test: clean up & fix tests

- Deduplicate the mock invocation services. This is possible now that the import order issue is resolved.
- Merge `DummyEventService` into `TestEventService` and update all tests to use `TestEventService`.
This commit is contained in:
psychedelicious
2024-02-20 18:49:55 +11:00
parent 7f75f6226b
commit 5f64ed5bd5
7 changed files with 78 additions and 184 deletions

View File

@ -1,66 +1,17 @@
from pathlib import Path
from typing import Any
import pytest
from fastapi import BackgroundTasks
from fastapi.testclient import TestClient
from invokeai.app.api.dependencies import ApiDependencies
from invokeai.app.api_app import app
from invokeai.app.services.board_records.board_records_common import BoardRecord
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.images.images_default import ImageService
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invoker import Invoker
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import create_mock_sqlite_database
client = TestClient(app)
@pytest.fixture
def mock_services(tmp_path: Path) -> InvocationServices:
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(configuration, logger)
return InvocationServices(
board_image_records=None, # type: ignore
board_images=None, # type: ignore
board_records=SqliteBoardRecordStorage(db=db),
boards=None, # type: ignore
bulk_download=BulkDownloadService(),
configuration=None, # type: ignore
events=None, # type: ignore
graph_execution_manager=None, # type: ignore
image_files=None, # type: ignore
image_records=None, # type: ignore
images=ImageService(),
invocation_cache=None, # type: ignore
latents=None, # type: ignore
logger=logger,
model_manager=None, # type: ignore
model_records=None, # type: ignore
download_queue=None, # type: ignore
model_install=None, # type: ignore
names=None, # type: ignore
performance_statistics=None, # type: ignore
processor=None, # type: ignore
queue=None, # type: ignore
session_processor=None, # type: ignore
session_queue=None, # type: ignore
urls=None, # type: ignore
workflow_records=None, # type: ignore
)
@pytest.fixture()
def mock_invoker(mock_services: InvocationServices) -> Invoker:
return Invoker(services=mock_services)
class MockApiDependencies(ApiDependencies):
invoker: Invoker

View File

@ -7,23 +7,17 @@ from zipfile import ZipFile
import pytest
from invokeai.app.services.board_records.board_records_common import BoardRecord, BoardRecordNotFoundException
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from invokeai.app.services.bulk_download.bulk_download_common import BulkDownloadTargetException
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageRecordNotFoundException,
ResourceOrigin,
)
from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.images.images_default import ImageService
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.event_service import DummyEventService
from tests.fixtures.sqlite_database import create_mock_sqlite_database
from tests.test_nodes import TestEventService
@pytest.fixture
@ -46,53 +40,6 @@ def mock_image_dto() -> ImageDTO:
)
@pytest.fixture
def mock_event_service() -> DummyEventService:
"""Create a dummy event service."""
return DummyEventService()
@pytest.fixture
def mock_services(mock_event_service: DummyEventService) -> InvocationServices:
configuration = InvokeAIAppConfig(use_memory_db=True, node_cache_size=0)
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(configuration, logger)
return InvocationServices(
board_image_records=None, # type: ignore
board_images=None, # type: ignore
board_records=SqliteBoardRecordStorage(db=db),
boards=None, # type: ignore
bulk_download=None, # type: ignore
configuration=None, # type: ignore
events=mock_event_service,
graph_execution_manager=None, # type: ignore
image_files=None, # type: ignore
image_records=None, # type: ignore
images=ImageService(),
invocation_cache=None, # type: ignore
latents=None, # type: ignore
logger=logger,
model_manager=None, # type: ignore
model_records=None, # type: ignore
download_queue=None, # type: ignore
model_install=None, # type: ignore
names=None, # type: ignore
performance_statistics=None, # type: ignore
processor=None, # type: ignore
queue=None, # type: ignore
session_processor=None, # type: ignore
session_queue=None, # type: ignore
urls=None, # type: ignore
workflow_records=None, # type: ignore
)
@pytest.fixture()
def mock_invoker(mock_services: InvocationServices) -> Invoker:
return Invoker(services=mock_services)
@pytest.fixture(autouse=True)
def mock_temporary_directory(monkeypatch: Any, tmp_path: Path):
"""Mock the TemporaryDirectory class so that it uses the tmp_path fixture."""
@ -288,6 +235,16 @@ def prepare_handler_test(tmp_path: Path, monkeypatch: Any, mock_image_dto: Image
monkeypatch.setattr(mock_invoker.services.images, "get_dto", mock_get_dto)
# This is used when preparing all images for a given board
def mock_get_all_board_image_names_for_board(*args, **kwargs):
return [mock_image_dto.image_name]
monkeypatch.setattr(
mock_invoker.services.board_image_records,
"get_all_board_image_names_for_board",
mock_get_all_board_image_names_for_board,
)
# Create a mock image file so that the contents of the zip file are not empty
mock_image_path: Path = tmp_path / mock_image_dto.image_name
mock_image_contents: str = "Totally an image"
@ -306,7 +263,7 @@ def assert_handler_success(
expected_image_path: Path,
mock_image_contents: str,
tmp_path: Path,
event_bus: DummyEventService,
event_bus: TestEventService,
):
"""Assert that the handler was successful."""
# Check that the zip file was created
@ -369,7 +326,7 @@ def test_handler_on_generic_exception(
with pytest.raises(Exception): # noqa: B017
execute_handler_test_on_error(tmp_path, monkeypatch, mock_image_dto, mock_invoker, exception)
event_bus: DummyEventService = mock_invoker.services.events
event_bus: TestEventService = mock_invoker.services.events
assert len(event_bus.events) == 2
assert event_bus.events[0].event_name == "bulk_download_started"
@ -384,7 +341,7 @@ def execute_handler_test_on_error(
bulk_download_service.start(mock_invoker)
bulk_download_service.handler([mock_image_dto.image_name], None, None)
event_bus: DummyEventService = mock_invoker.services.events
event_bus: TestEventService = mock_invoker.services.events
assert len(event_bus.events) == 2
assert event_bus.events[0].event_name == "bulk_download_started"

View File

@ -9,7 +9,7 @@ from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
from tests.fixtures.event_service import DummyEventService
from tests.test_nodes import TestEventService
# Prevent pytest deprecation warnings
TestAdapter.__test__ = False # type: ignore
@ -101,7 +101,7 @@ def test_errors(tmp_path: Path, session: Session) -> None:
def test_event_bus(tmp_path: Path, session: Session) -> None:
event_bus = DummyEventService()
event_bus = TestEventService()
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
queue.start()
@ -167,7 +167,7 @@ def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
def test_cancel(tmp_path: Path, session: Session) -> None:
event_bus = DummyEventService()
event_bus = TestEventService()
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
queue.start()