mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
final tidying before marking PR as ready for review
- Replace AnyModelLoader with ModelLoaderRegistry - Fix type check errors in multiple files - Remove apparently unneeded `get_model_config_enum()` method from model manager - Remove last vestiges of old model manager - Updated tests and documentation resolve conflict with seamless.py
This commit is contained in:
@ -4,18 +4,27 @@ Test model loading
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallServiceBase
|
||||
from invokeai.app.services.model_load import ModelLoadServiceBase
|
||||
from invokeai.app.services.model_manager import ModelManagerServiceBase
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
|
||||
def test_loading(mm2_installer: ModelInstallServiceBase, mm2_loader: ModelLoadServiceBase, embedding_file: Path):
|
||||
store = mm2_installer.record_store
|
||||
|
||||
def test_loading(mm2_model_manager: ModelManagerServiceBase, embedding_file: Path):
|
||||
store = mm2_model_manager.store
|
||||
matches = store.search_by_attr(model_name="test_embedding")
|
||||
assert len(matches) == 0
|
||||
key = mm2_installer.register_path(embedding_file)
|
||||
loaded_model = mm2_loader.load_model_by_config(store.get_model(key))
|
||||
key = mm2_model_manager.install.register_path(embedding_file)
|
||||
loaded_model = mm2_model_manager.load_model_by_config(store.get_model(key))
|
||||
assert loaded_model is not None
|
||||
assert loaded_model.config.key == key
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, TextualInversionModelRaw)
|
||||
loaded_model_2 = mm2_model_manager.load_model_by_key(key)
|
||||
assert loaded_model.config.key == loaded_model_2.config.key
|
||||
|
||||
loaded_model_3 = mm2_model_manager.load_model_by_attr(
|
||||
model_name=loaded_model.config.name,
|
||||
model_type=loaded_model.config.type,
|
||||
base_model=loaded_model.config.base,
|
||||
)
|
||||
assert loaded_model.config.key == loaded_model_3.config.key
|
||||
|
@ -6,17 +6,17 @@ from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
from pytest import FixtureRequest
|
||||
from pydantic import BaseModel
|
||||
from pytest import FixtureRequest
|
||||
from requests.sessions import Session
|
||||
from requests_testadapter import TestAdapter, TestSession
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadQueueServiceBase, DownloadQueueService
|
||||
from invokeai.app.services.download import DownloadQueueService, DownloadQueueServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.model_manager import ModelManagerServiceBase, ModelManagerService
|
||||
from invokeai.app.services.model_load import ModelLoadServiceBase, ModelLoadService
|
||||
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
|
||||
from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase
|
||||
from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase
|
||||
from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
from invokeai.backend.model_manager.config import (
|
||||
@ -95,9 +95,7 @@ def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_download_queue(mm2_session: Session,
|
||||
request: FixtureRequest
|
||||
) -> DownloadQueueServiceBase:
|
||||
def mm2_download_queue(mm2_session: Session, request: FixtureRequest) -> DownloadQueueServiceBase:
|
||||
download_queue = DownloadQueueService(requests_session=mm2_session)
|
||||
download_queue.start()
|
||||
|
||||
@ -107,30 +105,34 @@ def mm2_download_queue(mm2_session: Session,
|
||||
request.addfinalizer(stop_queue)
|
||||
return download_queue
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase:
|
||||
return mm2_record_store.metadata_store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
|
||||
ram_cache = ModelCache(
|
||||
logger=InvokeAILogger.get_logger(),
|
||||
max_cache_size=mm2_app_config.ram_cache_size,
|
||||
max_vram_cache_size=mm2_app_config.vram_cache_size
|
||||
max_vram_cache_size=mm2_app_config.vram_cache_size,
|
||||
)
|
||||
convert_cache = ModelConvertCache(mm2_app_config.models_convert_cache_path)
|
||||
return ModelLoadService(app_config=mm2_app_config,
|
||||
record_store=mm2_record_store,
|
||||
ram_cache=ram_cache,
|
||||
convert_cache=convert_cache,
|
||||
)
|
||||
return ModelLoadService(
|
||||
app_config=mm2_app_config,
|
||||
ram_cache=ram_cache,
|
||||
convert_cache=convert_cache,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_installer(mm2_app_config: InvokeAIAppConfig,
|
||||
mm2_download_queue: DownloadQueueServiceBase,
|
||||
mm2_session: Session,
|
||||
request: FixtureRequest,
|
||||
) -> ModelInstallServiceBase:
|
||||
def mm2_installer(
|
||||
mm2_app_config: InvokeAIAppConfig,
|
||||
mm2_download_queue: DownloadQueueServiceBase,
|
||||
mm2_session: Session,
|
||||
request: FixtureRequest,
|
||||
) -> ModelInstallServiceBase:
|
||||
logger = InvokeAILogger.get_logger()
|
||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||
events = DummyEventService()
|
||||
@ -213,15 +215,13 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBas
|
||||
store.add_model("test_config_5", raw5)
|
||||
return store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_model_manager(mm2_record_store: ModelRecordServiceBase,
|
||||
mm2_installer: ModelInstallServiceBase,
|
||||
mm2_loader: ModelLoadServiceBase) -> ModelManagerServiceBase:
|
||||
return ModelManagerService(
|
||||
store=mm2_record_store,
|
||||
install=mm2_installer,
|
||||
load=mm2_loader
|
||||
)
|
||||
def mm2_model_manager(
|
||||
mm2_record_store: ModelRecordServiceBase, mm2_installer: ModelInstallServiceBase, mm2_loader: ModelLoadServiceBase
|
||||
) -> ModelManagerServiceBase:
|
||||
return ModelManagerService(store=mm2_record_store, install=mm2_installer, load=mm2_loader)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
|
||||
@ -306,5 +306,3 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
|
||||
),
|
||||
)
|
||||
return sess
|
||||
|
||||
|
||||
|
@ -5,8 +5,8 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.lora import LoRALayer, LoRAModelRaw
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -1,7 +1,8 @@
|
||||
import pytest
|
||||
|
||||
from invokeai.backend.model_manager.util.libc_util import Struct_mallinfo2
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.model_manager.util.libc_util import Struct_mallinfo2
|
||||
|
||||
|
||||
def test_memory_snapshot_capture():
|
||||
"""Smoke test of MemorySnapshot.capture()."""
|
||||
|
Reference in New Issue
Block a user