diff --git a/tests/app/services/bulk_download/test_bulk_download.py b/tests/app/services/bulk_download/test_bulk_download.py index b18f6e038d..bf3bd27993 100644 --- a/tests/app/services/bulk_download/test_bulk_download.py +++ b/tests/app/services/bulk_download/test_bulk_download.py @@ -9,6 +9,11 @@ import pytest from invokeai.app.services.board_records.board_records_common import BoardRecord, BoardRecordNotFoundException from invokeai.app.services.bulk_download.bulk_download_common import BulkDownloadTargetException from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService +from invokeai.app.services.events.events_common import ( + BulkDownloadCompleteEvent, + BulkDownloadErrorEvent, + BulkDownloadStartedEvent, +) from invokeai.app.services.image_records.image_records_common import ( ImageCategory, ImageRecordNotFoundException, @@ -281,9 +286,9 @@ def assert_handler_success( # Check that the correct events were emitted assert len(event_bus.events) == 2 - assert event_bus.events[0].event_name == "bulk_download_started" - assert event_bus.events[1].event_name == "bulk_download_completed" - assert event_bus.events[1].payload["bulk_download_item_name"] == os.path.basename(expected_zip_path) + assert isinstance(event_bus.events[0], BulkDownloadStartedEvent) + assert isinstance(event_bus.events[1], BulkDownloadCompleteEvent) + assert event_bus.events[1].bulk_download_item_name == os.path.basename(expected_zip_path) def test_handler_on_image_not_found(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker): @@ -329,9 +334,9 @@ def test_handler_on_generic_exception( event_bus: TestEventService = mock_invoker.services.events assert len(event_bus.events) == 2 - assert event_bus.events[0].event_name == "bulk_download_started" - assert event_bus.events[1].event_name == "bulk_download_failed" - assert event_bus.events[1].payload["error"] == exception.__str__() + assert isinstance(event_bus.events[0], BulkDownloadStartedEvent) + assert isinstance(event_bus.events[1], BulkDownloadErrorEvent) + assert event_bus.events[1].error == exception.__str__() def execute_handler_test_on_error( @@ -344,9 +349,9 @@ def execute_handler_test_on_error( event_bus: TestEventService = mock_invoker.services.events assert len(event_bus.events) == 2 - assert event_bus.events[0].event_name == "bulk_download_started" - assert event_bus.events[1].event_name == "bulk_download_failed" - assert event_bus.events[1].payload["error"] == error.__str__() + assert isinstance(event_bus.events[0], BulkDownloadStartedEvent) + assert isinstance(event_bus.events[1], BulkDownloadErrorEvent) + assert event_bus.events[1].error == error.__str__() def test_delete(tmp_path: Path): diff --git a/tests/app/services/download/test_download_queue.py b/tests/app/services/download/test_download_queue.py index 307238fd61..98fed861ae 100644 --- a/tests/app/services/download/test_download_queue.py +++ b/tests/app/services/download/test_download_queue.py @@ -10,6 +10,13 @@ from requests.sessions import Session from requests_testadapter import TestAdapter, TestSession from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService +from invokeai.app.services.events.events_common import ( + DownloadCancelledEvent, + DownloadCompleteEvent, + DownloadErrorEvent, + DownloadProgressEvent, + DownloadStartedEvent, +) from tests.test_nodes import TestEventService # Prevent pytest deprecation warnings @@ -116,14 +123,22 @@ def test_event_bus(tmp_path: Path, session: Session) -> None: queue.join() events = event_bus.events assert len(events) == 3 - assert events[0].payload["timestamp"] <= events[1].payload["timestamp"] - assert events[1].payload["timestamp"] <= events[2].payload["timestamp"] - assert events[0].event_name == "download_started" - assert events[1].event_name == "download_progress" - assert events[1].payload["total_bytes"] > 0 - assert events[1].payload["current_bytes"] <= events[1].payload["total_bytes"] - assert events[2].event_name == "download_complete" - assert events[2].payload["total_bytes"] == 32029 + assert isinstance(events[0], DownloadStartedEvent) + assert isinstance(events[1], DownloadProgressEvent) + assert isinstance(events[2], DownloadCompleteEvent) + assert events[0].timestamp <= events[1].timestamp + assert events[1].timestamp <= events[2].timestamp + assert events[1].total_bytes > 0 + assert events[1].current_bytes <= events[1].total_bytes + assert events[2].total_bytes == 32029 + # assert events[0].payload["timestamp"] <= events[1].payload["timestamp"] + # assert events[1].payload["timestamp"] <= events[2].payload["timestamp"] + # assert events[0].event_name == "download_started" + # assert events[1].event_name == "download_progress" + # assert events[1].payload["total_bytes"] > 0 + # assert events[1].payload["current_bytes"] <= events[1].payload["total_bytes"] + # assert events[2].event_name == "download_complete" + # assert events[2].payload["total_bytes"] == 32029 # test a failure event_bus.events = [] # reset our accumulator @@ -132,10 +147,15 @@ def test_event_bus(tmp_path: Path, session: Session) -> None: events = event_bus.events print("\n".join([x.model_dump_json() for x in events])) assert len(events) == 1 - assert events[0].event_name == "download_error" - assert events[0].payload["error_type"] == "HTTPError(NOT FOUND)" - assert events[0].payload["error"] is not None - assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].payload["error"]) + assert isinstance(events[0], DownloadErrorEvent) + assert events[0].error_type == "HTTPError(NOT FOUND)" + assert events[0].error is not None + assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].error) + + # assert events[0].event_name == "download_error" + # assert events[0].payload["error_type"] == "HTTPError(NOT FOUND)" + # assert events[0].payload["error"] is not None + # assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].payload["error"]) queue.stop() @@ -202,6 +222,8 @@ def test_cancel(tmp_path: Path, session: Session) -> None: assert job.status == DownloadJobStatus.CANCELLED assert cancelled events = event_bus.events - assert events[-1].event_name == "download_cancelled" - assert events[-1].payload["source"] == "http://www.civitai.com/models/12345" + assert isinstance(events[-1], DownloadCancelledEvent) + assert events[-1].source == "http://www.civitai.com/models/12345" + # assert events[-1].event_name == "download_cancelled" + # assert events[-1].payload["source"] == "http://www.civitai.com/models/12345" queue.stop() diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index cf7ebe8d29..b9e6415c46 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -13,6 +13,12 @@ from pydantic_core import Url from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.events.events_base import EventServiceBase +from invokeai.app.services.events.events_common import ( + ModelInstallCompleteEvent, + ModelInstallDownloadProgressEvent, + ModelInstallDownloadsCompleteEvent, + ModelInstallStartedEvent, +) from invokeai.app.services.model_install import ( ModelInstallServiceBase, ) @@ -25,6 +31,7 @@ from invokeai.app.services.model_install.model_install_common import ( from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 +from tests.test_nodes import TestEventService OS = platform.uname().system @@ -132,19 +139,26 @@ def test_background_install( assert job.total_bytes == size # test that the expected events were issued - bus = mm2_installer.event_bus + bus: TestEventService = mm2_installer.event_bus assert bus assert hasattr(bus, "events") assert len(bus.events) == 2 - event_names = [x.event_name for x in bus.events] - assert "model_install_running" in event_names - assert "model_install_completed" in event_names - assert Path(bus.events[0].payload["source"]) == source - assert Path(bus.events[1].payload["source"]) == source - key = bus.events[1].payload["key"] + assert isinstance(bus.events[0], ModelInstallStartedEvent) + assert isinstance(bus.events[1], ModelInstallCompleteEvent) + assert Path(bus.events[0].source) == source + assert Path(bus.events[1].source) == source + key = bus.events[1].key assert key is not None + # event_names = [x.event_name for x in bus.events] + # assert "model_install_running" in event_names + # assert "model_install_completed" in event_names + # assert Path(bus.events[0].payload["source"]) == source + # assert Path(bus.events[1].payload["source"]) == source + # key = bus.events[1].payload["key"] + # assert key is not None + # see if the thing actually got installed at the expected location model_record = mm2_installer.record_store.get_model(key) assert model_record is not None @@ -221,7 +235,7 @@ def test_delete_register( def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors")) - bus = mm2_installer.event_bus + bus: TestEventService = mm2_installer.event_bus store = mm2_installer.record_store assert store is not None assert bus is not None @@ -239,20 +253,17 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: assert (mm2_app_config.models_path / model_record.path).exists() assert len(bus.events) == 4 - event_names = [x.event_name for x in bus.events] - assert event_names == [ - "model_install_downloading", - "model_install_downloads_done", - "model_install_running", - "model_install_completed", - ] + assert isinstance(bus.events[0], ModelInstallDownloadProgressEvent) + assert isinstance(bus.events[1], ModelInstallDownloadsCompleteEvent) + assert isinstance(bus.events[2], ModelInstallStartedEvent) + assert isinstance(bus.events[3], ModelInstallCompleteEvent) @pytest.mark.timeout(timeout=20, method="thread") def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo")) - bus = mm2_installer.event_bus + bus: TestEventService = mm2_installer.event_bus store = mm2_installer.record_store assert isinstance(bus, EventServiceBase) assert store is not None @@ -269,15 +280,10 @@ def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_co assert model_record.type == ModelType.Main assert model_record.format == ModelFormat.Diffusers - assert hasattr(bus, "events") # the dummyeventservice has this + assert any(isinstance(x, ModelInstallStartedEvent) for x in bus.events) + assert any(isinstance(x, ModelInstallDownloadProgressEvent) for x in bus.events) + assert any(isinstance(x, ModelInstallCompleteEvent) for x in bus.events) assert len(bus.events) >= 3 - event_names = {x.event_name for x in bus.events} - assert event_names == { - "model_install_downloading", - "model_install_downloads_done", - "model_install_running", - "model_install_completed", - } def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: diff --git a/tests/backend/model_manager/model_manager_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py index 980f6ea17b..5ddccd05bb 100644 --- a/tests/backend/model_manager/model_manager_fixtures.py +++ b/tests/backend/model_manager/model_manager_fixtures.py @@ -3,16 +3,13 @@ import os import shutil from pathlib import Path -from typing import Any, Dict, List import pytest -from pydantic import BaseModel from requests.sessions import Session from requests_testadapter import TestAdapter, TestSession from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.download import DownloadQueueService, DownloadQueueServiceBase -from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase @@ -39,27 +36,7 @@ from tests.backend.model_manager.model_metadata.metadata_examples import ( RepoHFModelJson1, ) from tests.fixtures.sqlite_database import create_mock_sqlite_database - - -class DummyEvent(BaseModel): - """Dummy Event to use with Dummy Event service.""" - - event_name: str - payload: Dict[str, Any] - - -class DummyEventService(EventServiceBase): - """Dummy event service for testing.""" - - events: List[DummyEvent] - - def __init__(self) -> None: - super().__init__() - self.events = [] - - def dispatch(self, event_name: str, payload: Any) -> None: - """Dispatch an event by appending it to self.events.""" - self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"])) +from tests.test_nodes import TestEventService # Create a temporary directory using the contents of `./data/invokeai_root` as the template @@ -127,7 +104,7 @@ def mm2_installer( ) -> ModelInstallServiceBase: logger = InvokeAILogger.get_logger() db = create_mock_sqlite_database(mm2_app_config, logger) - events = DummyEventService() + events = TestEventService() store = ModelRecordServiceSQL(db) installer = ModelInstallService( diff --git a/tests/test_nodes.py b/tests/test_nodes.py index e1fe857040..2d413a2687 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -1,7 +1,5 @@ from typing import Any, Callable, Union -from pydantic import BaseModel - from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -10,6 +8,7 @@ from invokeai.app.invocations.baseinvocation import ( ) from invokeai.app.invocations.fields import InputField, OutputField from invokeai.app.invocations.image import ImageField +from invokeai.app.services.events.events_common import EventBase from invokeai.app.services.shared.invocation_context import InvocationContext @@ -117,11 +116,10 @@ def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edg ) -class TestEvent(BaseModel): +class TestEvent(EventBase): __test__ = False # not a pytest test case - event_name: str - payload: Any + __event_name__ = "test_event" class TestEventService(EventServiceBase): @@ -129,10 +127,10 @@ class TestEventService(EventServiceBase): def __init__(self): super().__init__() - self.events: list[TestEvent] = [] + self.events: list[EventBase] = [] - def dispatch(self, event_name: str, payload: Any) -> None: - self.events.append(TestEvent(event_name=payload["event"], payload=payload["data"])) + def dispatch(self, event: EventBase) -> None: + self.events.append(event) pass