merge with main and resolve conflicts

This commit is contained in:
Lincoln Stein
2024-05-27 22:20:34 -04:00
256 changed files with 9360 additions and 6061 deletions

View File

@ -9,6 +9,11 @@ import pytest
from invokeai.app.services.board_records.board_records_common import BoardRecord, BoardRecordNotFoundException
from invokeai.app.services.bulk_download.bulk_download_common import BulkDownloadTargetException
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
from invokeai.app.services.events.events_common import (
BulkDownloadCompleteEvent,
BulkDownloadErrorEvent,
BulkDownloadStartedEvent,
)
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageRecordNotFoundException,
@ -281,9 +286,9 @@ def assert_handler_success(
# Check that the correct events were emitted
assert len(event_bus.events) == 2
assert event_bus.events[0].event_name == "bulk_download_started"
assert event_bus.events[1].event_name == "bulk_download_completed"
assert event_bus.events[1].payload["bulk_download_item_name"] == os.path.basename(expected_zip_path)
assert isinstance(event_bus.events[0], BulkDownloadStartedEvent)
assert isinstance(event_bus.events[1], BulkDownloadCompleteEvent)
assert event_bus.events[1].bulk_download_item_name == os.path.basename(expected_zip_path)
def test_handler_on_image_not_found(tmp_path: Path, monkeypatch: Any, mock_image_dto: ImageDTO, mock_invoker: Invoker):
@ -329,9 +334,9 @@ def test_handler_on_generic_exception(
event_bus: TestEventService = mock_invoker.services.events
assert len(event_bus.events) == 2
assert event_bus.events[0].event_name == "bulk_download_started"
assert event_bus.events[1].event_name == "bulk_download_failed"
assert event_bus.events[1].payload["error"] == exception.__str__()
assert isinstance(event_bus.events[0], BulkDownloadStartedEvent)
assert isinstance(event_bus.events[1], BulkDownloadErrorEvent)
assert event_bus.events[1].error == exception.__str__()
def execute_handler_test_on_error(
@ -344,9 +349,9 @@ def execute_handler_test_on_error(
event_bus: TestEventService = mock_invoker.services.events
assert len(event_bus.events) == 2
assert event_bus.events[0].event_name == "bulk_download_started"
assert event_bus.events[1].event_name == "bulk_download_failed"
assert event_bus.events[1].payload["error"] == error.__str__()
assert isinstance(event_bus.events[0], BulkDownloadStartedEvent)
assert isinstance(event_bus.events[1], BulkDownloadErrorEvent)
assert event_bus.events[1].error == error.__str__()
def test_delete(tmp_path: Path):

View File

@ -14,6 +14,13 @@ from requests_testadapter import TestAdapter
from invokeai.app.services.config import get_config
from invokeai.app.services.config.config_default import URLRegexTokenPair
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService, MultiFileDownloadJob
from invokeai.app.services.events.events_common import (
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
)
from invokeai.backend.model_manager.metadata import HuggingFaceMetadataFetch, ModelMetadataWithFiles, RemoteModelFile
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
from tests.test_nodes import TestEventService
@ -88,14 +95,14 @@ def test_event_bus(tmp_path: Path, mm2_session: Session) -> None:
queue.join()
events = event_bus.events
assert len(events) == 3
assert events[0].payload["timestamp"] <= events[1].payload["timestamp"]
assert events[1].payload["timestamp"] <= events[2].payload["timestamp"]
assert events[0].event_name == "download_started"
assert events[1].event_name == "download_progress"
assert events[1].payload["total_bytes"] > 0
assert events[1].payload["current_bytes"] <= events[1].payload["total_bytes"]
assert events[2].event_name == "download_complete"
assert events[2].payload["total_bytes"] == 32029
assert isinstance(events[0], DownloadStartedEvent)
assert isinstance(events[1], DownloadProgressEvent)
assert isinstance(events[2], DownloadCompleteEvent)
assert events[0].timestamp <= events[1].timestamp
assert events[1].timestamp <= events[2].timestamp
assert events[1].total_bytes > 0
assert events[1].current_bytes <= events[1].total_bytes
assert events[2].total_bytes == 32029
# test a failure
event_bus.events = [] # reset our accumulator
@ -104,10 +111,10 @@ def test_event_bus(tmp_path: Path, mm2_session: Session) -> None:
events = event_bus.events
print("\n".join([x.model_dump_json() for x in events]))
assert len(events) == 1
assert events[0].event_name == "download_error"
assert events[0].payload["error_type"] == "HTTPError(NOT FOUND)"
assert events[0].payload["error"] is not None
assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].payload["error"])
assert isinstance(events[0], DownloadErrorEvent)
assert events[0].error_type == "HTTPError(NOT FOUND)"
assert events[0].error is not None
assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].error)
queue.stop()
@ -171,8 +178,8 @@ def test_cancel(tmp_path: Path, mm2_session: Session) -> None:
assert job.status == DownloadJobStatus.CANCELLED
assert cancelled
events = event_bus.events
assert events[-1].event_name == "download_cancelled"
assert events[-1].payload["source"] == "http://www.civitai.com/models/12345"
assert isinstance(events[-1], DownloadCancelledEvent)
assert events[-1].source == "http://www.civitai.com/models/12345"
queue.stop()
@ -278,7 +285,7 @@ def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch: Any
assert job.status == DownloadJobStatus.CANCELLED
assert cancelled
events = event_bus.events
assert "download_cancelled" in [x.event_name for x in events]
assert DownloadCancelledEvent in [type(x) for x in events]
queue.stop()

View File

@ -13,12 +13,20 @@ from pydantic_core import Url
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.events.events_common import (
ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallStartedEvent,
)
from invokeai.app.services.model_install import (
HFModelSource,
ModelInstallServiceBase,
)
from invokeai.app.services.model_install.model_install_common import (
InstallStatus,
LocalModelSource,
ModelInstallJob,
ModelInstallServiceBase,
URLModelSource,
)
from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException
@ -30,6 +38,7 @@ from invokeai.backend.model_manager.config import (
ModelType,
)
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
from tests.test_nodes import TestEventService
OS = platform.uname().system
@ -137,17 +146,16 @@ def test_background_install(
assert job.total_bytes == size
# test that the expected events were issued
bus = mm2_installer.event_bus
bus: TestEventService = mm2_installer.event_bus
assert bus
assert hasattr(bus, "events")
assert len(bus.events) == 2
event_names = [x.event_name for x in bus.events]
assert "model_install_running" in event_names
assert "model_install_completed" in event_names
assert Path(bus.events[0].payload["source"]) == source
assert Path(bus.events[1].payload["source"]) == source
key = bus.events[1].payload["key"]
assert isinstance(bus.events[0], ModelInstallStartedEvent)
assert isinstance(bus.events[1], ModelInstallCompleteEvent)
assert Path(bus.events[0].source) == source
assert Path(bus.events[1].source) == source
key = bus.events[1].key
assert key is not None
# see if the thing actually got installed at the expected location
@ -226,7 +234,7 @@ def test_delete_register(
def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors"))
bus = mm2_installer.event_bus
bus: TestEventService = mm2_installer.event_bus
store = mm2_installer.record_store
assert store is not None
assert bus is not None
@ -244,20 +252,17 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
assert (mm2_app_config.models_path / model_record.path).exists()
assert len(bus.events) == 4
event_names = [x.event_name for x in bus.events]
assert event_names == [
"model_install_downloading",
"model_install_downloads_done",
"model_install_running",
"model_install_completed",
]
assert isinstance(bus.events[0], ModelInstallDownloadProgressEvent)
assert isinstance(bus.events[1], ModelInstallDownloadsCompleteEvent)
assert isinstance(bus.events[2], ModelInstallStartedEvent)
assert isinstance(bus.events[3], ModelInstallCompleteEvent)
@pytest.mark.timeout(timeout=10, method="thread")
def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
bus = mm2_installer.event_bus
bus: TestEventService = mm2_installer.event_bus
store = mm2_installer.record_store
assert isinstance(bus, EventServiceBase)
assert store is not None
@ -274,15 +279,10 @@ def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_con
assert model_record.type == ModelType.Main
assert model_record.format == ModelFormat.Diffusers
assert hasattr(bus, "events") # the dummyeventservice has this
assert any(isinstance(x, ModelInstallStartedEvent) for x in bus.events)
assert any(isinstance(x, ModelInstallDownloadProgressEvent) for x in bus.events)
assert any(isinstance(x, ModelInstallCompleteEvent) for x in bus.events)
assert len(bus.events) >= 3
event_names = {x.event_name for x in bus.events}
assert event_names == {
"model_install_downloading",
"model_install_downloads_done",
"model_install_running",
"model_install_completed",
}
@pytest.mark.timeout(timeout=10, method="thread")
@ -308,19 +308,24 @@ def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_con
assert hasattr(bus, "events") # the dummyeventservice has this
assert len(bus.events) >= 3
event_names = {x.event_name for x in bus.events}
assert event_names == {
"model_install_downloading",
"model_install_downloads_done",
"model_install_running",
"model_install_completed",
}
event_types = [type(x) for x in bus.events]
assert all(
x in event_types
for x in [
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallStartedEvent,
ModelInstallCompleteEvent,
]
)
completed_events = [x for x in bus.events if x.event_name == "model_install_completed"]
downloading_events = [x for x in bus.events if x.event_name == "model_install_downloading"]
assert completed_events[0].payload["total_bytes"] == downloading_events[-1].payload["bytes"]
assert job.total_bytes == completed_events[0].payload["total_bytes"]
assert job.total_bytes == sum(x["total_bytes"] for x in downloading_events[-1].payload["parts"])
completed_events = [x for x in bus.events if isinstance(x, ModelInstallCompleteEvent)]
downloading_events = [x for x in bus.events if isinstance(x, ModelInstallDownloadProgressEvent)]
assert completed_events[0].total_bytes == downloading_events[-1].bytes
assert job.total_bytes == completed_events[0].total_bytes
print(downloading_events[-1])
print(job.download_parts)
assert job.total_bytes == sum(x["total_bytes"] for x in downloading_events[-1].parts)
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:

View File

@ -19,7 +19,7 @@ def mock_context(
return build_invocation_context(
services=mock_services,
data=None, # type: ignore
cancel_event=None, # type: ignore
is_canceled=None, # type: ignore
)

View File

@ -3,16 +3,13 @@
import os
import shutil
from pathlib import Path
from typing import Any, Dict, List
import pytest
from pydantic import BaseModel
from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueService, DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase
from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase
@ -39,27 +36,7 @@ from tests.backend.model_manager.model_metadata.metadata_examples import (
RepoHFModelJson1,
)
from tests.fixtures.sqlite_database import create_mock_sqlite_database
class DummyEvent(BaseModel):
"""Dummy Event to use with Dummy Event service."""
event_name: str
payload: Dict[str, Any]
class DummyEventService(EventServiceBase):
"""Dummy event service for testing."""
events: List[DummyEvent]
def __init__(self) -> None:
super().__init__()
self.events = []
def dispatch(self, event_name: str, payload: Any) -> None:
"""Dispatch an event by appending it to self.events."""
self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"]))
from tests.test_nodes import TestEventService
# Create a temporary directory using the contents of `./data/invokeai_root` as the template
@ -127,7 +104,7 @@ def mm2_installer(
) -> ModelInstallServiceBase:
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(mm2_app_config, logger)
events = DummyEventService()
events = TestEventService()
store = ModelRecordServiceSQL(db)
installer = ModelInstallService(

View File

@ -20,6 +20,7 @@ from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
from invokeai.app.services.invoker import Invoker
from invokeai.backend.util.logging import InvokeAILogger
from tests.backend.model_manager.model_manager_fixtures import * # noqa: F403
from tests.fixtures.sqlite_database import create_mock_sqlite_database # noqa: F401
from tests.test_nodes import TestEventService

View File

@ -1,7 +1,5 @@
from typing import Any, Callable, Union
from pydantic import BaseModel
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
@ -10,6 +8,7 @@ from invokeai.app.invocations.baseinvocation import (
)
from invokeai.app.invocations.fields import InputField, OutputField
from invokeai.app.invocations.image import ImageField
from invokeai.app.services.events.events_common import EventBase
from invokeai.app.services.shared.invocation_context import InvocationContext
@ -117,11 +116,10 @@ def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edg
)
class TestEvent(BaseModel):
class TestEvent(EventBase):
__test__ = False # not a pytest test case
event_name: str
payload: Any
__event_name__ = "test_event"
class TestEventService(EventServiceBase):
@ -129,10 +127,10 @@ class TestEventService(EventServiceBase):
def __init__(self):
super().__init__()
self.events: list[TestEvent] = []
self.events: list[EventBase] = []
def dispatch(self, event_name: str, payload: Any) -> None:
self.events.append(TestEvent(event_name=payload["event"], payload=payload["data"]))
def dispatch(self, event: EventBase) -> None:
self.events.append(event)
pass