test: clean up & fix tests

- Deduplicate the mock invocation services. This is possible now that the import order issue is resolved.
- Merge `DummyEventService` into `TestEventService` and update all tests to use `TestEventService`.
This commit is contained in:
psychedelicious 2024-02-20 18:49:55 +11:00
parent cbb997e7d0
commit 5cba55d670
7 changed files with 78 additions and 184 deletions

View File

@ -1,66 +1,17 @@
from pathlib import Path
from typing import Any
import pytest
from fastapi import BackgroundTasks
from fastapi.testclient import TestClient
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.api_app import app
from invokeai.app.services.board_records.board_records_common import BoardRecord
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.images.images_default import ImageService
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invoker import Invoker
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import create_mock_sqlite_database
client = TestClient(app)
@pytest.fixture
def mock_services(tmp_path: Path) -> InvocationServices:
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(configuration, logger)
return InvocationServices(
board_image_records=None, # type: ignore
board_images=None, # type: ignore
board_records=SqliteBoardRecordStorage(db=db),
boards=None, # type: ignore
bulk_download=BulkDownloadService(),
configuration=None, # type: ignore
events=None, # type: ignore
graph_execution_manager=None, # type: ignore
image_files=None, # type: ignore
image_records=None, # type: ignore
images=ImageService(),
invocation_cache=None, # type: ignore
latents=None, # type: ignore
logger=logger,
model_manager=None, # type: ignore
model_records=None, # type: ignore
download_queue=None, # type: ignore
model_install=None, # type: ignore
names=None, # type: ignore
performance_statistics=None, # type: ignore
processor=None, # type: ignore
queue=None, # type: ignore
session_processor=None, # type: ignore
session_queue=None, # type: ignore
urls=None, # type: ignore
workflow_records=None, # type: ignore
)
@pytest.fixture()
def mock_invoker(mock_services: InvocationServices) -> Invoker:
return Invoker(services=mock_services)
class MockApiDependencies(ApiDependencies):
invoker: Invoker

View File

@ -7,23 +7,17 @@ from zipfile import ZipFile
import pytest
from invokeai.app.services.board_records.board_records_common import BoardRecord, BoardRecordNotFoundException
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from invokeai.app.services.bulk_download.bulk_download_common import BulkDownloadTargetException
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageRecordNotFoundException,
ResourceOrigin,
)
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.images.images_default import ImageService
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.event_service import DummyEventService
from tests.fixtures.sqlite_database import create_mock_sqlite_database
from tests.test_nodes import TestEventService
@pytest.fixture
@ -46,53 +40,6 @@ def mock_image_dto() -> ImageDTO:
)
@pytest.fixture
def mock_event_service() -> DummyEventService:
"""Create a dummy event service."""
return DummyEventService()
@pytest.fixture
def mock_services(mock_event_service: DummyEventService) -> InvocationServices:
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(configuration, logger)
return InvocationServices(
board_image_records=None, # type: ignore
board_images=None, # type: ignore
board_records=SqliteBoardRecordStorage(db=db),
boards=None, # type: ignore
bulk_download=None, # type: ignore
configuration=None, # type: ignore
events=mock_event_service,
graph_execution_manager=None, # type: ignore
image_files=None, # type: ignore
image_records=None, # type: ignore
images=ImageService(),
invocation_cache=None, # type: ignore
latents=None, # type: ignore
logger=logger,
model_manager=None, # type: ignore
model_records=None, # type: ignore
download_queue=None, # type: ignore
model_install=None, # type: ignore
names=None, # type: ignore
performance_statistics=None, # type: ignore
processor=None, # type: ignore
queue=None, # type: ignore
session_processor=None, # type: ignore
session_queue=None, # type: ignore
urls=None, # type: ignore
workflow_records=None, # type: ignore
)
@pytest.fixture()
def mock_invoker(mock_services: InvocationServices) -> Invoker:
return Invoker(services=mock_services)
@pytest.fixture(autouse=True)
def mock_temporary_directory(monkeypatch: Any, tmp_path: Path):
"""Mock the TemporaryDirectory class so that it uses the tmp_path fixture."""
@ -288,6 +235,16 @@ def prepare_handler_test(tmp_path: Path, monkeypatch: Any, mock_image_dto: Image
monkeypatch.setattr(mock_invoker.services.images, "get_dto", mock_get_dto)
# This is used when preparing all images for a given board
def mock_get_all_board_image_names_for_board(*args, **kwargs):
return [mock_image_dto.image_name]
monkeypatch.setattr(
mock_invoker.services.board_image_records,
"get_all_board_image_names_for_board",
mock_get_all_board_image_names_for_board,
)
# Create a mock image file so that the contents of the zip file are not empty
mock_image_path: Path = tmp_path / mock_image_dto.image_name
mock_image_contents: str = "Totally an image"
@ -306,7 +263,7 @@ def assert_handler_success(
expected_image_path: Path,
mock_image_contents: str,
tmp_path: Path,
event_bus: DummyEventService,
event_bus: TestEventService,
):
"""Assert that the handler was successful."""
# Check that the zip file was created
@ -369,7 +326,7 @@ def test_handler_on_generic_exception(
with pytest.raises(Exception): # noqa: B017
execute_handler_test_on_error(tmp_path, monkeypatch, mock_image_dto, mock_invoker, exception)
event_bus: DummyEventService = mock_invoker.services.events
event_bus: TestEventService = mock_invoker.services.events
assert len(event_bus.events) == 2
assert event_bus.events[0].event_name == "bulk_download_started"
@ -384,7 +341,7 @@ def execute_handler_test_on_error(
bulk_download_service.start(mock_invoker)
bulk_download_service.handler([mock_image_dto.image_name], None, None)
event_bus: DummyEventService = mock_invoker.services.events
event_bus: TestEventService = mock_invoker.services.events
assert len(event_bus.events) == 2
assert event_bus.events[0].event_name == "bulk_download_started"

View File

@ -9,7 +9,7 @@ from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
from tests.fixtures.event_service import DummyEventService
from tests.test_nodes import TestEventService
# Prevent pytest deprecation warnings
TestAdapter.__test__ = False # type: ignore
@ -101,7 +101,7 @@ def test_errors(tmp_path: Path, session: Session) -> None:
def test_event_bus(tmp_path: Path, session: Session) -> None:
event_bus = DummyEventService()
event_bus = TestEventService()
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
queue.start()
@ -167,7 +167,7 @@ def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
def test_cancel(tmp_path: Path, session: Session) -> None:
event_bus = DummyEventService()
event_bus = TestEventService()
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
queue.start()

View File

@ -4,4 +4,57 @@
# We import the model_installer and torch_device fixtures here so that they can be used by all tests. Flake8 does not
# play well with fixtures (F401 and F811), so this is cleaner than importing in all files that use these fixtures.
from invokeai.backend.util.test_utils import torch_device # noqa: F401
import logging
import pytest
from invokeai.app.services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.images.images_default import ImageService
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
from invokeai.app.services.invoker import Invoker
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import create_mock_sqlite_database # noqa: F401
from tests.test_nodes import TestEventService
@pytest.fixture
def mock_services() -> InvocationServices:
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(configuration, logger)
# NOTE: none of these are actually called by the test invocations
return InvocationServices(
board_image_records=SqliteBoardImageRecordStorage(db=db),
board_images=None, # type: ignore
board_records=SqliteBoardRecordStorage(db=db),
boards=None, # type: ignore
bulk_download=BulkDownloadService(),
configuration=configuration,
events=TestEventService(),
image_files=None, # type: ignore
image_records=None, # type: ignore
images=ImageService(),
invocation_cache=MemoryInvocationCache(max_cache_size=0),
logger=logging, # type: ignore
model_manager=None, # type: ignore
download_queue=None, # type: ignore
names=None, # type: ignore
performance_statistics=InvocationStatsService(),
session_processor=None, # type: ignore
session_queue=None, # type: ignore
urls=None, # type: ignore
workflow_records=None, # type: ignore
tensors=None, # type: ignore
conditioning=None, # type: ignore
)
@pytest.fixture()
def mock_invoker(mock_services: InvocationServices) -> Invoker:
return Invoker(services=mock_services)

View File

@ -1,27 +0,0 @@
from typing import Any, Dict, List
from pydantic import BaseModel
from invokeai.app.services.events.events_base import EventServiceBase
class DummyEvent(BaseModel):
"""Dummy Event to use with Dummy Event service."""
event_name: str
payload: Dict[str, Any]
# A dummy event service for testing event issuing
class DummyEventService(EventServiceBase):
"""Dummy event service for testing."""
events: List[DummyEvent]
def __init__(self) -> None:
super().__init__()
self.events = []
def dispatch(self, event_name: str, payload: Any) -> None:
"""Dispatch an event by appending it to self.events."""
self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"]))

View File

@ -1,4 +1,3 @@
import logging
from typing import Optional
from unittest.mock import Mock
@ -8,17 +7,12 @@ import pytest
from .test_nodes import ( # isort: split
PromptCollectionTestInvocation,
PromptTestInvocation,
TestEventService,
TextToImageTestInvocation,
)
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from invokeai.app.invocations.collections import RangeInvocation
from invokeai.app.invocations.math import AddInvocation, MultiplyInvocation
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
from invokeai.app.services.shared.graph import (
CollectInvocation,
Graph,
@ -38,39 +32,6 @@ def simple_graph() -> Graph:
return g
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
# the test invocations.
@pytest.fixture
def mock_services() -> InvocationServices:
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
# NOTE: none of these are actually called by the test invocations
return InvocationServices(
board_image_records=None, # type: ignore
board_images=None, # type: ignore
board_records=None, # type: ignore
boards=None, # type: ignore
bulk_download=None, # type: ignore
configuration=configuration,
events=TestEventService(),
image_files=None, # type: ignore
image_records=None, # type: ignore
images=None, # type: ignore
invocation_cache=MemoryInvocationCache(max_cache_size=0),
logger=logging, # type: ignore
model_manager=None, # type: ignore
download_queue=None, # type: ignore
names=None, # type: ignore
performance_statistics=InvocationStatsService(),
session_processor=None, # type: ignore
session_queue=None, # type: ignore
urls=None, # type: ignore
workflow_records=None, # type: ignore
tensors=None, # type: ignore
conditioning=None, # type: ignore
)
def invoke_next(g: GraphExecutionState) -> tuple[Optional[BaseInvocation], Optional[BaseInvocationOutput]]:
n = g.next()
if n is None:

View File

@ -1,5 +1,7 @@
from typing import Any, Callable, Union
from pydantic import BaseModel
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -115,25 +117,22 @@ def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edg
)
class TestEvent:
event_name: str
payload: Any
class TestEvent(BaseModel):
__test__ = False # not a pytest test case
def __init__(self, event_name: str, payload: Any):
self.event_name = event_name
self.payload = payload
event_name: str
payload: Any
class TestEventService(EventServiceBase):
events: list
__test__ = False # not a pytest test case
def __init__(self):
super().__init__()
self.events = []
self.events: list[TestEvent] = []
def dispatch(self, event_name: str, payload: Any) -> None:
self.events.append(TestEvent(event_name=payload["event"], payload=payload["data"]))
pass