mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
merge with main and resolve conflicts
This commit is contained in:
@ -13,12 +13,20 @@ from pydantic_core import Url
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.events.events_common import (
|
||||
ModelInstallCompleteEvent,
|
||||
ModelInstallDownloadProgressEvent,
|
||||
ModelInstallDownloadsCompleteEvent,
|
||||
ModelInstallStartedEvent,
|
||||
)
|
||||
from invokeai.app.services.model_install import (
|
||||
HFModelSource,
|
||||
ModelInstallServiceBase,
|
||||
)
|
||||
from invokeai.app.services.model_install.model_install_common import (
|
||||
InstallStatus,
|
||||
LocalModelSource,
|
||||
ModelInstallJob,
|
||||
ModelInstallServiceBase,
|
||||
URLModelSource,
|
||||
)
|
||||
from invokeai.app.services.model_records import ModelRecordChanges, UnknownModelException
|
||||
@ -30,6 +38,7 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelType,
|
||||
)
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
from tests.test_nodes import TestEventService
|
||||
|
||||
OS = platform.uname().system
|
||||
|
||||
@ -137,17 +146,16 @@ def test_background_install(
|
||||
assert job.total_bytes == size
|
||||
|
||||
# test that the expected events were issued
|
||||
bus = mm2_installer.event_bus
|
||||
bus: TestEventService = mm2_installer.event_bus
|
||||
assert bus
|
||||
assert hasattr(bus, "events")
|
||||
|
||||
assert len(bus.events) == 2
|
||||
event_names = [x.event_name for x in bus.events]
|
||||
assert "model_install_running" in event_names
|
||||
assert "model_install_completed" in event_names
|
||||
assert Path(bus.events[0].payload["source"]) == source
|
||||
assert Path(bus.events[1].payload["source"]) == source
|
||||
key = bus.events[1].payload["key"]
|
||||
assert isinstance(bus.events[0], ModelInstallStartedEvent)
|
||||
assert isinstance(bus.events[1], ModelInstallCompleteEvent)
|
||||
assert Path(bus.events[0].source) == source
|
||||
assert Path(bus.events[1].source) == source
|
||||
key = bus.events[1].key
|
||||
assert key is not None
|
||||
|
||||
# see if the thing actually got installed at the expected location
|
||||
@ -226,7 +234,7 @@ def test_delete_register(
|
||||
def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors"))
|
||||
|
||||
bus = mm2_installer.event_bus
|
||||
bus: TestEventService = mm2_installer.event_bus
|
||||
store = mm2_installer.record_store
|
||||
assert store is not None
|
||||
assert bus is not None
|
||||
@ -244,20 +252,17 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
|
||||
assert (mm2_app_config.models_path / model_record.path).exists()
|
||||
|
||||
assert len(bus.events) == 4
|
||||
event_names = [x.event_name for x in bus.events]
|
||||
assert event_names == [
|
||||
"model_install_downloading",
|
||||
"model_install_downloads_done",
|
||||
"model_install_running",
|
||||
"model_install_completed",
|
||||
]
|
||||
assert isinstance(bus.events[0], ModelInstallDownloadProgressEvent)
|
||||
assert isinstance(bus.events[1], ModelInstallDownloadsCompleteEvent)
|
||||
assert isinstance(bus.events[2], ModelInstallStartedEvent)
|
||||
assert isinstance(bus.events[3], ModelInstallCompleteEvent)
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||
|
||||
bus = mm2_installer.event_bus
|
||||
bus: TestEventService = mm2_installer.event_bus
|
||||
store = mm2_installer.record_store
|
||||
assert isinstance(bus, EventServiceBase)
|
||||
assert store is not None
|
||||
@ -274,15 +279,10 @@ def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_con
|
||||
assert model_record.type == ModelType.Main
|
||||
assert model_record.format == ModelFormat.Diffusers
|
||||
|
||||
assert hasattr(bus, "events") # the dummyeventservice has this
|
||||
assert any(isinstance(x, ModelInstallStartedEvent) for x in bus.events)
|
||||
assert any(isinstance(x, ModelInstallDownloadProgressEvent) for x in bus.events)
|
||||
assert any(isinstance(x, ModelInstallCompleteEvent) for x in bus.events)
|
||||
assert len(bus.events) >= 3
|
||||
event_names = {x.event_name for x in bus.events}
|
||||
assert event_names == {
|
||||
"model_install_downloading",
|
||||
"model_install_downloads_done",
|
||||
"model_install_running",
|
||||
"model_install_completed",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
@ -308,19 +308,24 @@ def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_con
|
||||
|
||||
assert hasattr(bus, "events") # the dummyeventservice has this
|
||||
assert len(bus.events) >= 3
|
||||
event_names = {x.event_name for x in bus.events}
|
||||
assert event_names == {
|
||||
"model_install_downloading",
|
||||
"model_install_downloads_done",
|
||||
"model_install_running",
|
||||
"model_install_completed",
|
||||
}
|
||||
event_types = [type(x) for x in bus.events]
|
||||
assert all(
|
||||
x in event_types
|
||||
for x in [
|
||||
ModelInstallDownloadProgressEvent,
|
||||
ModelInstallDownloadsCompleteEvent,
|
||||
ModelInstallStartedEvent,
|
||||
ModelInstallCompleteEvent,
|
||||
]
|
||||
)
|
||||
|
||||
completed_events = [x for x in bus.events if x.event_name == "model_install_completed"]
|
||||
downloading_events = [x for x in bus.events if x.event_name == "model_install_downloading"]
|
||||
assert completed_events[0].payload["total_bytes"] == downloading_events[-1].payload["bytes"]
|
||||
assert job.total_bytes == completed_events[0].payload["total_bytes"]
|
||||
assert job.total_bytes == sum(x["total_bytes"] for x in downloading_events[-1].payload["parts"])
|
||||
completed_events = [x for x in bus.events if isinstance(x, ModelInstallCompleteEvent)]
|
||||
downloading_events = [x for x in bus.events if isinstance(x, ModelInstallDownloadProgressEvent)]
|
||||
assert completed_events[0].total_bytes == downloading_events[-1].bytes
|
||||
assert job.total_bytes == completed_events[0].total_bytes
|
||||
print(downloading_events[-1])
|
||||
print(job.download_parts)
|
||||
assert job.total_bytes == sum(x["total_bytes"] for x in downloading_events[-1].parts)
|
||||
|
||||
|
||||
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
|
Reference in New Issue
Block a user