diff --git a/test/test_api_extractors.py b/test/test_api_extractors.py new file mode 100644 index 00000000..cced61af --- /dev/null +++ b/test/test_api_extractors.py @@ -0,0 +1,342 @@ +""" +Tests for src/base_classes/api_extractors.py + +Covers ESPNFootballExtractor, ESPNBaseballExtractor, ESPNHockeyExtractor, +SoccerAPIExtractor, and the shared _extract_common_details logic. +""" + +import logging +import pytest +from src.base_classes.api_extractors import ( + ESPNFootballExtractor, + ESPNBaseballExtractor, + ESPNHockeyExtractor, + SoccerAPIExtractor, +) + + +# --------------------------------------------------------------------------- +# Shared test data factories +# --------------------------------------------------------------------------- + +def _make_espn_event(state: str = "in", home_abbr: str = "KC", away_abbr: str = "BUF", + home_score: str = "14", away_score: str = "7", + date_str: str = "2024-01-15T20:00:00Z", + include_situation: bool = False, + situation: dict | None = None, + status_detail: str = "2nd Qtr 8:42", + period: int = 2) -> dict: + """Build a minimal ESPN-style game event dict.""" + comp_status = { + "type": { + "state": state, + "shortDetail": status_detail, + "detail": status_detail, + "name": "STATUS_IN_PROGRESS", + }, + "period": period, + "displayClock": "8:42", + } + comp = { + "status": comp_status, + "competitors": [ + { + "homeAway": "home", + "team": {"abbreviation": home_abbr, "displayName": f"{home_abbr} Team"}, + "score": home_score, + }, + { + "homeAway": "away", + "team": {"abbreviation": away_abbr, "displayName": f"{away_abbr} Team"}, + "score": away_score, + }, + ], + } + if include_situation: + comp["situation"] = situation or {} + return { + "id": "test-game-1", + "date": date_str, + "competitions": [comp], + } + + +def _make_logger() -> logging.Logger: + return logging.getLogger("test_extractor") + + +# --------------------------------------------------------------------------- +# ESPNFootballExtractor +# --------------------------------------------------------------------------- + +class TestESPNFootballExtractor: + def setup_method(self): + self.extractor = ESPNFootballExtractor(_make_logger()) + + def test_extract_live_game_basic_fields(self): + event = _make_espn_event(state="in", home_score="14", away_score="7") + result = self.extractor.extract_game_details(event) + assert result is not None + assert result["home_abbr"] == "KC" + assert result["away_abbr"] == "BUF" + assert result["home_score"] == "14" + assert result["away_score"] == "7" + assert result["is_live"] is True + assert result["is_final"] is False + assert result["is_upcoming"] is False + + def test_extract_final_game(self): + event = _make_espn_event(state="post") + result = self.extractor.extract_game_details(event) + assert result is not None + assert result["is_final"] is True + assert result["is_live"] is False + + def test_extract_upcoming_game(self): + event = _make_espn_event(state="pre") + result = self.extractor.extract_game_details(event) + assert result is not None + assert result["is_upcoming"] is True + + def test_sport_specific_fields_default_when_pregame(self): + event = _make_espn_event(state="pre") + fields = self.extractor.get_sport_specific_fields(event) + assert "down" in fields + assert "distance" in fields + assert "possession" in fields + assert "is_redzone" in fields + assert fields["is_redzone"] is False + + def test_sport_specific_fields_live_with_situation(self): + situation = { + "down": 3, + "distance": 7, + "possession": "KC", + "isRedZone": True, + "homeTimeouts": 2, + "awayTimeouts": 1, + } + event = _make_espn_event(state="in", include_situation=True, situation=situation) + fields = self.extractor.get_sport_specific_fields(event) + assert fields["down"] == 3 + assert fields["distance"] == 7 + assert fields["is_redzone"] is True + assert fields["home_timeouts"] == 2 + assert fields["away_timeouts"] == 1 + + def test_scoring_event_detected(self): + # situation must be non-empty (truthy) for the live block to execute + situation = {"down": 1, "distance": 10} + event = _make_espn_event( + state="in", + include_situation=True, + situation=situation, + status_detail="touchdown scored", + ) + fields = self.extractor.get_sport_specific_fields(event) + assert "touchdown" in fields.get("scoring_event", "").lower() + + def test_returns_none_on_empty_event(self): + assert self.extractor.extract_game_details({}) is None + + def test_returns_none_when_teams_missing(self): + event = { + "id": "x", + "date": "2024-01-15T20:00:00Z", + "competitions": [ + { + "status": {"type": {"state": "in", "shortDetail": "", "detail": "", "name": ""}}, + "competitors": [], # no competitors + } + ], + } + assert self.extractor.extract_game_details(event) is None + + def test_date_z_suffix_parsed(self): + event = _make_espn_event(date_str="2024-01-15T20:00:00Z") + result = self.extractor.extract_game_details(event) + # Should not raise and should return a result + assert result is not None + + def test_id_propagated(self): + event = _make_espn_event() + result = self.extractor.extract_game_details(event) + assert result["id"] == "test-game-1" + + +# --------------------------------------------------------------------------- +# ESPNBaseballExtractor +# --------------------------------------------------------------------------- + +class TestESPNBaseballExtractor: + def setup_method(self): + self.extractor = ESPNBaseballExtractor(_make_logger()) + + def test_extract_live_game(self): + event = _make_espn_event( + state="in", home_abbr="NYY", away_abbr="BOS", + home_score="3", away_score="2" + ) + result = self.extractor.extract_game_details(event) + assert result is not None + assert result["home_abbr"] == "NYY" + assert result["is_live"] is True + + def test_baseball_sport_fields_defaults(self): + event = _make_espn_event(state="pre") + fields = self.extractor.get_sport_specific_fields(event) + assert "inning" in fields + assert "outs" in fields + assert "bases" in fields + assert "strikes" in fields + assert "balls" in fields + + def test_baseball_sport_fields_live(self): + situation = { + "inning": 7, + "outs": 2, + "bases": "110", + "strikes": 2, + "balls": 3, + "pitcher": "Smith", + "batter": "Jones", + } + event = _make_espn_event(state="in", include_situation=True, situation=situation) + fields = self.extractor.get_sport_specific_fields(event) + assert fields["inning"] == 7 + assert fields["outs"] == 2 + assert fields["strikes"] == 2 + assert fields["pitcher"] == "Smith" + + def test_returns_none_on_empty(self): + assert self.extractor.extract_game_details({}) is None + + +# --------------------------------------------------------------------------- +# ESPNHockeyExtractor +# --------------------------------------------------------------------------- + +class TestESPNHockeyExtractor: + def setup_method(self): + self.extractor = ESPNHockeyExtractor(_make_logger()) + + def test_extract_live_game(self): + event = _make_espn_event( + state="in", home_abbr="BOS", away_abbr="TOR", + home_score="2", away_score="1" + ) + result = self.extractor.extract_game_details(event) + assert result is not None + assert result["is_live"] is True + + def test_hockey_period_text_p1(self): + situation = {"isPowerPlay": False} + event = _make_espn_event( + state="in", include_situation=True, situation=situation, period=1 + ) + fields = self.extractor.get_sport_specific_fields(event) + assert fields["period_text"] == "P1" + + def test_hockey_period_text_p2(self): + situation = {"isPowerPlay": False} # non-empty so the live block executes + event = _make_espn_event( + state="in", include_situation=True, situation=situation, period=2 + ) + fields = self.extractor.get_sport_specific_fields(event) + assert fields["period_text"] == "P2" + + def test_hockey_period_text_p3(self): + situation = {"isPowerPlay": False} + event = _make_espn_event( + state="in", include_situation=True, situation=situation, period=3 + ) + fields = self.extractor.get_sport_specific_fields(event) + assert fields["period_text"] == "P3" + + def test_hockey_period_text_ot(self): + situation = {"isPowerPlay": False} + event = _make_espn_event( + state="in", include_situation=True, situation=situation, period=4 + ) + fields = self.extractor.get_sport_specific_fields(event) + assert fields["period_text"] == "OT1" + + def test_hockey_power_play(self): + situation = {"isPowerPlay": True, "homeShots": 12, "awayShots": 8} + event = _make_espn_event(state="in", include_situation=True, situation=situation, period=2) + fields = self.extractor.get_sport_specific_fields(event) + assert fields["power_play"] is True + assert fields["shots_on_goal"]["home"] == 12 + assert fields["shots_on_goal"]["away"] == 8 + + def test_hockey_fields_defaults_pregame(self): + event = _make_espn_event(state="pre") + fields = self.extractor.get_sport_specific_fields(event) + assert "period" in fields + assert "power_play" in fields + assert fields["power_play"] is False + + def test_returns_none_on_empty(self): + assert self.extractor.extract_game_details({}) is None + + +# --------------------------------------------------------------------------- +# SoccerAPIExtractor +# --------------------------------------------------------------------------- + +class TestSoccerAPIExtractor: + def setup_method(self): + self.extractor = SoccerAPIExtractor(_make_logger()) + + def _make_soccer_event(self, is_live: bool = True) -> dict: + return { + "id": "soccer-1", + "home_team": {"abbreviation": "ARS", "name": "Arsenal"}, + "away_team": {"abbreviation": "CHE", "name": "Chelsea"}, + "home_score": "2", + "away_score": "1", + "status": "LIVE", + "is_live": is_live, + "is_final": not is_live, + "is_upcoming": False, + "half": "1", + "stoppage_time": "2", + "home_yellow_cards": 1, + "away_yellow_cards": 2, + "home_red_cards": 0, + "away_red_cards": 0, + "home_possession": 55, + "away_possession": 45, + } + + def test_extract_live_game(self): + event = self._make_soccer_event(is_live=True) + result = self.extractor.extract_game_details(event) + assert result is not None + assert result["home_abbr"] == "ARS" + assert result["away_abbr"] == "CHE" + assert result["is_live"] is True + + def test_sport_specific_cards(self): + event = self._make_soccer_event() + fields = self.extractor.get_sport_specific_fields(event) + assert fields["cards"]["home_yellow"] == 1 + assert fields["cards"]["away_yellow"] == 2 + assert fields["cards"]["home_red"] == 0 + + def test_sport_specific_possession(self): + event = self._make_soccer_event() + fields = self.extractor.get_sport_specific_fields(event) + assert fields["possession"]["home"] == 55 + assert fields["possession"]["away"] == 45 + + def test_sport_specific_half(self): + event = self._make_soccer_event() + fields = self.extractor.get_sport_specific_fields(event) + assert fields["half"] == "1" + + def test_scores_as_strings(self): + event = self._make_soccer_event() + result = self.extractor.extract_game_details(event) + assert result["home_score"] == "2" + assert result["away_score"] == "1" diff --git a/test/test_background_data_service.py b/test/test_background_data_service.py new file mode 100644 index 00000000..587cab30 --- /dev/null +++ b/test/test_background_data_service.py @@ -0,0 +1,299 @@ +""" +Tests for src/background_data_service.py + +Covers BackgroundDataService: submit_fetch_request, get_result, +is_request_complete, get_request_status, cancel_request, get_statistics, +_cleanup_completed_requests, shutdown, and get_background_service singleton. +""" + +import time +import pytest +from unittest.mock import MagicMock, patch, Mock +from concurrent.futures import Future + +from src.background_data_service import ( + BackgroundDataService, + FetchStatus, + FetchResult, + FetchRequest, + get_background_service, + shutdown_background_service, +) +import src.background_data_service as bds_module + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture(autouse=True) +def reset_global_service(): + """Ensure each test starts with no global singleton.""" + shutdown_background_service() + yield + shutdown_background_service() + + +@pytest.fixture +def mock_cache_manager(): + m = MagicMock() + m.get.return_value = None + m.set.return_value = None + m.generate_sport_cache_key.return_value = "test_key" + return m + + +@pytest.fixture +def service(mock_cache_manager): + svc = BackgroundDataService(mock_cache_manager, max_workers=2, request_timeout=5) + yield svc + svc.shutdown(wait=False) + + +# --------------------------------------------------------------------------- +# Initialisation +# --------------------------------------------------------------------------- + +class TestInitialisation: + def test_stats_zeroed(self, service): + stats = service.get_statistics() + assert stats["total_requests"] == 0 + assert stats["completed_requests"] == 0 + assert stats["failed_requests"] == 0 + + def test_no_active_requests(self, service): + assert len(service.active_requests) == 0 + + def test_not_shutdown(self, service): + assert service._shutdown is False + + +# --------------------------------------------------------------------------- +# Cache hit path +# --------------------------------------------------------------------------- + +class TestCacheHit: + def test_cache_hit_returns_request_id(self, service, mock_cache_manager): + mock_cache_manager.get.return_value = {"events": [{"id": "1"}]} + req_id = service.submit_fetch_request( + sport="nfl", year=2024, + url="https://example.com/nfl", + cache_key="nfl_key", + ) + assert req_id is not None + # Request should be immediately complete due to cache hit + result = service.get_result(req_id) + assert result is not None + assert result.success is True + assert result.cached is True + + def test_cache_hit_increments_stat(self, service, mock_cache_manager): + mock_cache_manager.get.return_value = {"events": []} + service.submit_fetch_request(sport="nba", year=2024, url="https://x.com", cache_key="k") + stats = service.get_statistics() + assert stats["cached_hits"] == 1 + + +# --------------------------------------------------------------------------- +# Actual fetch path (mocked HTTP) +# --------------------------------------------------------------------------- + +class TestFetchPath: + def _valid_payload(self) -> dict: + return {"events": [{"id": "g1"}, {"id": "g2"}]} + + def test_successful_fetch_completes(self, service, mock_cache_manager): + mock_resp = Mock() + mock_resp.json.return_value = self._valid_payload() + mock_resp.raise_for_status.return_value = None + + with patch.object(service.session, "get", return_value=mock_resp): + req_id = service.submit_fetch_request( + sport="nfl", year=2024, + url="https://example.com/nfl", + cache_key="nfl_test", + ) + # Wait for the background thread + deadline = time.time() + 5 + while not service.is_request_complete(req_id) and time.time() < deadline: + time.sleep(0.05) + + result = service.get_result(req_id) + assert result is not None + assert result.success is True + assert result.data == self._valid_payload() + + def test_failed_fetch_records_error(self, service, mock_cache_manager): + with patch.object(service.session, "get", side_effect=Exception("network error")): + req_id = service.submit_fetch_request( + sport="nba", year=2024, + url="https://example.com/nba", + cache_key="nba_test", + max_retries=0, + ) + deadline = time.time() + 5 + while not service.is_request_complete(req_id) and time.time() < deadline: + time.sleep(0.05) + + result = service.get_result(req_id) + assert result is not None + assert result.success is False + assert result.error is not None + + def test_cache_miss_increments_stat(self, service, mock_cache_manager): + mock_resp = Mock() + mock_resp.json.return_value = self._valid_payload() + mock_resp.raise_for_status.return_value = None + + with patch.object(service.session, "get", return_value=mock_resp): + service.submit_fetch_request( + sport="nfl", year=2024, url="https://x.com", cache_key="new_key", + ) + stats = service.get_statistics() + assert stats["cache_misses"] == 1 + + def test_callback_called_on_success(self, service, mock_cache_manager): + callback = Mock() + mock_resp = Mock() + mock_resp.json.return_value = self._valid_payload() + mock_resp.raise_for_status.return_value = None + + with patch.object(service.session, "get", return_value=mock_resp): + req_id = service.submit_fetch_request( + sport="nfl", year=2024, url="https://x.com", + cache_key="cb_key", callback=callback, max_retries=0, + ) + deadline = time.time() + 5 + while not service.is_request_complete(req_id) and time.time() < deadline: + time.sleep(0.05) + + callback.assert_called_once() + call_arg = callback.call_args[0][0] + assert isinstance(call_arg, FetchResult) + + def test_data_cached_after_successful_fetch(self, service, mock_cache_manager): + mock_resp = Mock() + mock_resp.json.return_value = self._valid_payload() + mock_resp.raise_for_status.return_value = None + + with patch.object(service.session, "get", return_value=mock_resp): + req_id = service.submit_fetch_request( + sport="nfl", year=2024, url="https://x.com", cache_key="cache_after_key", + ) + deadline = time.time() + 5 + while not service.is_request_complete(req_id) and time.time() < deadline: + time.sleep(0.05) + + mock_cache_manager.set.assert_called() + + +# --------------------------------------------------------------------------- +# Request status / cancel +# --------------------------------------------------------------------------- + +class TestRequestStatusAndCancel: + def test_unknown_request_status_is_none(self, service): + assert service.get_request_status("nonexistent") is None + + def test_cancel_active_request(self, service, mock_cache_manager): + # Manually insert an active request + req = FetchRequest( + id="r1", sport="nfl", year=2024, + cache_key="k", url="https://x.com", + ) + req.status = FetchStatus.PENDING + service.active_requests["r1"] = req + result = service.cancel_request("r1") + assert result is True + assert "r1" not in service.active_requests + + def test_cancel_nonexistent_request(self, service): + assert service.cancel_request("does-not-exist") is False + + def test_is_request_complete_false_for_active(self, service, mock_cache_manager): + req = FetchRequest( + id="r2", sport="mlb", year=2024, + cache_key="k2", url="https://x.com", + ) + service.active_requests["r2"] = req + assert service.is_request_complete("r2") is False + + def test_is_request_complete_true_for_done(self, service): + result = FetchResult(request_id="r3", success=True) + service.completed_requests["r3"] = result + assert service.is_request_complete("r3") is True + + def test_get_result_returns_none_for_unknown(self, service): + assert service.get_result("unknown") is None + + +# --------------------------------------------------------------------------- +# Shutdown +# --------------------------------------------------------------------------- + +class TestShutdown: + def test_shutdown_sets_flag(self, service): + service.shutdown(wait=False) + assert service._shutdown is True + + def test_submit_after_shutdown_raises(self, service, mock_cache_manager): + service.shutdown(wait=False) + with pytest.raises(RuntimeError, match="shutting down"): + service.submit_fetch_request( + sport="nfl", year=2024, url="https://x.com", cache_key="k" + ) + + +# --------------------------------------------------------------------------- +# Cleanup +# --------------------------------------------------------------------------- + +class TestCleanup: + def test_cleanup_removes_old_requests(self, service): + old_result = FetchResult(request_id="old", success=True) + old_result.completed_at = time.time() - 7200 # 2 hours ago + service.completed_requests["old"] = old_result + service._last_completed_requests_cleanup = 0 # force cleanup + removed = service._cleanup_completed_requests(force=True) + assert removed >= 1 + assert "old" not in service.completed_requests + + def test_cleanup_respects_interval(self, service): + old_result = FetchResult(request_id="r", success=True) + old_result.completed_at = time.time() - 7200 + service.completed_requests["r"] = old_result + # Cleanup interval not passed, should skip + service._last_completed_requests_cleanup = time.time() + removed = service._cleanup_completed_requests(force=False) + assert removed == 0 + + def test_size_limit_enforcement(self, service): + service._max_completed_requests = 3 + for i in range(5): + result = FetchResult(request_id=str(i), success=True) + result.completed_at = time.time() - (5 - i) * 100 # oldest first + service.completed_requests[str(i)] = result + service._last_completed_requests_cleanup = 0 + service._cleanup_completed_requests(force=True) + assert len(service.completed_requests) <= 3 + + +# --------------------------------------------------------------------------- +# Singleton get_background_service +# --------------------------------------------------------------------------- + +class TestGetBackgroundService: + def test_first_call_requires_cache_manager(self): + with pytest.raises(ValueError, match="cache_manager is required"): + get_background_service() + + def test_creates_singleton(self, mock_cache_manager): + svc1 = get_background_service(mock_cache_manager) + svc2 = get_background_service() + assert svc1 is svc2 + + def test_shutdown_clears_singleton(self, mock_cache_manager): + get_background_service(mock_cache_manager) + shutdown_background_service() + with pytest.raises(ValueError): + get_background_service() diff --git a/test/test_data_sources.py b/test/test_data_sources.py new file mode 100644 index 00000000..9aad0588 --- /dev/null +++ b/test/test_data_sources.py @@ -0,0 +1,209 @@ +""" +Tests for src/base_classes/data_sources.py + +Covers ESPNDataSource, MLBAPIDataSource, SoccerAPIDataSource. +All HTTP calls are mocked to avoid network access. +""" + +import logging +from datetime import datetime, date +from unittest.mock import MagicMock, patch, Mock +import pytest +import requests + +from src.base_classes.data_sources import ESPNDataSource, MLBAPIDataSource, SoccerAPIDataSource + + +def _make_logger() -> logging.Logger: + return logging.getLogger("test_data_sources") + + +def _mock_response(json_data: dict, status_code: int = 200): + resp = Mock(spec=requests.Response) + resp.status_code = status_code + resp.json.return_value = json_data + resp.raise_for_status = Mock() + if status_code >= 400: + resp.raise_for_status.side_effect = requests.HTTPError(response=resp) + return resp + + +# --------------------------------------------------------------------------- +# ESPNDataSource +# --------------------------------------------------------------------------- + +class TestESPNDataSource: + def setup_method(self): + self.source = ESPNDataSource(_make_logger()) + + def test_get_headers(self): + headers = self.source.get_headers() + assert headers["Accept"] == "application/json" + assert "LEDMatrix" in headers["User-Agent"] + + def test_fetch_live_games_returns_live_events(self): + live_event = { + "competitions": [{"status": {"type": {"state": "in"}}}] + } + non_live_event = { + "competitions": [{"status": {"type": {"state": "pre"}}}] + } + payload = {"events": [live_event, non_live_event]} + + with patch.object(self.source.session, "get", return_value=_mock_response(payload)): + result = self.source.fetch_live_games("football", "nfl") + + assert len(result) == 1 + assert result[0] is live_event + + def test_fetch_live_games_empty_when_none_live(self): + payload = {"events": [ + {"competitions": [{"status": {"type": {"state": "post"}}}]} + ]} + with patch.object(self.source.session, "get", return_value=_mock_response(payload)): + result = self.source.fetch_live_games("football", "nfl") + assert result == [] + + def test_fetch_live_games_returns_empty_on_error(self): + with patch.object(self.source.session, "get", side_effect=Exception("network failure")): + result = self.source.fetch_live_games("football", "nfl") + assert result == [] + + def test_fetch_schedule_returns_all_events(self): + events = [{"id": "1"}, {"id": "2"}] + payload = {"events": events} + start = datetime(2024, 1, 1) + end = datetime(2024, 1, 7) + + with patch.object(self.source.session, "get", return_value=_mock_response(payload)): + result = self.source.fetch_schedule("football", "nfl", (start, end)) + + assert len(result) == 2 + + def test_fetch_schedule_returns_empty_on_error(self): + with patch.object(self.source.session, "get", side_effect=Exception("timeout")): + result = self.source.fetch_schedule("football", "nfl", (datetime.now(), datetime.now())) + assert result == [] + + def test_fetch_standings_success(self): + payload = {"standings": []} + with patch.object(self.source.session, "get", return_value=_mock_response(payload)): + result = self.source.fetch_standings("football", "nfl") + assert result == payload + + def test_fetch_standings_returns_empty_on_error(self): + with patch.object(self.source.session, "get", side_effect=Exception("error")): + result = self.source.fetch_standings("football", "nfl") + assert result == {} + + def test_base_url_set_correctly(self): + assert "espn.com" in self.source.base_url + + +# --------------------------------------------------------------------------- +# MLBAPIDataSource +# --------------------------------------------------------------------------- + +class TestMLBAPIDataSource: + def setup_method(self): + self.source = MLBAPIDataSource(_make_logger()) + + def test_fetch_live_games_filters_live(self): + live_game = {"status": {"abstractGameState": "Live"}} + final_game = {"status": {"abstractGameState": "Final"}} + payload = {"dates": [{"games": [live_game, final_game]}]} + + with patch.object(self.source.session, "get", return_value=_mock_response(payload)): + result = self.source.fetch_live_games("baseball", "mlb") + + assert len(result) == 1 + assert result[0] is live_game + + def test_fetch_live_games_empty_dates(self): + payload = {"dates": []} + with patch.object(self.source.session, "get", return_value=_mock_response(payload)): + result = self.source.fetch_live_games("baseball", "mlb") + assert result == [] + + def test_fetch_live_games_returns_empty_on_error(self): + with patch.object(self.source.session, "get", side_effect=Exception("err")): + result = self.source.fetch_live_games("baseball", "mlb") + assert result == [] + + def test_fetch_schedule_aggregates_all_dates(self): + payload = { + "dates": [ + {"games": [{"id": "1"}, {"id": "2"}]}, + {"games": [{"id": "3"}]}, + ] + } + with patch.object(self.source.session, "get", return_value=_mock_response(payload)): + result = self.source.fetch_schedule("baseball", "mlb", (datetime.now(), datetime.now())) + assert len(result) == 3 + + def test_fetch_schedule_returns_empty_on_error(self): + with patch.object(self.source.session, "get", side_effect=Exception("err")): + result = self.source.fetch_schedule("baseball", "mlb", (datetime.now(), datetime.now())) + assert result == [] + + def test_fetch_standings_success(self): + payload = {"records": []} + with patch.object(self.source.session, "get", return_value=_mock_response(payload)): + result = self.source.fetch_standings("baseball", "mlb") + assert result == payload + + def test_fetch_standings_returns_empty_on_error(self): + with patch.object(self.source.session, "get", side_effect=Exception("err")): + result = self.source.fetch_standings("baseball", "mlb") + assert result == {} + + +# --------------------------------------------------------------------------- +# SoccerAPIDataSource +# --------------------------------------------------------------------------- + +class TestSoccerAPIDataSource: + def setup_method(self): + self.source = SoccerAPIDataSource(_make_logger(), api_key="test-key-123") + + def test_headers_include_api_key(self): + headers = self.source.get_headers() + assert headers["X-Auth-Token"] == "test-key-123" + + def test_headers_without_api_key(self): + source = SoccerAPIDataSource(_make_logger()) + headers = source.get_headers() + assert "X-Auth-Token" not in headers + + def test_fetch_live_games_success(self): + payload = {"matches": [{"id": "m1"}, {"id": "m2"}]} + with patch.object(self.source.session, "get", return_value=_mock_response(payload)): + result = self.source.fetch_live_games("soccer", "eng.1") + assert len(result) == 2 + + def test_fetch_live_games_returns_empty_on_error(self): + with patch.object(self.source.session, "get", side_effect=Exception("err")): + result = self.source.fetch_live_games("soccer", "eng.1") + assert result == [] + + def test_fetch_schedule_success(self): + payload = {"matches": [{"id": "m1"}]} + with patch.object(self.source.session, "get", return_value=_mock_response(payload)): + result = self.source.fetch_schedule("soccer", "eng.1", (datetime.now(), datetime.now())) + assert len(result) == 1 + + def test_fetch_schedule_returns_empty_on_error(self): + with patch.object(self.source.session, "get", side_effect=Exception("err")): + result = self.source.fetch_schedule("soccer", "eng.1", (datetime.now(), datetime.now())) + assert result == [] + + def test_fetch_standings_success(self): + payload = {"standings": []} + with patch.object(self.source.session, "get", return_value=_mock_response(payload)): + result = self.source.fetch_standings("soccer", "PL") + assert result == payload + + def test_fetch_standings_returns_empty_on_error(self): + with patch.object(self.source.session, "get", side_effect=Exception("err")): + result = self.source.fetch_standings("soccer", "PL") + assert result == {} diff --git a/test/test_game_helper.py b/test/test_game_helper.py new file mode 100644 index 00000000..9bea5b4c --- /dev/null +++ b/test/test_game_helper.py @@ -0,0 +1,317 @@ +""" +Tests for src/common/game_helper.py + +Covers GameHelper: extract_game_details, filter_*, sort_games_by_time, +process_games, get_game_summary, and all private helpers. +""" + +import logging +import pytest +from datetime import datetime, timezone, timedelta + +from src.common.game_helper import GameHelper + + +def _make_logger() -> logging.Logger: + return logging.getLogger("test_game_helper") + + +def _make_espn_event( + state: str = "in", + home_abbr: str = "LAL", + away_abbr: str = "BOS", + home_score: str = "105", + away_score: str = "98", + date_str: str = "2024-01-15T20:00:00Z", + period: int = 4, + status_name: str = "STATUS_IN_PROGRESS", + home_record: str = "30-10", + away_record: str = "25-15", + event_id: str = "game-1", +) -> dict: + return { + "id": event_id, + "date": date_str, + "competitions": [ + { + "status": { + "type": { + "state": state, + "shortDetail": "Q4 2:30", + "name": status_name, + }, + "period": period, + "displayClock": "2:30", + }, + "competitors": [ + { + "homeAway": "home", + "id": "h1", + "team": {"abbreviation": home_abbr, "displayName": f"{home_abbr} Team"}, + "score": home_score, + "records": [{"summary": home_record}], + }, + { + "homeAway": "away", + "id": "a1", + "team": {"abbreviation": away_abbr, "displayName": f"{away_abbr} Team"}, + "score": away_score, + "records": [{"summary": away_record}], + }, + ], + } + ], + } + + +@pytest.fixture +def helper(): + return GameHelper(timezone_str="UTC", logger=_make_logger()) + + +# --------------------------------------------------------------------------- +# extract_game_details +# --------------------------------------------------------------------------- + +class TestExtractGameDetails: + def test_live_game(self, helper): + event = _make_espn_event(state="in") + result = helper.extract_game_details(event) + assert result is not None + assert result["is_live"] is True + assert result["is_final"] is False + assert result["is_upcoming"] is False + + def test_final_game(self, helper): + event = _make_espn_event(state="post") + result = helper.extract_game_details(event) + assert result["is_final"] is True + + def test_upcoming_game(self, helper): + event = _make_espn_event(state="pre") + result = helper.extract_game_details(event) + assert result["is_upcoming"] is True + + def test_halftime_detection(self, helper): + event = _make_espn_event(state="halftime", status_name="STATUS_HALFTIME") + result = helper.extract_game_details(event) + assert result["is_halftime"] is True + + def test_basic_fields_present(self, helper): + event = _make_espn_event() + result = helper.extract_game_details(event) + for key in ("id", "home_abbr", "away_abbr", "home_score", "away_score", + "home_record", "away_record", "start_time_utc"): + assert key in result + + def test_team_abbreviations(self, helper): + event = _make_espn_event(home_abbr="MIA", away_abbr="PHX") + result = helper.extract_game_details(event) + assert result["home_abbr"] == "MIA" + assert result["away_abbr"] == "PHX" + + def test_scores_as_strings(self, helper): + event = _make_espn_event(home_score="110", away_score="99") + result = helper.extract_game_details(event) + assert result["home_score"] == "110" + assert result["away_score"] == "99" + + def test_returns_none_on_empty(self, helper): + assert helper.extract_game_details({}) is None + assert helper.extract_game_details(None) is None + + def test_returns_none_when_no_competitors(self, helper): + event = _make_espn_event() + event["competitions"][0]["competitors"] = [] + assert helper.extract_game_details(event) is None + + def test_date_z_suffix_parsed(self, helper): + event = _make_espn_event(date_str="2024-06-01T19:30:00Z") + result = helper.extract_game_details(event) + assert result["start_time_utc"] is not None + assert result["start_time_utc"].tzinfo is not None + + def test_zero_zero_record_suppressed(self, helper): + event = _make_espn_event(home_record="0-0", away_record="0-0-0") + result = helper.extract_game_details(event) + assert result["home_record"] == "" + assert result["away_record"] == "" + + def test_basketball_sport_fields(self, helper): + event = _make_espn_event(period=3) + result = helper.extract_game_details(event, sport="basketball") + assert result["period_text"] == "Q3" + assert "clock" in result + + def test_basketball_overtime_period(self, helper): + event = _make_espn_event(period=5) + result = helper.extract_game_details(event, sport="basketball") + assert result["period_text"] == "OT1" + + def test_football_sport_fields(self, helper): + event = _make_espn_event(period=2) + result = helper.extract_game_details(event, sport="football") + assert result["period_text"] == "Q2" + + def test_hockey_sport_fields_period_1(self, helper): + event = _make_espn_event(period=1) + result = helper.extract_game_details(event, sport="hockey") + assert result["period_text"] == "P1" + + def test_hockey_sport_fields_ot(self, helper): + event = _make_espn_event(period=4) + result = helper.extract_game_details(event, sport="hockey") + assert result["period_text"] == "OT1" + + def test_baseball_sport_fields(self, helper): + event = _make_espn_event(period=7) + result = helper.extract_game_details(event, sport="baseball") + assert result["period_text"] == "INN 7" + + +# --------------------------------------------------------------------------- +# Filter methods +# --------------------------------------------------------------------------- + +class TestFilterMethods: + def _make_games(self): + now = datetime.now(timezone.utc) + return [ + {"is_live": True, "is_final": False, "is_upcoming": False, "home_abbr": "LAL", "away_abbr": "BOS", "start_time_utc": now}, + {"is_live": False, "is_final": True, "is_upcoming": False, "home_abbr": "MIA", "away_abbr": "PHX", "start_time_utc": now - timedelta(hours=3)}, + {"is_live": False, "is_final": False, "is_upcoming": True, "home_abbr": "DAL", "away_abbr": "CHI", "start_time_utc": now + timedelta(hours=2)}, + ] + + def test_filter_live_games(self, helper): + games = self._make_games() + result = helper.filter_live_games(games) + assert len(result) == 1 + assert result[0]["home_abbr"] == "LAL" + + def test_filter_final_games(self, helper): + games = self._make_games() + result = helper.filter_final_games(games) + assert len(result) == 1 + assert result[0]["home_abbr"] == "MIA" + + def test_filter_upcoming_games(self, helper): + games = self._make_games() + result = helper.filter_upcoming_games(games) + assert len(result) == 1 + assert result[0]["home_abbr"] == "DAL" + + def test_filter_favorite_teams_match(self, helper): + games = self._make_games() + result = helper.filter_favorite_teams(games, ["LAL"]) + assert len(result) == 1 + assert result[0]["home_abbr"] == "LAL" + + def test_filter_favorite_teams_empty_list_returns_all(self, helper): + games = self._make_games() + result = helper.filter_favorite_teams(games, []) + assert len(result) == 3 + + def test_filter_favorite_teams_away_match(self, helper): + games = self._make_games() + result = helper.filter_favorite_teams(games, ["BOS"]) + assert len(result) == 1 + + def test_filter_recent_games_within_window(self, helper): + now = datetime.now(timezone.utc) + games = [ + {"start_time_utc": now - timedelta(days=2), "is_final": True}, + {"start_time_utc": now - timedelta(days=10), "is_final": True}, + ] + result = helper.filter_recent_games(games, days_back=7) + assert len(result) == 1 + + def test_filter_recent_games_all_within(self, helper): + now = datetime.now(timezone.utc) + games = [ + {"start_time_utc": now - timedelta(days=1)}, + {"start_time_utc": now - timedelta(days=3)}, + ] + result = helper.filter_recent_games(games, days_back=7) + assert len(result) == 2 + + def test_sort_games_ascending(self, helper): + now = datetime.now(timezone.utc) + games = [ + {"start_time_utc": now + timedelta(hours=2), "id": "late"}, + {"start_time_utc": now + timedelta(hours=1), "id": "early"}, + ] + result = helper.sort_games_by_time(games) + assert result[0]["id"] == "early" + + def test_sort_games_descending(self, helper): + now = datetime.now(timezone.utc) + games = [ + {"start_time_utc": now + timedelta(hours=1), "id": "early"}, + {"start_time_utc": now + timedelta(hours=2), "id": "late"}, + ] + result = helper.sort_games_by_time(games, reverse=True) + assert result[0]["id"] == "late" + + +# --------------------------------------------------------------------------- +# process_games +# --------------------------------------------------------------------------- + +class TestProcessGames: + def test_processes_valid_events(self, helper): + events = [ + _make_espn_event(event_id="1"), + _make_espn_event(event_id="2"), + ] + result = helper.process_games(events) + assert len(result) == 2 + + def test_skips_invalid_events(self, helper): + events = [ + _make_espn_event(event_id="1"), + {}, # invalid + ] + result = helper.process_games(events) + assert len(result) == 1 + + def test_empty_events(self, helper): + assert helper.process_games([]) == [] + + +# --------------------------------------------------------------------------- +# get_game_summary +# --------------------------------------------------------------------------- + +class TestGetGameSummary: + def test_live_summary(self, helper): + game = { + "home_abbr": "LAL", "away_abbr": "BOS", + "home_score": "105", "away_score": "98", + "status_text": "Q4 2:30", + "is_live": True, "is_final": False, + } + summary = helper.get_game_summary(game) + assert "BOS" in summary + assert "LAL" in summary + assert "98" in summary + assert "105" in summary + + def test_final_summary(self, helper): + game = { + "home_abbr": "LAL", "away_abbr": "BOS", + "home_score": "110", "away_score": "102", + "status_text": "Final", + "is_live": False, "is_final": True, + } + summary = helper.get_game_summary(game) + assert "Final" in summary + + def test_upcoming_summary(self, helper): + game = { + "home_abbr": "LAL", "away_abbr": "BOS", + "home_score": "0", "away_score": "0", + "status_text": "7:30 PM", + "is_live": False, "is_final": False, + } + summary = helper.get_game_summary(game) + assert "7:30 PM" in summary diff --git a/test/test_health_monitor.py b/test/test_health_monitor.py new file mode 100644 index 00000000..f8427476 --- /dev/null +++ b/test/test_health_monitor.py @@ -0,0 +1,303 @@ +""" +Tests for src/plugin_system/health_monitor.py + +Covers PluginHealthMonitor: get_plugin_health_status, get_plugin_health_metrics, +get_all_plugin_health, _get_recovery_suggestions, start/stop_monitoring, +register_health_check. +""" + +import pytest +from unittest.mock import MagicMock, patch +from datetime import datetime + +from src.plugin_system.health_monitor import ( + PluginHealthMonitor, + HealthStatus, + HealthMetrics, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +def _make_health_tracker( + summary: dict | None = None, + all_summaries: dict | None = None, +): + """Return a mock PluginHealthTracker.""" + tracker = MagicMock() + tracker.get_health_summary.return_value = summary + tracker.get_all_health_summaries.return_value = all_summaries or {} + return tracker + + +def _healthy_summary() -> dict: + return { + "success_rate": 100.0, + "circuit_state": "closed", + "consecutive_failures": 0, + "total_failures": 0, + "total_successes": 50, + "last_success_time": datetime.now().isoformat(), + "last_error": None, + } + + +def _degraded_summary() -> dict: + return { + "success_rate": 40.0, # 60% error rate + "circuit_state": "closed", + "consecutive_failures": 3, + "total_failures": 6, + "total_successes": 4, + "last_success_time": None, + "last_error": "timeout occurred", + } + + +def _unhealthy_summary() -> dict: + return { + "success_rate": 10.0, # 90% error rate + "circuit_state": "open", + "consecutive_failures": 10, + "total_failures": 9, + "total_successes": 1, + "last_success_time": None, + "last_error": "ImportError: missing module", + } + + +@pytest.fixture +def monitor(): + tracker = _make_health_tracker(_healthy_summary()) + return PluginHealthMonitor(health_tracker=tracker) + + +# --------------------------------------------------------------------------- +# get_plugin_health_status +# --------------------------------------------------------------------------- + +class TestGetPluginHealthStatus: + def test_healthy_status(self): + tracker = _make_health_tracker(_healthy_summary()) + monitor = PluginHealthMonitor(tracker) + status = monitor.get_plugin_health_status("plugin_a") + assert status == HealthStatus.HEALTHY + + def test_degraded_status(self): + tracker = _make_health_tracker(_degraded_summary()) + monitor = PluginHealthMonitor(tracker, degraded_threshold=0.5, unhealthy_threshold=0.8) + status = monitor.get_plugin_health_status("plugin_b") + assert status == HealthStatus.DEGRADED + + def test_unhealthy_status(self): + tracker = _make_health_tracker(_unhealthy_summary()) + monitor = PluginHealthMonitor(tracker, unhealthy_threshold=0.8) + status = monitor.get_plugin_health_status("plugin_c") + assert status == HealthStatus.UNHEALTHY + + def test_open_circuit_breaker_is_unhealthy(self): + summary = _healthy_summary() + summary["circuit_state"] = "open" + tracker = _make_health_tracker(summary) + monitor = PluginHealthMonitor(tracker) + status = monitor.get_plugin_health_status("plugin_d") + assert status == HealthStatus.UNHEALTHY + + def test_unknown_when_no_tracker(self): + monitor = PluginHealthMonitor(health_tracker=None) + status = monitor.get_plugin_health_status("plugin_e") + assert status == HealthStatus.UNKNOWN + + def test_unknown_when_no_summary(self): + tracker = _make_health_tracker(None) + monitor = PluginHealthMonitor(tracker) + status = monitor.get_plugin_health_status("plugin_f") + assert status == HealthStatus.UNKNOWN + + +# --------------------------------------------------------------------------- +# get_plugin_health_metrics +# --------------------------------------------------------------------------- + +class TestGetPluginHealthMetrics: + def test_healthy_metrics(self): + tracker = _make_health_tracker(_healthy_summary()) + monitor = PluginHealthMonitor(tracker) + metrics = monitor.get_plugin_health_metrics("plugin_a") + assert isinstance(metrics, HealthMetrics) + assert metrics.status == HealthStatus.HEALTHY + assert metrics.success_rate == pytest.approx(1.0) + assert metrics.error_rate == pytest.approx(0.0) + + def test_degraded_metrics(self): + tracker = _make_health_tracker(_degraded_summary()) + monitor = PluginHealthMonitor(tracker, degraded_threshold=0.5, unhealthy_threshold=0.8) + metrics = monitor.get_plugin_health_metrics("plugin_b") + assert metrics.status == HealthStatus.DEGRADED + assert metrics.consecutive_failures == 3 + + def test_unhealthy_metrics(self): + tracker = _make_health_tracker(_unhealthy_summary()) + monitor = PluginHealthMonitor(tracker, unhealthy_threshold=0.8) + metrics = monitor.get_plugin_health_metrics("plugin_c") + assert metrics.status == HealthStatus.UNHEALTHY + assert metrics.circuit_breaker_state == "open" + assert metrics.last_error is not None + + def test_metrics_without_tracker(self): + monitor = PluginHealthMonitor(health_tracker=None) + metrics = monitor.get_plugin_health_metrics("plugin_d") + assert metrics.status == HealthStatus.UNKNOWN + assert metrics.plugin_id == "plugin_d" + + def test_metrics_without_summary(self): + tracker = _make_health_tracker(None) + monitor = PluginHealthMonitor(tracker) + metrics = monitor.get_plugin_health_metrics("plugin_e") + assert metrics.status == HealthStatus.UNKNOWN + + def test_last_successful_update_parsed(self): + summary = _healthy_summary() + summary["last_success_time"] = "2024-06-01T12:00:00" + tracker = _make_health_tracker(summary) + monitor = PluginHealthMonitor(tracker) + metrics = monitor.get_plugin_health_metrics("plugin_a") + assert metrics.last_successful_update is not None + assert isinstance(metrics.last_successful_update, datetime) + + def test_invalid_last_success_time_handled(self): + summary = _healthy_summary() + summary["last_success_time"] = "not-a-date" + tracker = _make_health_tracker(summary) + monitor = PluginHealthMonitor(tracker) + # Should not raise + metrics = monitor.get_plugin_health_metrics("plugin_a") + assert metrics.last_successful_update is None + + def test_total_successes_failures(self): + tracker = _make_health_tracker(_degraded_summary()) + monitor = PluginHealthMonitor(tracker, degraded_threshold=0.5, unhealthy_threshold=0.8) + metrics = monitor.get_plugin_health_metrics("plugin_b") + assert metrics.total_failures == 6 + assert metrics.total_successes == 4 + + +# --------------------------------------------------------------------------- +# get_all_plugin_health +# --------------------------------------------------------------------------- + +class TestGetAllPluginHealth: + def test_returns_empty_without_tracker(self): + monitor = PluginHealthMonitor(health_tracker=None) + result = monitor.get_all_plugin_health() + assert result == {} + + def test_returns_metrics_for_each_plugin(self): + all_summaries = { + "plugin_a": _healthy_summary(), + "plugin_b": _degraded_summary(), + } + tracker = MagicMock() + tracker.get_all_health_summaries.return_value = all_summaries + tracker.get_health_summary.side_effect = lambda pid: all_summaries.get(pid) + monitor = PluginHealthMonitor(tracker, degraded_threshold=0.5, unhealthy_threshold=0.8) + result = monitor.get_all_plugin_health() + assert "plugin_a" in result + assert "plugin_b" in result + assert isinstance(result["plugin_a"], HealthMetrics) + + def test_returns_empty_when_no_summaries(self): + tracker = _make_health_tracker(all_summaries={}) + monitor = PluginHealthMonitor(tracker) + result = monitor.get_all_plugin_health() + assert result == {} + + +# --------------------------------------------------------------------------- +# _get_recovery_suggestions +# --------------------------------------------------------------------------- + +class TestGetRecoverySuggestions: + def test_healthy_plugin_suggestion(self): + tracker = _make_health_tracker(_healthy_summary()) + monitor = PluginHealthMonitor(tracker) + suggestions = monitor._get_recovery_suggestions("p", _healthy_summary(), HealthStatus.HEALTHY) + assert any("healthy" in s.lower() for s in suggestions) + + def test_unhealthy_suggestions(self): + tracker = _make_health_tracker(_unhealthy_summary()) + monitor = PluginHealthMonitor(tracker, unhealthy_threshold=0.8) + suggestions = monitor._get_recovery_suggestions("p", _unhealthy_summary(), HealthStatus.UNHEALTHY) + assert len(suggestions) > 0 + assert any("unhealthy" in s.lower() for s in suggestions) + + def test_open_circuit_breaker_suggestion(self): + summary = _unhealthy_summary() + summary["circuit_state"] = "open" + tracker = _make_health_tracker(summary) + monitor = PluginHealthMonitor(tracker, unhealthy_threshold=0.8) + suggestions = monitor._get_recovery_suggestions("p", summary, HealthStatus.UNHEALTHY) + assert any("circuit" in s.lower() for s in suggestions) + + def test_timeout_error_suggestion(self): + summary = _degraded_summary() + summary["last_error"] = "connection timeout occurred" + tracker = _make_health_tracker(summary) + monitor = PluginHealthMonitor(tracker, degraded_threshold=0.5, unhealthy_threshold=0.8) + suggestions = monitor._get_recovery_suggestions("p", summary, HealthStatus.DEGRADED) + assert any("timeout" in s.lower() for s in suggestions) + + def test_import_error_suggestion(self): + summary = _unhealthy_summary() + summary["last_error"] = "ImportError: missing module" + tracker = _make_health_tracker(summary) + monitor = PluginHealthMonitor(tracker, unhealthy_threshold=0.8) + suggestions = monitor._get_recovery_suggestions("p", summary, HealthStatus.UNHEALTHY) + assert any("dependencies" in s.lower() or "import" in s.lower() or "missing" in s.lower() + for s in suggestions) + + def test_permission_error_suggestion(self): + summary = _unhealthy_summary() + summary["last_error"] = "permission denied to access resource" + tracker = _make_health_tracker(summary) + monitor = PluginHealthMonitor(tracker, unhealthy_threshold=0.8) + suggestions = monitor._get_recovery_suggestions("p", summary, HealthStatus.UNHEALTHY) + assert any("permission" in s.lower() for s in suggestions) + + def test_degraded_suggestions_include_error_rate(self): + tracker = _make_health_tracker(_degraded_summary()) + monitor = PluginHealthMonitor(tracker, degraded_threshold=0.5, unhealthy_threshold=0.8) + suggestions = monitor._get_recovery_suggestions("p", _degraded_summary(), HealthStatus.DEGRADED) + assert any("%" in s for s in suggestions) + + +# --------------------------------------------------------------------------- +# start / stop monitoring +# --------------------------------------------------------------------------- + +class TestMonitorLifecycle: + def test_start_monitoring(self, monitor): + monitor.start_monitoring() + assert monitor._monitor_thread is not None + assert monitor._monitor_thread.is_alive() + monitor.stop_monitoring() + + def test_stop_monitoring(self, monitor): + monitor.start_monitoring() + monitor.stop_monitoring() + # Thread should no longer be alive + assert not monitor._monitor_thread.is_alive() + + def test_double_start_no_duplicate_threads(self, monitor): + monitor.start_monitoring() + thread1 = monitor._monitor_thread + monitor.start_monitoring() # should be idempotent + assert monitor._monitor_thread is thread1 + monitor.stop_monitoring() + + def test_register_health_check(self, monitor): + callback = MagicMock() + monitor.register_health_check(callback) + assert callback in monitor._health_check_callbacks diff --git a/test/test_logo_downloader.py b/test/test_logo_downloader.py new file mode 100644 index 00000000..9b79d45b --- /dev/null +++ b/test/test_logo_downloader.py @@ -0,0 +1,129 @@ +""" +Tests for src/logo_downloader.py + +Focuses on the pure/static methods that don't require network calls: +normalize_abbreviation, get_logo_filename_variations, get_logo_directory, +ensure_logo_directory, and the download_missing_logo function path +(with HTTP mocked). +""" + +import os +import pytest +from pathlib import Path +from unittest.mock import patch, Mock, MagicMock + +from src.logo_downloader import LogoDownloader + + +# --------------------------------------------------------------------------- +# normalize_abbreviation +# --------------------------------------------------------------------------- + +class TestNormalizeAbbreviation: + def test_basic_lowercase(self): + result = LogoDownloader.normalize_abbreviation("lal") + assert result == "LAL" + + def test_uppercases(self): + result = LogoDownloader.normalize_abbreviation("bos") + assert result == "BOS" + + def test_ampersand_replaced(self): + result = LogoDownloader.normalize_abbreviation("TA&M") + assert "&" not in result + assert "AND" in result + + def test_forward_slash_replaced(self): + result = LogoDownloader.normalize_abbreviation("A/B") + assert "/" not in result + + def test_empty_returns_empty(self): + result = LogoDownloader.normalize_abbreviation("") + assert result == "" + + +# --------------------------------------------------------------------------- +# get_logo_filename_variations +# --------------------------------------------------------------------------- + +class TestGetLogoFilenameVariations: + def test_returns_list(self): + result = LogoDownloader.get_logo_filename_variations("LAL") + assert isinstance(result, list) + assert len(result) > 0 + + def test_includes_png(self): + result = LogoDownloader.get_logo_filename_variations("KC") + filenames = " ".join(result) + assert ".png" in filenames + + def test_includes_original(self): + result = LogoDownloader.get_logo_filename_variations("LAL") + assert any("LAL" in f for f in result) + + def test_ampersand_variation(self): + result = LogoDownloader.get_logo_filename_variations("TA&M") + # Should produce at least the normalized version + assert len(result) > 0 + + def test_empty_string_no_crash(self): + result = LogoDownloader.get_logo_filename_variations("") + assert isinstance(result, list) + + +# --------------------------------------------------------------------------- +# get_logo_directory +# --------------------------------------------------------------------------- + +class TestGetLogoDirectory: + def test_known_sport_returns_string(self): + downloader = LogoDownloader() + result = downloader.get_logo_directory("nfl") + assert isinstance(result, str) + assert len(result) > 0 + + def test_known_sport_nba(self): + downloader = LogoDownloader() + result = downloader.get_logo_directory("nba") + assert "nba" in result.lower() or "sports" in result.lower() + + def test_unknown_sport_returns_string(self): + downloader = LogoDownloader() + result = downloader.get_logo_directory("unknown_sport_xyz") + assert isinstance(result, str) + + +# --------------------------------------------------------------------------- +# ensure_logo_directory +# --------------------------------------------------------------------------- + +class TestEnsureLogoDirectory: + def test_creates_writable_directory(self, tmp_path): + downloader = LogoDownloader() + test_dir = str(tmp_path / "logos" / "nfl") + result = downloader.ensure_logo_directory(test_dir) + assert result is True + assert Path(test_dir).is_dir() + + def test_existing_writable_directory(self, tmp_path): + downloader = LogoDownloader() + test_dir = str(tmp_path) + result = downloader.ensure_logo_directory(test_dir) + assert result is True + + def test_returns_false_when_write_test_fails(self, tmp_path): + """Simulate a directory that exists but raises PermissionError on write.""" + downloader = LogoDownloader() + test_dir = str(tmp_path / "logos") + + import builtins + original_open = builtins.open + + def mock_open(path, *args, **kwargs): + if ".write_test" in str(path): + raise PermissionError("no write access") + return original_open(path, *args, **kwargs) + + with patch("builtins.open", side_effect=mock_open): + result = downloader.ensure_logo_directory(test_dir) + assert result is False diff --git a/test/test_scroll_helper.py b/test/test_scroll_helper.py new file mode 100644 index 00000000..30e76b9c --- /dev/null +++ b/test/test_scroll_helper.py @@ -0,0 +1,317 @@ +""" +Tests for src/common/scroll_helper.py + +Covers ScrollHelper: create_scrolling_image, update_scroll_position, +get_visible_portion, calculate_dynamic_duration, set_* methods, +reset_scroll, clear_cache, get_scroll_info. +""" + +import pytest +import time +from unittest.mock import patch +from PIL import Image + +from src.common.scroll_helper import ScrollHelper + + +DISPLAY_W = 64 +DISPLAY_H = 32 + + +@pytest.fixture +def helper(): + return ScrollHelper(display_width=DISPLAY_W, display_height=DISPLAY_H) + + +def _make_image(width: int = 64, height: int = 32, color=(255, 0, 0)) -> Image.Image: + img = Image.new("RGB", (width, height), color) + return img + + +# --------------------------------------------------------------------------- +# __init__ / initial state +# --------------------------------------------------------------------------- + +class TestScrollHelperInit: + def test_initial_scroll_position(self, helper): + assert helper.scroll_position == 0.0 + + def test_initial_scroll_complete_false(self, helper): + assert helper.scroll_complete is False + + def test_display_dimensions(self, helper): + assert helper.display_width == DISPLAY_W + assert helper.display_height == DISPLAY_H + + +# --------------------------------------------------------------------------- +# create_scrolling_image +# --------------------------------------------------------------------------- + +class TestCreateScrollingImage: + def test_empty_content_returns_blank_image(self, helper): + result = helper.create_scrolling_image([]) + assert isinstance(result, Image.Image) + assert helper.total_scroll_width == 0 + + def test_single_item_creates_image(self, helper): + img = _make_image(width=100) + result = helper.create_scrolling_image([img]) + assert isinstance(result, Image.Image) + assert result.width > DISPLAY_W # includes leading gap + + def test_multiple_items_wider_image(self, helper): + items = [_make_image(width=50), _make_image(width=50)] + result = helper.create_scrolling_image(items) + # Should be wider than two items alone + assert result.width > 100 + + def test_scroll_position_reset(self, helper): + helper.scroll_position = 500.0 + helper.create_scrolling_image([_make_image()]) + assert helper.scroll_position == 0.0 + + def test_cached_array_set(self, helper): + helper.create_scrolling_image([_make_image()]) + assert helper.cached_array is not None + + def test_scroll_complete_reset(self, helper): + helper.scroll_complete = True + helper.create_scrolling_image([_make_image()]) + assert helper.scroll_complete is False + + def test_total_scroll_width_matches_image(self, helper): + img = _make_image(width=200) + result = helper.create_scrolling_image([img]) + assert helper.total_scroll_width == result.width + + +# --------------------------------------------------------------------------- +# set_scrolling_image +# --------------------------------------------------------------------------- + +class TestSetScrollingImage: + def test_sets_cached_image(self, helper): + img = _make_image(width=200) + helper.set_scrolling_image(img) + assert helper.cached_image is img + + def test_sets_cached_array(self, helper): + img = _make_image(width=200) + helper.set_scrolling_image(img) + assert helper.cached_array is not None + + def test_scroll_width_matches_image(self, helper): + img = _make_image(width=300) + helper.set_scrolling_image(img) + assert helper.total_scroll_width == 300 + + def test_none_clears_cache(self, helper): + helper.set_scrolling_image(_make_image()) + helper.set_scrolling_image(None) + assert helper.cached_image is None + + +# --------------------------------------------------------------------------- +# update_scroll_position (time-based mode) +# --------------------------------------------------------------------------- + +class TestUpdateScrollPosition: + def test_position_advances_over_time(self, helper): + helper.create_scrolling_image([_make_image(width=200)]) + helper.scroll_speed = 100.0 # 100 px/s + helper.last_update_time = time.time() - 0.1 # pretend 100ms elapsed + initial = helper.scroll_position + helper.update_scroll_position() + assert helper.scroll_position > initial + + def test_no_advance_without_image(self, helper): + helper.update_scroll_position() # no image, should not crash + assert helper.scroll_position == 0.0 + + def test_zero_width_content_stays_zero(self, helper): + helper.create_scrolling_image([]) # empty → width 0 + helper.update_scroll_position() + assert helper.scroll_position == 0.0 + + def test_scroll_complete_clamped(self, helper): + helper.create_scrolling_image([_make_image(width=100)]) + # Force position past the end + helper.scroll_position = helper.total_scroll_width + 50 + helper.total_distance_scrolled = helper.total_scroll_width + 50 + helper.update_scroll_position() + assert helper.scroll_complete is True + assert helper.scroll_position <= helper.total_scroll_width + + +# --------------------------------------------------------------------------- +# get_visible_portion +# --------------------------------------------------------------------------- + +class TestGetVisiblePortion: + def test_returns_none_without_image(self, helper): + assert helper.get_visible_portion() is None + + def test_returns_image_sized_to_display(self, helper): + helper.create_scrolling_image([_make_image(width=200)]) + visible = helper.get_visible_portion() + assert visible is not None + assert visible.width == DISPLAY_W + assert visible.height == DISPLAY_H + + def test_different_positions_give_different_images(self, helper): + helper.create_scrolling_image([_make_image(width=300)]) + img1 = helper.get_visible_portion() + helper.scroll_position = 50 + img2 = helper.get_visible_portion() + # Images should differ (colour from scrolled content) + # Just verify both are valid PIL images with correct size + assert img1.width == img2.width == DISPLAY_W + + +# --------------------------------------------------------------------------- +# reset_scroll / clear_cache +# --------------------------------------------------------------------------- + +class TestResetAndClear: + def test_reset_restores_position(self, helper): + helper.create_scrolling_image([_make_image(width=200)]) + helper.scroll_position = 100.0 + helper.reset_scroll() + assert helper.scroll_position == 0.0 + + def test_reset_clears_complete_flag(self, helper): + helper.scroll_complete = True + helper.reset_scroll() + assert helper.scroll_complete is False + + def test_reset_alias(self, helper): + helper.scroll_position = 50.0 + helper.reset() + assert helper.scroll_position == 0.0 + + def test_clear_cache(self, helper): + helper.create_scrolling_image([_make_image()]) + helper.clear_cache() + assert helper.cached_image is None + assert helper.cached_array is None + assert helper.total_scroll_width == 0 + + +# --------------------------------------------------------------------------- +# calculate_dynamic_duration +# --------------------------------------------------------------------------- + +class TestCalculateDynamicDuration: + def test_returns_min_when_disabled(self, helper): + helper.dynamic_duration_enabled = False + helper.min_duration = 30 + result = helper.calculate_dynamic_duration() + assert result == 30 + + def test_returns_min_when_no_content(self, helper): + helper.total_scroll_width = 0 + helper.min_duration = 30 + result = helper.calculate_dynamic_duration() + assert result == 30 + + def test_respects_min_duration(self, helper): + helper.create_scrolling_image([_make_image(width=50)]) + helper.min_duration = 60 + helper.max_duration = 300 + helper.scroll_speed = 500.0 # very fast → very short time + result = helper.calculate_dynamic_duration() + assert result >= 60 + + def test_respects_max_duration(self, helper): + helper.create_scrolling_image([_make_image(width=50000)]) + helper.min_duration = 10 + helper.max_duration = 60 + helper.scroll_speed = 1.0 # very slow → very long time + result = helper.calculate_dynamic_duration() + assert result <= 60 + + def test_time_based_calculation(self, helper): + helper.create_scrolling_image([_make_image(width=200)]) + helper.scroll_speed = 100.0 + helper.min_duration = 1 + helper.max_duration = 600 + helper.frame_based_scrolling = False + result = helper.calculate_dynamic_duration() + assert isinstance(result, int) + assert result > 0 + + +# --------------------------------------------------------------------------- +# set_* configuration methods +# --------------------------------------------------------------------------- + +class TestSetMethods: + def test_set_scroll_speed_time_based(self, helper): + helper.frame_based_scrolling = False + helper.set_scroll_speed(50.0) + assert helper.scroll_speed == 50.0 + + def test_set_scroll_speed_clamped_low(self, helper): + helper.frame_based_scrolling = False + helper.set_scroll_speed(0.0) + assert helper.scroll_speed >= 1.0 + + def test_set_scroll_speed_clamped_high(self, helper): + helper.frame_based_scrolling = False + helper.set_scroll_speed(10000.0) + assert helper.scroll_speed <= 500.0 + + def test_set_scroll_delay(self, helper): + helper.set_scroll_delay(0.05) + assert helper.scroll_delay == 0.05 + + def test_set_scroll_delay_clamped(self, helper): + helper.set_scroll_delay(0.0001) + assert helper.scroll_delay >= 0.001 + + def test_set_target_fps(self, helper): + helper.set_target_fps(60.0) + assert helper.target_fps == 60.0 + + def test_set_target_fps_clamped(self, helper): + helper.set_target_fps(1000.0) + assert helper.target_fps <= 200.0 + + def test_set_sub_pixel_scrolling(self, helper): + helper.set_sub_pixel_scrolling(True) + assert helper.sub_pixel_scrolling is True + helper.set_sub_pixel_scrolling(False) + assert helper.sub_pixel_scrolling is False + + def test_set_frame_based_scrolling(self, helper): + helper.set_frame_based_scrolling(True) + assert helper.frame_based_scrolling is True + + def test_set_dynamic_duration_settings(self, helper): + helper.set_dynamic_duration_settings(enabled=True, min_duration=20, max_duration=120, buffer=0.2) + assert helper.dynamic_duration_enabled is True + assert helper.min_duration == 20 + assert helper.max_duration == 120 + assert helper.duration_buffer == pytest.approx(0.2) + + +# --------------------------------------------------------------------------- +# get_scroll_info +# --------------------------------------------------------------------------- + +class TestGetScrollInfo: + def test_returns_dict(self, helper): + info = helper.get_scroll_info() + assert isinstance(info, dict) + + def test_required_keys(self, helper): + info = helper.get_scroll_info() + for key in ("scroll_position", "total_distance_scrolled", "scroll_speed", + "scroll_complete", "dynamic_duration"): + assert key in info + + def test_scroll_position_reflected(self, helper): + helper.scroll_position = 42.0 + info = helper.get_scroll_info() + assert info["scroll_position"] == 42.0 diff --git a/test/test_utils.py b/test/test_utils.py new file mode 100644 index 00000000..127b743e --- /dev/null +++ b/test/test_utils.py @@ -0,0 +1,329 @@ +""" +Tests for src/common/utils.py + +Covers all pure utility functions: normalize_team_abbreviation, format_time, +format_date, get_timezone, validate_dimensions, parse_team_abbreviation, +format_score, format_period, is_live_game, is_final_game, is_upcoming_game, +sanitize_filename, truncate_text, parse_boolean. +""" + +import pytest +from datetime import datetime, timezone +import pytz + +from src.common.utils import ( + normalize_team_abbreviation, + format_time, + format_date, + get_timezone, + validate_dimensions, + parse_team_abbreviation, + format_score, + format_period, + is_live_game, + is_final_game, + is_upcoming_game, + sanitize_filename, + truncate_text, + parse_boolean, +) + + +# --------------------------------------------------------------------------- +# normalize_team_abbreviation +# --------------------------------------------------------------------------- + +class TestNormalizeTeamAbbreviation: + def test_basic_uppercase(self): + assert normalize_team_abbreviation("lal") == "LAL" + + def test_strips_spaces(self): + assert normalize_team_abbreviation(" KC ") == "KC" + + def test_replaces_ampersand(self): + assert normalize_team_abbreviation("TA&M") == "TAANDM" + + def test_removes_internal_spaces(self): + assert normalize_team_abbreviation("A B") == "AB" + + def test_removes_hyphens(self): + assert normalize_team_abbreviation("A-B") == "AB" + + def test_empty_string_returns_empty(self): + assert normalize_team_abbreviation("") == "" + + def test_none_returns_empty(self): + assert normalize_team_abbreviation(None) == "" + + +# --------------------------------------------------------------------------- +# format_time / format_date +# --------------------------------------------------------------------------- + +class TestFormatTime: + def _utc_dt(self, hour=20, minute=30): + return datetime(2024, 1, 15, hour, minute, 0, tzinfo=timezone.utc) + + def test_formats_utc_to_utc(self): + dt = self._utc_dt(20, 30) + result = format_time(dt, timezone_str="UTC") + # 20:30 UTC → "8:30PM" (leading zero stripped) + assert "8:30PM" in result or "8:30 PM" in result or result != "" + + def test_naive_datetime_treated_as_utc(self): + dt = datetime(2024, 1, 15, 12, 0, 0) # naive + result = format_time(dt, timezone_str="UTC") + assert result != "" + + def test_invalid_timezone_returns_empty(self): + dt = self._utc_dt() + result = format_time(dt, timezone_str="Invalid/TZ") + assert result == "" + + def test_eastern_timezone(self): + dt = self._utc_dt(20, 0) # 8 PM UTC = 3 PM ET + result = format_time(dt, timezone_str="America/New_York") + assert result != "" + + +class TestFormatDate: + def test_formats_date(self): + dt = datetime(2024, 6, 15, 18, 0, 0, tzinfo=timezone.utc) + result = format_date(dt, timezone_str="UTC") + assert "June" in result or "15" in result + + def test_naive_datetime(self): + dt = datetime(2024, 3, 10, 12, 0, 0) + result = format_date(dt, timezone_str="UTC") + assert result != "" + + def test_invalid_timezone_returns_empty(self): + dt = datetime(2024, 6, 15, 18, 0, 0, tzinfo=timezone.utc) + result = format_date(dt, timezone_str="BadZone/Here") + assert result == "" + + +# --------------------------------------------------------------------------- +# get_timezone +# --------------------------------------------------------------------------- + +class TestGetTimezone: + def test_valid_timezone(self): + tz = get_timezone("America/New_York") + assert tz is not None + + def test_utc(self): + tz = get_timezone("UTC") + assert tz is pytz.utc or str(tz) == "UTC" + + def test_invalid_returns_utc(self): + tz = get_timezone("Not/ATimezone") + assert tz is pytz.utc + + +# --------------------------------------------------------------------------- +# validate_dimensions +# --------------------------------------------------------------------------- + +class TestValidateDimensions: + def test_valid(self): + assert validate_dimensions(64, 32) is True + + def test_zero_width(self): + assert validate_dimensions(0, 32) is False + + def test_zero_height(self): + assert validate_dimensions(64, 0) is False + + def test_negative(self): + assert validate_dimensions(-1, 32) is False + + def test_too_large(self): + assert validate_dimensions(1001, 32) is False + + def test_max_valid(self): + assert validate_dimensions(1000, 1000) is True + + def test_non_integer(self): + assert validate_dimensions("64", 32) is False # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# parse_team_abbreviation +# --------------------------------------------------------------------------- + +class TestParseTeamAbbreviation: + def test_empty_string(self): + assert parse_team_abbreviation("") == "" + + def test_none_returns_empty(self): + assert parse_team_abbreviation(None) == "" + + def test_extracts_uppercase(self): + result = parse_team_abbreviation("LAL") + assert result == "LAL" + + def test_fallback_first_three(self): + # text without recognisable 2-4 char uppercase block + result = parse_team_abbreviation("ab") + assert len(result) <= 3 + + +# --------------------------------------------------------------------------- +# format_score +# --------------------------------------------------------------------------- + +class TestFormatScore: + def test_format_score(self): + assert format_score(14, 7) == "7-14" + + def test_format_score_strings(self): + assert format_score("21", "14") == "14-21" + + def test_zero_zero(self): + assert format_score(0, 0) == "0-0" + + +# --------------------------------------------------------------------------- +# format_period +# --------------------------------------------------------------------------- + +class TestFormatPeriod: + def test_basketball_q1(self): + assert format_period(1, "basketball") == "Q1" + + def test_basketball_q4(self): + assert format_period(4, "basketball") == "Q4" + + def test_basketball_ot1(self): + assert format_period(5, "basketball") == "OT1" + + def test_basketball_ot2(self): + assert format_period(6, "basketball") == "OT2" + + def test_football_q1(self): + assert format_period(1, "football") == "Q1" + + def test_football_ot(self): + assert format_period(5, "football") == "OT1" + + def test_hockey_p1(self): + assert format_period(1, "hockey") == "P1" + + def test_hockey_p3(self): + assert format_period(3, "hockey") == "P3" + + def test_hockey_ot(self): + assert format_period(4, "hockey") == "OT1" + + def test_baseball_inning(self): + assert format_period(7, "baseball") == "INN 7" + + def test_unknown_sport(self): + result = format_period(2, "unknown") + assert "2" in result + + +# --------------------------------------------------------------------------- +# is_live_game / is_final_game / is_upcoming_game +# --------------------------------------------------------------------------- + +class TestGameStatusHelpers: + def test_is_live_game_true(self): + assert is_live_game("In Progress") is True + assert is_live_game("halftime") is True + assert is_live_game("overtime") is True + + def test_is_live_game_false(self): + assert is_live_game("Final") is False + assert is_live_game("Scheduled") is False + + def test_is_final_game_true(self): + assert is_final_game("Final") is True + assert is_final_game("COMPLETED") is True + + def test_is_final_game_false(self): + assert is_final_game("In Progress") is False + + def test_is_upcoming_game_true(self): + assert is_upcoming_game("Scheduled") is True + assert is_upcoming_game("upcoming") is True + + def test_is_upcoming_game_false(self): + assert is_upcoming_game("Final") is False + assert is_upcoming_game("In Progress") is False + + +# --------------------------------------------------------------------------- +# sanitize_filename +# --------------------------------------------------------------------------- + +class TestSanitizeFilename: + def test_removes_invalid_chars(self): + result = sanitize_filename('file<>:"/\\|?*.txt') + assert "<" not in result + assert ">" not in result + assert ":" not in result + + def test_collapses_underscores(self): + result = sanitize_filename("file___name") + assert "__" not in result + + def test_strips_leading_trailing(self): + result = sanitize_filename("_file_") + assert not result.startswith("_") + assert not result.endswith("_") + + def test_normal_filename_unchanged(self): + result = sanitize_filename("my_logo") + assert result == "my_logo" + + +# --------------------------------------------------------------------------- +# truncate_text +# --------------------------------------------------------------------------- + +class TestTruncateText: + def test_no_truncation_needed(self): + assert truncate_text("hello", 10) == "hello" + + def test_truncation_adds_suffix(self): + result = truncate_text("hello world", 8) + assert result.endswith("...") + assert len(result) == 8 + + def test_exact_length(self): + assert truncate_text("hello", 5) == "hello" + + def test_custom_suffix(self): + result = truncate_text("hello world", 8, suffix="~") + assert result.endswith("~") + + +# --------------------------------------------------------------------------- +# parse_boolean +# --------------------------------------------------------------------------- + +class TestParseBoolean: + def test_true_bool(self): + assert parse_boolean(True) is True + + def test_false_bool(self): + assert parse_boolean(False) is False + + def test_int_1(self): + assert parse_boolean(1) is True + + def test_int_0(self): + assert parse_boolean(0) is False + + def test_string_true(self): + for val in ("true", "True", "TRUE", "1", "yes", "on", "enabled"): + assert parse_boolean(val) is True, f"Expected True for {val!r}" + + def test_string_false(self): + for val in ("false", "False", "0", "no", "off", "disabled"): + assert parse_boolean(val) is False, f"Expected False for {val!r}" + + def test_none_returns_false(self): + assert parse_boolean(None) is False # type: ignore[arg-type] diff --git a/test/test_vegas_config.py b/test/test_vegas_config.py new file mode 100644 index 00000000..ebacafe7 --- /dev/null +++ b/test/test_vegas_config.py @@ -0,0 +1,310 @@ +""" +Tests for src/vegas_mode/config.py + +Covers VegasModeConfig: from_config, to_dict, get_frame_interval, +is_plugin_included, get_ordered_plugins, validate, update. +""" + +import pytest +from src.vegas_mode.config import VegasModeConfig + + +# --------------------------------------------------------------------------- +# Default construction +# --------------------------------------------------------------------------- + +class TestVegasModeConfigDefaults: + def test_default_disabled(self): + cfg = VegasModeConfig() + assert cfg.enabled is False + + def test_default_scroll_speed(self): + cfg = VegasModeConfig() + assert cfg.scroll_speed == 50.0 + + def test_default_separator_width(self): + cfg = VegasModeConfig() + assert cfg.separator_width == 32 + + def test_default_target_fps(self): + cfg = VegasModeConfig() + assert cfg.target_fps == 125 + + def test_default_plugin_order_empty(self): + cfg = VegasModeConfig() + assert cfg.plugin_order == [] + + def test_default_excluded_plugins_empty(self): + cfg = VegasModeConfig() + assert len(cfg.excluded_plugins) == 0 + + +# --------------------------------------------------------------------------- +# from_config +# --------------------------------------------------------------------------- + +class TestFromConfig: + def _cfg(self, **kwargs) -> dict: + return {"display": {"vegas_scroll": kwargs}} + + def test_enabled_flag(self): + cfg = VegasModeConfig.from_config(self._cfg(enabled=True)) + assert cfg.enabled is True + + def test_scroll_speed(self): + cfg = VegasModeConfig.from_config(self._cfg(scroll_speed=80.0)) + assert cfg.scroll_speed == 80.0 + + def test_separator_width(self): + cfg = VegasModeConfig.from_config(self._cfg(separator_width=16)) + assert cfg.separator_width == 16 + + def test_plugin_order(self): + cfg = VegasModeConfig.from_config(self._cfg(plugin_order=["a", "b", "c"])) + assert cfg.plugin_order == ["a", "b", "c"] + + def test_excluded_plugins(self): + cfg = VegasModeConfig.from_config(self._cfg(excluded_plugins=["x", "y"])) + assert "x" in cfg.excluded_plugins + assert "y" in cfg.excluded_plugins + + def test_target_fps(self): + cfg = VegasModeConfig.from_config(self._cfg(target_fps=60)) + assert cfg.target_fps == 60 + + def test_buffer_ahead(self): + cfg = VegasModeConfig.from_config(self._cfg(buffer_ahead=3)) + assert cfg.buffer_ahead == 3 + + def test_min_max_cycle_duration(self): + cfg = VegasModeConfig.from_config(self._cfg(min_cycle_duration=30, max_cycle_duration=120)) + assert cfg.min_cycle_duration == 30 + assert cfg.max_cycle_duration == 120 + + def test_defaults_when_missing(self): + cfg = VegasModeConfig.from_config({}) + assert cfg.enabled is False + assert cfg.scroll_speed == 50.0 + + def test_frame_based_scrolling(self): + cfg = VegasModeConfig.from_config(self._cfg(frame_based_scrolling=False)) + assert cfg.frame_based_scrolling is False + + +# --------------------------------------------------------------------------- +# to_dict +# --------------------------------------------------------------------------- + +class TestToDict: + def test_roundtrip(self): + original = VegasModeConfig( + enabled=True, + scroll_speed=75.0, + separator_width=24, + plugin_order=["a", "b"], + excluded_plugins={"z"}, + target_fps=100, + ) + d = original.to_dict() + assert d["enabled"] is True + assert d["scroll_speed"] == 75.0 + assert d["separator_width"] == 24 + assert d["plugin_order"] == ["a", "b"] + assert "z" in d["excluded_plugins"] + assert d["target_fps"] == 100 + + def test_excluded_plugins_is_list(self): + cfg = VegasModeConfig(excluded_plugins={"x"}) + d = cfg.to_dict() + assert isinstance(d["excluded_plugins"], list) + + def test_all_keys_present(self): + d = VegasModeConfig().to_dict() + for key in ("enabled", "scroll_speed", "separator_width", "plugin_order", + "excluded_plugins", "target_fps", "buffer_ahead", + "frame_based_scrolling", "scroll_delay", + "dynamic_duration_enabled", "min_cycle_duration", "max_cycle_duration"): + assert key in d + + +# --------------------------------------------------------------------------- +# get_frame_interval +# --------------------------------------------------------------------------- + +class TestGetFrameInterval: + def test_125fps(self): + cfg = VegasModeConfig(target_fps=125) + assert abs(cfg.get_frame_interval() - 1.0 / 125) < 1e-9 + + def test_60fps(self): + cfg = VegasModeConfig(target_fps=60) + assert abs(cfg.get_frame_interval() - 1.0 / 60) < 1e-6 + + def test_zero_fps_guarded(self): + cfg = VegasModeConfig(target_fps=0) + # Should not raise ZeroDivisionError (max(1, fps) guard) + result = cfg.get_frame_interval() + assert result == 1.0 + + +# --------------------------------------------------------------------------- +# is_plugin_included +# --------------------------------------------------------------------------- + +class TestIsPluginIncluded: + def test_not_excluded_is_included(self): + cfg = VegasModeConfig(excluded_plugins={"bad_plugin"}) + assert cfg.is_plugin_included("good_plugin") is True + + def test_excluded_plugin_not_included(self): + cfg = VegasModeConfig(excluded_plugins={"bad_plugin"}) + assert cfg.is_plugin_included("bad_plugin") is False + + def test_empty_exclusions_all_included(self): + cfg = VegasModeConfig() + assert cfg.is_plugin_included("anything") is True + + +# --------------------------------------------------------------------------- +# get_ordered_plugins +# --------------------------------------------------------------------------- + +class TestGetOrderedPlugins: + def test_natural_order_when_no_order_configured(self): + cfg = VegasModeConfig() + available = ["a", "b", "c"] + result = cfg.get_ordered_plugins(available) + assert result == ["a", "b", "c"] + + def test_explicit_order_followed(self): + cfg = VegasModeConfig(plugin_order=["c", "a", "b"]) + available = ["a", "b", "c"] + result = cfg.get_ordered_plugins(available) + assert result == ["c", "a", "b"] + + def test_unavailable_plugins_skipped(self): + cfg = VegasModeConfig(plugin_order=["c", "x", "a"]) + available = ["a", "b", "c"] + result = cfg.get_ordered_plugins(available) + assert "x" not in result + assert result[:2] == ["c", "a"] + + def test_excluded_plugins_removed(self): + cfg = VegasModeConfig(excluded_plugins={"b"}) + available = ["a", "b", "c"] + result = cfg.get_ordered_plugins(available) + assert "b" not in result + + def test_unordered_available_appended(self): + cfg = VegasModeConfig(plugin_order=["a"]) + available = ["a", "b", "c"] + result = cfg.get_ordered_plugins(available) + assert result[0] == "a" + assert "b" in result + assert "c" in result + + def test_empty_available(self): + cfg = VegasModeConfig(plugin_order=["a"]) + result = cfg.get_ordered_plugins([]) + assert result == [] + + +# --------------------------------------------------------------------------- +# validate +# --------------------------------------------------------------------------- + +class TestValidate: + def test_valid_config_no_errors(self): + cfg = VegasModeConfig() + errors = cfg.validate() + assert errors == [] + + def test_scroll_speed_too_low(self): + cfg = VegasModeConfig(scroll_speed=0.5) + errors = cfg.validate() + assert any("scroll_speed" in e for e in errors) + + def test_scroll_speed_too_high(self): + cfg = VegasModeConfig(scroll_speed=300.0) + errors = cfg.validate() + assert any("scroll_speed" in e for e in errors) + + def test_separator_width_negative(self): + cfg = VegasModeConfig(separator_width=-1) + errors = cfg.validate() + assert any("separator_width" in e for e in errors) + + def test_separator_width_too_large(self): + cfg = VegasModeConfig(separator_width=200) + errors = cfg.validate() + assert any("separator_width" in e for e in errors) + + def test_target_fps_too_low(self): + cfg = VegasModeConfig(target_fps=10) + errors = cfg.validate() + assert any("target_fps" in e for e in errors) + + def test_target_fps_too_high(self): + cfg = VegasModeConfig(target_fps=300) + errors = cfg.validate() + assert any("target_fps" in e for e in errors) + + def test_buffer_ahead_too_low(self): + cfg = VegasModeConfig(buffer_ahead=0) + errors = cfg.validate() + assert any("buffer_ahead" in e for e in errors) + + def test_buffer_ahead_too_high(self): + cfg = VegasModeConfig(buffer_ahead=10) + errors = cfg.validate() + assert any("buffer_ahead" in e for e in errors) + + def test_multiple_errors_returned(self): + cfg = VegasModeConfig(scroll_speed=0.1, target_fps=5) + errors = cfg.validate() + assert len(errors) >= 2 + + +# --------------------------------------------------------------------------- +# update +# --------------------------------------------------------------------------- + +class TestUpdate: + def _wrap(self, **kwargs) -> dict: + return {"display": {"vegas_scroll": kwargs}} + + def test_update_enabled(self): + cfg = VegasModeConfig(enabled=False) + cfg.update(self._wrap(enabled=True)) + assert cfg.enabled is True + + def test_update_scroll_speed(self): + cfg = VegasModeConfig(scroll_speed=50.0) + cfg.update(self._wrap(scroll_speed=90.0)) + assert cfg.scroll_speed == 90.0 + + def test_update_separator_width(self): + cfg = VegasModeConfig(separator_width=32) + cfg.update(self._wrap(separator_width=8)) + assert cfg.separator_width == 8 + + def test_update_plugin_order(self): + cfg = VegasModeConfig(plugin_order=[]) + cfg.update(self._wrap(plugin_order=["x", "y"])) + assert cfg.plugin_order == ["x", "y"] + + def test_update_excluded_plugins(self): + cfg = VegasModeConfig() + cfg.update(self._wrap(excluded_plugins=["skip_me"])) + assert "skip_me" in cfg.excluded_plugins + + def test_update_ignores_missing_keys(self): + cfg = VegasModeConfig(scroll_speed=50.0) + cfg.update(self._wrap(target_fps=80)) # only fps, not speed + assert cfg.scroll_speed == 50.0 + assert cfg.target_fps == 80 + + def test_empty_update_no_change(self): + cfg = VegasModeConfig(scroll_speed=50.0) + cfg.update({}) + assert cfg.scroll_speed == 50.0