merge with main and resolve conflicts

This commit is contained in:
Lincoln Stein
2024-05-27 22:20:34 -04:00
256 changed files with 9360 additions and 6061 deletions

View File

@ -13,12 +13,20 @@ from pydantic_core import Url
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.events.events_common import (
ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallStartedEvent,
)
from invokeai.app.services.model_install import (
HFModelSource,
ModelInstallServiceBase,
)
from invokeai.app.services.model_install.model_install_common import (
InstallStatus,
LocalModelSource,
ModelInstallJob,
ModelInstallServiceBase,
URLModelSource,
)
from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException
@ -30,6 +38,7 @@ from invokeai.backend.model_manager.config import (
ModelType,
)
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
from tests.test_nodes import TestEventService
OS = platform.uname().system
@ -137,17 +146,16 @@ def test_background_install(
assert job.total_bytes == size
# test that the expected events were issued
bus = mm2_installer.event_bus
bus: TestEventService = mm2_installer.event_bus
assert bus
assert hasattr(bus, "events")
assert len(bus.events) == 2
event_names = [x.event_name for x in bus.events]
assert "model_install_running" in event_names
assert "model_install_completed" in event_names
assert Path(bus.events[0].payload["source"]) == source
assert Path(bus.events[1].payload["source"]) == source
key = bus.events[1].payload["key"]
assert isinstance(bus.events[0], ModelInstallStartedEvent)
assert isinstance(bus.events[1], ModelInstallCompleteEvent)
assert Path(bus.events[0].source) == source
assert Path(bus.events[1].source) == source
key = bus.events[1].key
assert key is not None
# see if the thing actually got installed at the expected location
@ -226,7 +234,7 @@ def test_delete_register(
def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors"))
bus = mm2_installer.event_bus
bus: TestEventService = mm2_installer.event_bus
store = mm2_installer.record_store
assert store is not None
assert bus is not None
@ -244,20 +252,17 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
assert (mm2_app_config.models_path / model_record.path).exists()
assert len(bus.events) == 4
event_names = [x.event_name for x in bus.events]
assert event_names == [
"model_install_downloading",
"model_install_downloads_done",
"model_install_running",
"model_install_completed",
]
assert isinstance(bus.events[0], ModelInstallDownloadProgressEvent)
assert isinstance(bus.events[1], ModelInstallDownloadsCompleteEvent)
assert isinstance(bus.events[2], ModelInstallStartedEvent)
assert isinstance(bus.events[3], ModelInstallCompleteEvent)
@pytest.mark.timeout(timeout=10, method="thread")
def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
bus = mm2_installer.event_bus
bus: TestEventService = mm2_installer.event_bus
store = mm2_installer.record_store
assert isinstance(bus, EventServiceBase)
assert store is not None
@ -274,15 +279,10 @@ def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_con
assert model_record.type == ModelType.Main
assert model_record.format == ModelFormat.Diffusers
assert hasattr(bus, "events") # the dummyeventservice has this
assert any(isinstance(x, ModelInstallStartedEvent) for x in bus.events)
assert any(isinstance(x, ModelInstallDownloadProgressEvent) for x in bus.events)
assert any(isinstance(x, ModelInstallCompleteEvent) for x in bus.events)
assert len(bus.events) >= 3
event_names = {x.event_name for x in bus.events}
assert event_names == {
"model_install_downloading",
"model_install_downloads_done",
"model_install_running",
"model_install_completed",
}
@pytest.mark.timeout(timeout=10, method="thread")
@ -308,19 +308,24 @@ def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_con
assert hasattr(bus, "events") # the dummyeventservice has this
assert len(bus.events) >= 3
event_names = {x.event_name for x in bus.events}
assert event_names == {
"model_install_downloading",
"model_install_downloads_done",
"model_install_running",
"model_install_completed",
}
event_types = [type(x) for x in bus.events]
assert all(
x in event_types
for x in [
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallStartedEvent,
ModelInstallCompleteEvent,
]
)
completed_events = [x for x in bus.events if x.event_name == "model_install_completed"]
downloading_events = [x for x in bus.events if x.event_name == "model_install_downloading"]
assert completed_events[0].payload["total_bytes"] == downloading_events[-1].payload["bytes"]
assert job.total_bytes == completed_events[0].payload["total_bytes"]
assert job.total_bytes == sum(x["total_bytes"] for x in downloading_events[-1].payload["parts"])
completed_events = [x for x in bus.events if isinstance(x, ModelInstallCompleteEvent)]
downloading_events = [x for x in bus.events if isinstance(x, ModelInstallDownloadProgressEvent)]
assert completed_events[0].total_bytes == downloading_events[-1].bytes
assert job.total_bytes == completed_events[0].total_bytes
print(downloading_events[-1])
print(job.download_parts)
assert job.total_bytes == sum(x["total_bytes"] for x in downloading_events[-1].parts)
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: