test: clean up & fix tests

- Deduplicate the mock invocation services. This is possible now that the import order issue is resolved.
- Merge `DummyEventService` into `TestEventService` and update all tests to use `TestEventService`.
This commit is contained in:
psychedelicious 2024-02-20 18:49:55 +11:00
parent cbb997e7d0
commit 5cba55d670
7 changed files with 78 additions and 184 deletions

View File

@ -1,66 +1,17 @@
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
import pytest
from fastapi import BackgroundTasks from fastapi import BackgroundTasks
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from invokeai.app.api.dependencies import ApiDependencies from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.api_app import app from invokeai.app.api_app import app
from invokeai.app.services.board_records.board_records_common import BoardRecord from invokeai.app.services.board_records.board_records_common import BoardRecord
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.images.images_default import ImageService
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import create_mock_sqlite_database
client = TestClient(app) client = TestClient(app)
@pytest.fixture
def mock_services(tmp_path: Path) -> InvocationServices:
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(configuration, logger)
return InvocationServices(
board_image_records=None, # type: ignore
board_images=None, # type: ignore
board_records=SqliteBoardRecordStorage(db=db),
boards=None, # type: ignore
bulk_download=BulkDownloadService(),
configuration=None, # type: ignore
events=None, # type: ignore
graph_execution_manager=None, # type: ignore
image_files=None, # type: ignore
image_records=None, # type: ignore
images=ImageService(),
invocation_cache=None, # type: ignore
latents=None, # type: ignore
logger=logger,
model_manager=None, # type: ignore
model_records=None, # type: ignore
download_queue=None, # type: ignore
model_install=None, # type: ignore
names=None, # type: ignore
performance_statistics=None, # type: ignore
processor=None, # type: ignore
queue=None, # type: ignore
session_processor=None, # type: ignore
session_queue=None, # type: ignore
urls=None, # type: ignore
workflow_records=None, # type: ignore
)
@pytest.fixture()
def mock_invoker(mock_services: InvocationServices) -> Invoker:
return Invoker(services=mock_services)
class MockApiDependencies(ApiDependencies): class MockApiDependencies(ApiDependencies):
invoker: Invoker invoker: Invoker

View File

@ -7,23 +7,17 @@ from zipfile import ZipFile
import pytest import pytest
from invokeai.app.services.board_records.board_records_common import BoardRecord, BoardRecordNotFoundException from invokeai.app.services.board_records.board_records_common import BoardRecord, BoardRecordNotFoundException
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from invokeai.app.services.bulk_download.bulk_download_common import BulkDownloadTargetException from invokeai.app.services.bulk_download.bulk_download_common import BulkDownloadTargetException
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_records.image_records_common import ( from invokeai.app.services.image_records.image_records_common import (
ImageCategory, ImageCategory,
ImageRecordNotFoundException, ImageRecordNotFoundException,
ResourceOrigin, ResourceOrigin,
) )
from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.images.images_default import ImageService
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.backend.util.logging import InvokeAILogger from tests.test_nodes import TestEventService
from tests.fixtures.event_service import DummyEventService
from tests.fixtures.sqlite_database import create_mock_sqlite_database
@pytest.fixture @pytest.fixture
@ -46,53 +40,6 @@ def mock_image_dto() -> ImageDTO:
) )
@pytest.fixture
def mock_event_service() -> DummyEventService:
"""Create a dummy event service."""
return DummyEventService()
@pytest.fixture
def mock_services(mock_event_service: DummyEventService) -> InvocationServices:
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(configuration, logger)
return InvocationServices(
board_image_records=None, # type: ignore
board_images=None, # type: ignore
board_records=SqliteBoardRecordStorage(db=db),
boards=None, # type: ignore
bulk_download=None, # type: ignore
configuration=None, # type: ignore
events=mock_event_service,
graph_execution_manager=None, # type: ignore
image_files=None, # type: ignore
image_records=None, # type: ignore
images=ImageService(),
invocation_cache=None, # type: ignore
latents=None, # type: ignore
logger=logger,
model_manager=None, # type: ignore
model_records=None, # type: ignore
download_queue=None, # type: ignore
model_install=None, # type: ignore
names=None, # type: ignore
performance_statistics=None, # type: ignore
processor=None, # type: ignore
queue=None, # type: ignore
session_processor=None, # type: ignore
session_queue=None, # type: ignore
urls=None, # type: ignore
workflow_records=None, # type: ignore
)
@pytest.fixture()
def mock_invoker(mock_services: InvocationServices) -> Invoker:
return Invoker(services=mock_services)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_temporary_directory(monkeypatch: Any, tmp_path: Path): def mock_temporary_directory(monkeypatch: Any, tmp_path: Path):
"""Mock the TemporaryDirectory class so that it uses the tmp_path fixture.""" """Mock the TemporaryDirectory class so that it uses the tmp_path fixture."""
@ -288,6 +235,16 @@ def prepare_handler_test(tmp_path: Path, monkeypatch: Any, mock_image_dto: Image
monkeypatch.setattr(mock_invoker.services.images, "get_dto", mock_get_dto) monkeypatch.setattr(mock_invoker.services.images, "get_dto", mock_get_dto)
# This is used when preparing all images for a given board
def mock_get_all_board_image_names_for_board(*args, **kwargs):
return [mock_image_dto.image_name]
monkeypatch.setattr(
mock_invoker.services.board_image_records,
"get_all_board_image_names_for_board",
mock_get_all_board_image_names_for_board,
)
# Create a mock image file so that the contents of the zip file are not empty # Create a mock image file so that the contents of the zip file are not empty
mock_image_path: Path = tmp_path / mock_image_dto.image_name mock_image_path: Path = tmp_path / mock_image_dto.image_name
mock_image_contents: str = "Totally an image" mock_image_contents: str = "Totally an image"
@ -306,7 +263,7 @@ def assert_handler_success(
expected_image_path: Path, expected_image_path: Path,
mock_image_contents: str, mock_image_contents: str,
tmp_path: Path, tmp_path: Path,
event_bus: DummyEventService, event_bus: TestEventService,
): ):
"""Assert that the handler was successful.""" """Assert that the handler was successful."""
# Check that the zip file was created # Check that the zip file was created
@ -369,7 +326,7 @@ def test_handler_on_generic_exception(
with pytest.raises(Exception): # noqa: B017 with pytest.raises(Exception): # noqa: B017
execute_handler_test_on_error(tmp_path, monkeypatch, mock_image_dto, mock_invoker, exception) execute_handler_test_on_error(tmp_path, monkeypatch, mock_image_dto, mock_invoker, exception)
event_bus: DummyEventService = mock_invoker.services.events event_bus: TestEventService = mock_invoker.services.events
assert len(event_bus.events) == 2 assert len(event_bus.events) == 2
assert event_bus.events[0].event_name == "bulk_download_started" assert event_bus.events[0].event_name == "bulk_download_started"
@ -384,7 +341,7 @@ def execute_handler_test_on_error(
bulk_download_service.start(mock_invoker) bulk_download_service.start(mock_invoker)
bulk_download_service.handler([mock_image_dto.image_name], None, None) bulk_download_service.handler([mock_image_dto.image_name], None, None)
event_bus: DummyEventService = mock_invoker.services.events event_bus: TestEventService = mock_invoker.services.events
assert len(event_bus.events) == 2 assert len(event_bus.events) == 2
assert event_bus.events[0].event_name == "bulk_download_started" assert event_bus.events[0].event_name == "bulk_download_started"

View File

@ -9,7 +9,7 @@ from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession from requests_testadapter import TestAdapter, TestSession
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
from tests.fixtures.event_service import DummyEventService from tests.test_nodes import TestEventService
# Prevent pytest deprecation warnings # Prevent pytest deprecation warnings
TestAdapter.__test__ = False # type: ignore TestAdapter.__test__ = False # type: ignore
@ -101,7 +101,7 @@ def test_errors(tmp_path: Path, session: Session) -> None:
def test_event_bus(tmp_path: Path, session: Session) -> None: def test_event_bus(tmp_path: Path, session: Session) -> None:
event_bus = DummyEventService() event_bus = TestEventService()
queue = DownloadQueueService(requests_session=session, event_bus=event_bus) queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
queue.start() queue.start()
@ -167,7 +167,7 @@ def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
def test_cancel(tmp_path: Path, session: Session) -> None: def test_cancel(tmp_path: Path, session: Session) -> None:
event_bus = DummyEventService() event_bus = TestEventService()
queue = DownloadQueueService(requests_session=session, event_bus=event_bus) queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
queue.start() queue.start()

View File

@ -4,4 +4,57 @@
# We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not # We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not
# play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures. # play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures.
from invokeai.backend.util.test_utils import torch_device # noqa: F401 import logging
import pytest
from invokeai.app.services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.images.images_default import ImageService
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
from invokeai.app.services.invoker import Invoker
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import create_mock_sqlite_database # noqa: F401
from tests.test_nodes import TestEventService
@pytest.fixture
def mock_services() -> InvocationServices:
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(configuration, logger)
# NOTE: none of these are actually called by the test invocations
return InvocationServices(
board_image_records=SqliteBoardImageRecordStorage(db=db),
board_images=None, # type: ignore
board_records=SqliteBoardRecordStorage(db=db),
boards=None, # type: ignore
bulk_download=BulkDownloadService(),
configuration=configuration,
events=TestEventService(),
image_files=None, # type: ignore
image_records=None, # type: ignore
images=ImageService(),
invocation_cache=MemoryInvocationCache(max_cache_size=0),
logger=logging, # type: ignore
model_manager=None, # type: ignore
download_queue=None, # type: ignore
names=None, # type: ignore
performance_statistics=InvocationStatsService(),
session_processor=None, # type: ignore
session_queue=None, # type: ignore
urls=None, # type: ignore
workflow_records=None, # type: ignore
tensors=None, # type: ignore
conditioning=None, # type: ignore
)
@pytest.fixture()
def mock_invoker(mock_services: InvocationServices) -> Invoker:
return Invoker(services=mock_services)

View File

@ -1,27 +0,0 @@
from typing import Any, Dict, List
from pydantic import BaseModel
from invokeai.app.services.events.events_base import EventServiceBase
class DummyEvent(BaseModel):
"""Dummy Event to use with Dummy Event service."""
event_name: str
payload: Dict[str, Any]
# A dummy event service for testing event issuing
class DummyEventService(EventServiceBase):
"""Dummy event service for testing."""
events: List[DummyEvent]
def __init__(self) -> None:
super().__init__()
self.events = []
def dispatch(self, event_name: str, payload: Any) -> None:
"""Dispatch an event by appending it to self.events."""
self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"]))

View File

@ -1,4 +1,3 @@
import logging
from typing import Optional from typing import Optional
from unittest.mock import Mock from unittest.mock import Mock
@ -8,17 +7,12 @@ import pytest
from .test_nodes import ( # isort: split from .test_nodes import ( # isort: split
PromptCollectionTestInvocation, PromptCollectionTestInvocation,
PromptTestInvocation, PromptTestInvocation,
TestEventService,
TextToImageTestInvocation, TextToImageTestInvocation,
) )
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from invokeai.app.invocations.collections import RangeInvocation from invokeai.app.invocations.collections import RangeInvocation
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
from invokeai.app.services.shared.graph import ( from invokeai.app.services.shared.graph import (
CollectInvocation, CollectInvocation,
Graph, Graph,
@ -38,39 +32,6 @@ def simple_graph() -> Graph:
return g return g
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
# the test invocations.
@pytest.fixture
def mock_services() -> InvocationServices:
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
# NOTE: none of these are actually called by the test invocations
return InvocationServices(
board_image_records=None, # type: ignore
board_images=None, # type: ignore
board_records=None, # type: ignore
boards=None, # type: ignore
bulk_download=None, # type: ignore
configuration=configuration,
events=TestEventService(),
image_files=None, # type: ignore
image_records=None, # type: ignore
images=None, # type: ignore
invocation_cache=MemoryInvocationCache(max_cache_size=0),
logger=logging, # type: ignore
model_manager=None, # type: ignore
download_queue=None, # type: ignore
names=None, # type: ignore
performance_statistics=InvocationStatsService(),
session_processor=None, # type: ignore
session_queue=None, # type: ignore
urls=None, # type: ignore
workflow_records=None, # type: ignore
tensors=None, # type: ignore
conditioning=None, # type: ignore
)
def invoke_next(g: GraphExecutionState) -> tuple[Optional[BaseInvocation], Optional[BaseInvocationOutput]]: def invoke_next(g: GraphExecutionState) -> tuple[Optional[BaseInvocation], Optional[BaseInvocationOutput]]:
n = g.next() n = g.next()
if n is None: if n is None:

View File

@ -1,5 +1,7 @@
from typing import Any, Callable, Union from typing import Any, Callable, Union
from pydantic import BaseModel
from invokeai.app.invocations.baseinvocation import ( from invokeai.app.invocations.baseinvocation import (
BaseInvocation, BaseInvocation,
BaseInvocationOutput, BaseInvocationOutput,
@ -115,25 +117,22 @@ def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edg
) )
class TestEvent: class TestEvent(BaseModel):
event_name: str
payload: Any
__test__ = False # not a pytest test case __test__ = False # not a pytest test case
def __init__(self, event_name: str, payload: Any): event_name: str
self.event_name = event_name payload: Any
self.payload = payload
class TestEventService(EventServiceBase): class TestEventService(EventServiceBase):
events: list
__test__ = False # not a pytest test case __test__ = False # not a pytest test case
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.events = [] self.events: list[TestEvent] = []
def dispatch(self, event_name: str, payload: Any) -> None: def dispatch(self, event_name: str, payload: Any) -> None:
self.events.append(TestEvent(event_name=payload["event"], payload=payload["data"]))
pass pass