From 5cba55d670843e727870e95023327a29cec0f87f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 20 Feb 2024 18:49:55 +1100 Subject: [PATCH] test: clean up & fix tests - Deduplicate the mock invocation services. This is possible now that the import order issue is resolved. - Merge `DummyEventService` into `TestEventService` and update all tests to use `TestEventService`. --- tests/app/routers/test_images.py | 49 ------------- .../bulk_download/test_bulk_download.py | 71 ++++--------------- .../services/download/test_download_queue.py | 6 +- tests/conftest.py | 55 +++++++++++++- tests/fixtures/event_service.py | 27 ------- tests/test_graph_execution_state.py | 39 ---------- tests/test_nodes.py | 15 ++-- 7 files changed, 78 insertions(+), 184 deletions(-) delete mode 100644 tests/fixtures/event_service.py diff --git a/tests/app/routers/test_images.py b/tests/app/routers/test_images.py index 67297a116f..5cb8cf1c37 100644 --- a/tests/app/routers/test_images.py +++ b/tests/app/routers/test_images.py @@ -1,66 +1,17 @@ from pathlib import Path from typing import Any -import pytest from fastapi import BackgroundTasks from fastapi.testclient import TestClient from invokeai.app.api.dependencies import ApiDependencies from invokeai.app.api_app import app from invokeai.app.services.board_records.board_records_common import BoardRecord -from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage -from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService -from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.app.services.images.images_default import ImageService -from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invoker import Invoker -from invokeai.backend.util.logging import InvokeAILogger -from tests.fixtures.sqlite_database import create_mock_sqlite_database client = TestClient(app) -@pytest.fixture -def mock_services(tmp_path: Path) -> InvocationServices: - configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0) - logger = InvokeAILogger.get_logger() - db = create_mock_sqlite_database(configuration, logger) - - return InvocationServices( - board_image_records=None, # type: ignore - board_images=None, # type: ignore - board_records=SqliteBoardRecordStorage(db=db), - boards=None, # type: ignore - bulk_download=BulkDownloadService(), - configuration=None, # type: ignore - events=None, # type: ignore - graph_execution_manager=None, # type: ignore - image_files=None, # type: ignore - image_records=None, # type: ignore - images=ImageService(), - invocation_cache=None, # type: ignore - latents=None, # type: ignore - logger=logger, - model_manager=None, # type: ignore - model_records=None, # type: ignore - download_queue=None, # type: ignore - model_install=None, # type: ignore - names=None, # type: ignore - performance_statistics=None, # type: ignore - processor=None, # type: ignore - queue=None, # type: ignore - session_processor=None, # type: ignore - session_queue=None, # type: ignore - urls=None, # type: ignore - workflow_records=None, # type: ignore - ) - - -@pytest.fixture() -def mock_invoker(mock_services: InvocationServices) -> Invoker: - return Invoker(services=mock_services) - - class MockApiDependencies(ApiDependencies): invoker: Invoker diff --git a/tests/app/services/bulk_download/test_bulk_download.py b/tests/app/services/bulk_download/test_bulk_download.py index 3e8b7fd2eb..b18f6e038d 100644 --- a/tests/app/services/bulk_download/test_bulk_download.py +++ b/tests/app/services/bulk_download/test_bulk_download.py @@ -7,23 +7,17 @@ from zipfile import ZipFile import pytest from invokeai.app.services.board_records.board_records_common import BoardRecord, BoardRecordNotFoundException -from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage from invokeai.app.services.bulk_download.bulk_download_common import BulkDownloadTargetException from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService -from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.image_records.image_records_common import ( ImageCategory, ImageRecordNotFoundException, ResourceOrigin, ) from invokeai.app.services.images.images_common import ImageDTO -from invokeai.app.services.images.images_default import ImageService -from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invoker import Invoker from invokeai.app.services.shared.pagination import OffsetPaginatedResults -from invokeai.backend.util.logging import InvokeAILogger -from tests.fixtures.event_service import DummyEventService -from tests.fixtures.sqlite_database import create_mock_sqlite_database +from tests.test_nodes import TestEventService @pytest.fixture @@ -46,53 +40,6 @@ def mock_image_dto() -> ImageDTO: ) -@pytest.fixture -def mock_event_service() -> DummyEventService: - """Create a dummy event service.""" - return DummyEventService() - - -@pytest.fixture -def mock_services(mock_event_service: DummyEventService) -> InvocationServices: - configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0) - logger = InvokeAILogger.get_logger() - db = create_mock_sqlite_database(configuration, logger) - - return InvocationServices( - board_image_records=None, # type: ignore - board_images=None, # type: ignore - board_records=SqliteBoardRecordStorage(db=db), - boards=None, # type: ignore - bulk_download=None, # type: ignore - configuration=None, # type: ignore - events=mock_event_service, - graph_execution_manager=None, # type: ignore - image_files=None, # type: ignore - image_records=None, # type: ignore - images=ImageService(), - invocation_cache=None, # type: ignore - latents=None, # type: ignore - logger=logger, - model_manager=None, # type: ignore - model_records=None, # type: ignore - download_queue=None, # type: ignore - model_install=None, # type: ignore - names=None, # type: ignore - performance_statistics=None, # type: ignore - processor=None, # type: ignore - queue=None, # type: ignore - session_processor=None, # type: ignore - session_queue=None, # type: ignore - urls=None, # type: ignore - workflow_records=None, # type: ignore - ) - - -@pytest.fixture() -def mock_invoker(mock_services: InvocationServices) -> Invoker: - return Invoker(services=mock_services) - - @pytest.fixture(autouse=True) def mock_temporary_directory(monkeypatch: Any, tmp_path: Path): """Mock the TemporaryDirectory class so that it uses the tmp_path fixture.""" @@ -288,6 +235,16 @@ def prepare_handler_test(tmp_path: Path, monkeypatch: Any, mock_image_dto: Image monkeypatch.setattr(mock_invoker.services.images, "get_dto", mock_get_dto) + # This is used when preparing all images for a given board + def mock_get_all_board_image_names_for_board(*args, **kwargs): + return [mock_image_dto.image_name] + + monkeypatch.setattr( + mock_invoker.services.board_image_records, + "get_all_board_image_names_for_board", + mock_get_all_board_image_names_for_board, + ) + # Create a mock image file so that the contents of the zip file are not empty mock_image_path: Path = tmp_path / mock_image_dto.image_name mock_image_contents: str = "Totally an image" @@ -306,7 +263,7 @@ def assert_handler_success( expected_image_path: Path, mock_image_contents: str, tmp_path: Path, - event_bus: DummyEventService, + event_bus: TestEventService, ): """Assert that the handler was successful.""" # Check that the zip file was created @@ -369,7 +326,7 @@ def test_handler_on_generic_exception( with pytest.raises(Exception): # noqa: B017 execute_handler_test_on_error(tmp_path, monkeypatch, mock_image_dto, mock_invoker, exception) - event_bus: DummyEventService = mock_invoker.services.events + event_bus: TestEventService = mock_invoker.services.events assert len(event_bus.events) == 2 assert event_bus.events[0].event_name == "bulk_download_started" @@ -384,7 +341,7 @@ def execute_handler_test_on_error( bulk_download_service.start(mock_invoker) bulk_download_service.handler([mock_image_dto.image_name], None, None) - event_bus: DummyEventService = mock_invoker.services.events + event_bus: TestEventService = mock_invoker.services.events assert len(event_bus.events) == 2 assert event_bus.events[0].event_name == "bulk_download_started" diff --git a/tests/app/services/download/test_download_queue.py b/tests/app/services/download/test_download_queue.py index 34408ac5ae..ff9b193b17 100644 --- a/tests/app/services/download/test_download_queue.py +++ b/tests/app/services/download/test_download_queue.py @@ -9,7 +9,7 @@ from requests.sessions import Session from requests_testadapter import TestAdapter, TestSession from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService -from tests.fixtures.event_service import DummyEventService +from tests.test_nodes import TestEventService # Prevent pytest deprecation warnings TestAdapter.__test__ = False # type: ignore @@ -101,7 +101,7 @@ def test_errors(tmp_path: Path, session: Session) -> None: def test_event_bus(tmp_path: Path, session: Session) -> None: - event_bus = DummyEventService() + event_bus = TestEventService() queue = DownloadQueueService(requests_session=session, event_bus=event_bus) queue.start() @@ -167,7 +167,7 @@ def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None: def test_cancel(tmp_path: Path, session: Session) -> None: - event_bus = DummyEventService() + event_bus = TestEventService() queue = DownloadQueueService(requests_session=session, event_bus=event_bus) queue.start() diff --git a/tests/conftest.py b/tests/conftest.py index 873ccc13fd..a483b7529a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,4 +4,57 @@ # We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not # play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures. -from invokeai.backend.util.test_utils import torch_device # noqa: F401 +import logging + +import pytest + +from invokeai.app.services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage +from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage +from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService +from invokeai.app.services.config.config_default import InvokeAIAppConfig +from invokeai.app.services.images.images_default import ImageService +from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache +from invokeai.app.services.invocation_services import InvocationServices +from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService +from invokeai.app.services.invoker import Invoker +from invokeai.backend.util.logging import InvokeAILogger +from tests.fixtures.sqlite_database import create_mock_sqlite_database # noqa: F401 +from tests.test_nodes import TestEventService + + +@pytest.fixture +def mock_services() -> InvocationServices: + configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0) + logger = InvokeAILogger.get_logger() + db = create_mock_sqlite_database(configuration, logger) + + # NOTE: none of these are actually called by the test invocations + return InvocationServices( + board_image_records=SqliteBoardImageRecordStorage(db=db), + board_images=None, # type: ignore + board_records=SqliteBoardRecordStorage(db=db), + boards=None, # type: ignore + bulk_download=BulkDownloadService(), + configuration=configuration, + events=TestEventService(), + image_files=None, # type: ignore + image_records=None, # type: ignore + images=ImageService(), + invocation_cache=MemoryInvocationCache(max_cache_size=0), + logger=logging, # type: ignore + model_manager=None, # type: ignore + download_queue=None, # type: ignore + names=None, # type: ignore + performance_statistics=InvocationStatsService(), + session_processor=None, # type: ignore + session_queue=None, # type: ignore + urls=None, # type: ignore + workflow_records=None, # type: ignore + tensors=None, # type: ignore + conditioning=None, # type: ignore + ) + + +@pytest.fixture() +def mock_invoker(mock_services: InvocationServices) -> Invoker: + return Invoker(services=mock_services) diff --git a/tests/fixtures/event_service.py b/tests/fixtures/event_service.py deleted file mode 100644 index 8f6a45c38f..0000000000 --- a/tests/fixtures/event_service.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import Any, Dict, List - -from pydantic import BaseModel - -from invokeai.app.services.events.events_base import EventServiceBase - - -class DummyEvent(BaseModel): - """Dummy Event to use with Dummy Event service.""" - - event_name: str - payload: Dict[str, Any] - - -# A dummy event service for testing event issuing -class DummyEventService(EventServiceBase): - """Dummy event service for testing.""" - - events: List[DummyEvent] - - def __init__(self) -> None: - super().__init__() - self.events = [] - - def dispatch(self, event_name: str, payload: Any) -> None: - """Dispatch an event by appending it to self.events.""" - self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"])) diff --git a/tests/test_graph_execution_state.py b/tests/test_graph_execution_state.py index 9a35037431..0bb15b17df 100644 --- a/tests/test_graph_execution_state.py +++ b/tests/test_graph_execution_state.py @@ -1,4 +1,3 @@ -import logging from typing import Optional from unittest.mock import Mock @@ -8,17 +7,12 @@ import pytest from .test_nodes import ( # isort: split PromptCollectionTestInvocation, PromptTestInvocation, - TestEventService, TextToImageTestInvocation, ) from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from invokeai.app.invocations.collections import RangeInvocation from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation -from invokeai.app.services.config.config_default import InvokeAIAppConfig -from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache -from invokeai.app.services.invocation_services import InvocationServices -from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService from invokeai.app.services.shared.graph import ( CollectInvocation, Graph, @@ -38,39 +32,6 @@ def simple_graph() -> Graph: return g -# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types -# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate -# the test invocations. -@pytest.fixture -def mock_services() -> InvocationServices: - configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0) - # NOTE: none of these are actually called by the test invocations - return InvocationServices( - board_image_records=None, # type: ignore - board_images=None, # type: ignore - board_records=None, # type: ignore - boards=None, # type: ignore - bulk_download=None, # type: ignore - configuration=configuration, - events=TestEventService(), - image_files=None, # type: ignore - image_records=None, # type: ignore - images=None, # type: ignore - invocation_cache=MemoryInvocationCache(max_cache_size=0), - logger=logging, # type: ignore - model_manager=None, # type: ignore - download_queue=None, # type: ignore - names=None, # type: ignore - performance_statistics=InvocationStatsService(), - session_processor=None, # type: ignore - session_queue=None, # type: ignore - urls=None, # type: ignore - workflow_records=None, # type: ignore - tensors=None, # type: ignore - conditioning=None, # type: ignore - ) - - def invoke_next(g: GraphExecutionState) -> tuple[Optional[BaseInvocation], Optional[BaseInvocationOutput]]: n = g.next() if n is None: diff --git a/tests/test_nodes.py b/tests/test_nodes.py index aab3d9c7b4..e1fe857040 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -1,5 +1,7 @@ from typing import Any, Callable, Union +from pydantic import BaseModel + from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -115,25 +117,22 @@ def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edg ) -class TestEvent: - event_name: str - payload: Any +class TestEvent(BaseModel): __test__ = False # not a pytest test case - def __init__(self, event_name: str, payload: Any): - self.event_name = event_name - self.payload = payload + event_name: str + payload: Any class TestEventService(EventServiceBase): - events: list __test__ = False # not a pytest test case def __init__(self): super().__init__() - self.events = [] + self.events: list[TestEvent] = [] def dispatch(self, event_name: str, payload: Any) -> None: + self.events.append(TestEvent(event_name=payload["event"], payload=payload["data"])) pass