tests: update tests to use new events

This commit is contained in:
psychedelicious 2024-03-14 18:51:17 +11:00
parent 655f62008f
commit a876675448
5 changed files with 88 additions and 80 deletions

View File

@ -9,6 +9,11 @@ import pytest
from invokeai.app.services.board_records.board_records_common import BoardRecord, BoardRecordNotFoundException from invokeai.app.services.board_records.board_records_common import BoardRecord, BoardRecordNotFoundException
from invokeai.app.services.bulk_download.bulk_download_common import BulkDownloadTargetException from invokeai.app.services.bulk_download.bulk_download_common import BulkDownloadTargetException
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
from invokeai.app.services.events.events_common import (
BulkDownloadCompleteEvent,
BulkDownloadErrorEvent,
BulkDownloadStartedEvent,
)
from invokeai.app.services.image_records.image_records_common import ( from invokeai.app.services.image_records.image_records_common import (
ImageCategory, ImageCategory,
ImageRecordNotFoundException, ImageRecordNotFoundException,
@ -281,9 +286,9 @@ def assert_handler_success(
# Check that the correct events were emitted # Check that the correct events were emitted
assert len(event_bus.events) == 2 assert len(event_bus.events) == 2
assert event_bus.events[0].event_name == "bulk_download_started" assert isinstance(event_bus.events[0], BulkDownloadStartedEvent)
assert event_bus.events[1].event_name == "bulk_download_completed" assert isinstance(event_bus.events[1], BulkDownloadCompleteEvent)
assert event_bus.events[1].payload["bulk_download_item_name"] == os.path.basename(expected_zip_path) assert event_bus.events[1].bulk_download_item_name == os.path.basename(expected_zip_path)
def test_handler_on_image_not_found(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker): def test_handler_on_image_not_found(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker):
@ -329,9 +334,9 @@ def test_handler_on_generic_exception(
event_bus: TestEventService = mock_invoker.services.events event_bus: TestEventService = mock_invoker.services.events
assert len(event_bus.events) == 2 assert len(event_bus.events) == 2
assert event_bus.events[0].event_name == "bulk_download_started" assert isinstance(event_bus.events[0], BulkDownloadStartedEvent)
assert event_bus.events[1].event_name == "bulk_download_failed" assert isinstance(event_bus.events[1], BulkDownloadErrorEvent)
assert event_bus.events[1].payload["error"] == exception.__str__() assert event_bus.events[1].error == exception.__str__()
def execute_handler_test_on_error( def execute_handler_test_on_error(
@ -344,9 +349,9 @@ def execute_handler_test_on_error(
event_bus: TestEventService = mock_invoker.services.events event_bus: TestEventService = mock_invoker.services.events
assert len(event_bus.events) == 2 assert len(event_bus.events) == 2
assert event_bus.events[0].event_name == "bulk_download_started" assert isinstance(event_bus.events[0], BulkDownloadStartedEvent)
assert event_bus.events[1].event_name == "bulk_download_failed" assert isinstance(event_bus.events[1], BulkDownloadErrorEvent)
assert event_bus.events[1].payload["error"] == error.__str__() assert event_bus.events[1].error == error.__str__()
def test_delete(tmp_path: Path): def test_delete(tmp_path: Path):

View File

@ -10,6 +10,13 @@ from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession from requests_testadapter import TestAdapter, TestSession
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
from invokeai.app.services.events.events_common import (
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
)
from tests.test_nodes import TestEventService from tests.test_nodes import TestEventService
# Prevent pytest deprecation warnings # Prevent pytest deprecation warnings
@ -116,14 +123,22 @@ def test_event_bus(tmp_path: Path, session: Session) -> None:
queue.join() queue.join()
events = event_bus.events events = event_bus.events
assert len(events) == 3 assert len(events) == 3
assert events[0].payload["timestamp"] <= events[1].payload["timestamp"] assert isinstance(events[0], DownloadStartedEvent)
assert events[1].payload["timestamp"] <= events[2].payload["timestamp"] assert isinstance(events[1], DownloadProgressEvent)
assert events[0].event_name == "download_started" assert isinstance(events[2], DownloadCompleteEvent)
assert events[1].event_name == "download_progress" assert events[0].timestamp <= events[1].timestamp
assert events[1].payload["total_bytes"] > 0 assert events[1].timestamp <= events[2].timestamp
assert events[1].payload["current_bytes"] <= events[1].payload["total_bytes"] assert events[1].total_bytes > 0
assert events[2].event_name == "download_complete" assert events[1].current_bytes <= events[1].total_bytes
assert events[2].payload["total_bytes"] == 32029 assert events[2].total_bytes == 32029
# assert events[0].payload["timestamp"] <= events[1].payload["timestamp"]
# assert events[1].payload["timestamp"] <= events[2].payload["timestamp"]
# assert events[0].event_name == "download_started"
# assert events[1].event_name == "download_progress"
# assert events[1].payload["total_bytes"] > 0
# assert events[1].payload["current_bytes"] <= events[1].payload["total_bytes"]
# assert events[2].event_name == "download_complete"
# assert events[2].payload["total_bytes"] == 32029
# test a failure # test a failure
event_bus.events = [] # reset our accumulator event_bus.events = [] # reset our accumulator
@ -132,10 +147,15 @@ def test_event_bus(tmp_path: Path, session: Session) -> None:
events = event_bus.events events = event_bus.events
print("\n".join([x.model_dump_json() for x in events])) print("\n".join([x.model_dump_json() for x in events]))
assert len(events) == 1 assert len(events) == 1
assert events[0].event_name == "download_error" assert isinstance(events[0], DownloadErrorEvent)
assert events[0].payload["error_type"] == "HTTPError(NOT FOUND)" assert events[0].error_type == "HTTPError(NOT FOUND)"
assert events[0].payload["error"] is not None assert events[0].error is not None
assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].payload["error"]) assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].error)
# assert events[0].event_name == "download_error"
# assert events[0].payload["error_type"] == "HTTPError(NOT FOUND)"
# assert events[0].payload["error"] is not None
# assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].payload["error"])
queue.stop() queue.stop()
@ -202,6 +222,8 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
assert job.status == DownloadJobStatus.CANCELLED assert job.status == DownloadJobStatus.CANCELLED
assert cancelled assert cancelled
events = event_bus.events events = event_bus.events
assert events[-1].event_name == "download_cancelled" assert isinstance(events[-1], DownloadCancelledEvent)
assert events[-1].payload["source"] == "http://www.civitai.com/models/12345" assert events[-1].source == "http://www.civitai.com/models/12345"
# assert events[-1].event_name == "download_cancelled"
# assert events[-1].payload["source"] == "http://www.civitai.com/models/12345"
queue.stop() queue.stop()

View File

@ -13,6 +13,12 @@ from pydantic_core import Url
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.events.events_common import (
ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallStartedEvent,
)
from invokeai.app.services.model_install import ( from invokeai.app.services.model_install import (
ModelInstallServiceBase, ModelInstallServiceBase,
) )
@ -25,6 +31,7 @@ from invokeai.app.services.model_install.model_install_common import (
from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException
from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType from invokeai.backend.model_manager.config import BaseModelType, InvalidModelConfigException, ModelFormat, ModelType
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
from tests.test_nodes import TestEventService
OS = platform.uname().system OS = platform.uname().system
@ -132,19 +139,26 @@ def test_background_install(
assert job.total_bytes == size assert job.total_bytes == size
# test that the expected events were issued # test that the expected events were issued
bus = mm2_installer.event_bus bus: TestEventService = mm2_installer.event_bus
assert bus assert bus
assert hasattr(bus, "events") assert hasattr(bus, "events")
assert len(bus.events) == 2 assert len(bus.events) == 2
event_names = [x.event_name for x in bus.events] assert isinstance(bus.events[0], ModelInstallStartedEvent)
assert "model_install_running" in event_names assert isinstance(bus.events[1], ModelInstallCompleteEvent)
assert "model_install_completed" in event_names assert Path(bus.events[0].source) == source
assert Path(bus.events[0].payload["source"]) == source assert Path(bus.events[1].source) == source
assert Path(bus.events[1].payload["source"]) == source key = bus.events[1].key
key = bus.events[1].payload["key"]
assert key is not None assert key is not None
# event_names = [x.event_name for x in bus.events]
# assert "model_install_running" in event_names
# assert "model_install_completed" in event_names
# assert Path(bus.events[0].payload["source"]) == source
# assert Path(bus.events[1].payload["source"]) == source
# key = bus.events[1].payload["key"]
# assert key is not None
# see if the thing actually got installed at the expected location # see if the thing actually got installed at the expected location
model_record = mm2_installer.record_store.get_model(key) model_record = mm2_installer.record_store.get_model(key)
assert model_record is not None assert model_record is not None
@ -221,7 +235,7 @@ def test_delete_register(
def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors")) source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors"))
bus = mm2_installer.event_bus bus: TestEventService = mm2_installer.event_bus
store = mm2_installer.record_store store = mm2_installer.record_store
assert store is not None assert store is not None
assert bus is not None assert bus is not None
@ -239,20 +253,17 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
assert (mm2_app_config.models_path / model_record.path).exists() assert (mm2_app_config.models_path / model_record.path).exists()
assert len(bus.events) == 4 assert len(bus.events) == 4
event_names = [x.event_name for x in bus.events] assert isinstance(bus.events[0], ModelInstallDownloadProgressEvent)
assert event_names == [ assert isinstance(bus.events[1], ModelInstallDownloadsCompleteEvent)
"model_install_downloading", assert isinstance(bus.events[2], ModelInstallStartedEvent)
"model_install_downloads_done", assert isinstance(bus.events[3], ModelInstallCompleteEvent)
"model_install_running",
"model_install_completed",
]
@pytest.mark.timeout(timeout=20, method="thread") @pytest.mark.timeout(timeout=20, method="thread")
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo")) source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
bus = mm2_installer.event_bus bus: TestEventService = mm2_installer.event_bus
store = mm2_installer.record_store store = mm2_installer.record_store
assert isinstance(bus, EventServiceBase) assert isinstance(bus, EventServiceBase)
assert store is not None assert store is not None
@ -269,15 +280,10 @@ def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_co
assert model_record.type == ModelType.Main assert model_record.type == ModelType.Main
assert model_record.format == ModelFormat.Diffusers assert model_record.format == ModelFormat.Diffusers
assert hasattr(bus, "events") # the dummyeventservice has this assert any(isinstance(x, ModelInstallStartedEvent) for x in bus.events)
assert any(isinstance(x, ModelInstallDownloadProgressEvent) for x in bus.events)
assert any(isinstance(x, ModelInstallCompleteEvent) for x in bus.events)
assert len(bus.events) >= 3 assert len(bus.events) >= 3
event_names = {x.event_name for x in bus.events}
assert event_names == {
"model_install_downloading",
"model_install_downloads_done",
"model_install_running",
"model_install_completed",
}
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:

View File

@ -3,16 +3,13 @@
import os import os
import shutil import shutil
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List
import pytest import pytest
from pydantic import BaseModel
from requests.sessions import Session from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession from requests_testadapter import TestAdapter, TestSession
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueService, DownloadQueueServiceBase from invokeai.app.services.download import DownloadQueueService, DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase
from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase
@ -39,27 +36,7 @@ from tests.backend.model_manager.model_metadata.metadata_examples import (
RepoHFModelJson1, RepoHFModelJson1,
) )
from tests.fixtures.sqlite_database import create_mock_sqlite_database from tests.fixtures.sqlite_database import create_mock_sqlite_database
from tests.test_nodes import TestEventService
class DummyEvent(BaseModel):
"""Dummy Event to use with Dummy Event service."""
event_name: str
payload: Dict[str, Any]
class DummyEventService(EventServiceBase):
"""Dummy event service for testing."""
events: List[DummyEvent]
def __init__(self) -> None:
super().__init__()
self.events = []
def dispatch(self, event_name: str, payload: Any) -> None:
"""Dispatch an event by appending it to self.events."""
self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"]))
# Create a temporary directory using the contents of `./data/invokeai_root` as the template # Create a temporary directory using the contents of `./data/invokeai_root` as the template
@ -127,7 +104,7 @@ def mm2_installer(
) -> ModelInstallServiceBase: ) -> ModelInstallServiceBase:
logger = InvokeAILogger.get_logger() logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(mm2_app_config, logger) db = create_mock_sqlite_database(mm2_app_config, logger)
events = DummyEventService() events = TestEventService()
store = ModelRecordServiceSQL(db) store = ModelRecordServiceSQL(db)
installer = ModelInstallService( installer = ModelInstallService(

View File

@ -1,7 +1,5 @@
from typing import Any, Callable, Union from typing import Any, Callable, Union
from pydantic import BaseModel
from invokeai.app.invocations.baseinvocation import ( from invokeai.app.invocations.baseinvocation import (
BaseInvocation, BaseInvocation,
BaseInvocationOutput, BaseInvocationOutput,
@ -10,6 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
) )
from invokeai.app.invocations.fields import InputField, OutputField from invokeai.app.invocations.fields import InputField, OutputField
from invokeai.app.invocations.image import ImageField from invokeai.app.invocations.image import ImageField
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
@ -117,11 +116,10 @@ def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edg
) )
class TestEvent(BaseModel): class TestEvent(EventBase):
__test__ = False # not a pytest test case __test__ = False # not a pytest test case
event_name: str __event_name__ = "test_event"
payload: Any
class TestEventService(EventServiceBase): class TestEventService(EventServiceBase):
@ -129,10 +127,10 @@ class TestEventService(EventServiceBase):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.events: list[TestEvent] = [] self.events: list[EventBase] = []
def dispatch(self, event_name: str, payload: Any) -> None: def dispatch(self, event: EventBase) -> None:
self.events.append(TestEvent(event_name=payload["event"], payload=payload["data"])) self.events.append(event)
pass pass