final tidying before marking PR as ready for review

- Replace AnyModelLoader with ModelLoaderRegistry
- Fix type check errors in multiple files
- Remove apparently unneeded `get_model_config_enum()` method from model manager
- Remove last vestiges of old model manager
- Updated tests and documentation

resolve conflict with seamless.py
This commit is contained in:
psychedelicious
2024-02-18 17:27:42 +11:00
parent 2ad0752582
commit be8b99eed5
74 changed files with 672 additions and 10362 deletions

View File

@ -4,18 +4,27 @@ Test model loading
from pathlib import Path
from invokeai.app.services.model_install import ModelInstallServiceBase
from invokeai.app.services.model_load import ModelLoadServiceBase
from invokeai.app.services.model_manager import ModelManagerServiceBase
from invokeai.backend.textual_inversion import TextualInversionModelRaw
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
def test_loading(mm2_installer: ModelInstallServiceBase, mm2_loader: ModelLoadServiceBase, embedding_file: Path):
store = mm2_installer.record_store
def test_loading(mm2_model_manager: ModelManagerServiceBase, embedding_file: Path):
store = mm2_model_manager.store
matches = store.search_by_attr(model_name="test_embedding")
assert len(matches) == 0
key = mm2_installer.register_path(embedding_file)
loaded_model = mm2_loader.load_model_by_config(store.get_model(key))
key = mm2_model_manager.install.register_path(embedding_file)
loaded_model = mm2_model_manager.load_model_by_config(store.get_model(key))
assert loaded_model is not None
assert loaded_model.config.key == key
with loaded_model as model:
assert isinstance(model, TextualInversionModelRaw)
loaded_model_2 = mm2_model_manager.load_model_by_key(key)
assert loaded_model.config.key == loaded_model_2.config.key
loaded_model_3 = mm2_model_manager.load_model_by_attr(
model_name=loaded_model.config.name,
model_type=loaded_model.config.type,
base_model=loaded_model.config.base,
)
assert loaded_model.config.key == loaded_model_3.config.key

View File

@ -6,17 +6,17 @@ from pathlib import Path
from typing import Any, Dict, List
import pytest
from pytest import FixtureRequest
from pydantic import BaseModel
from pytest import FixtureRequest
from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueServiceBase, DownloadQueueService
from invokeai.app.services.download import DownloadQueueService, DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.model_manager import ModelManagerServiceBase, ModelManagerService
from invokeai.app.services.model_load import ModelLoadServiceBase, ModelLoadService
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase
from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase
from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
from invokeai.backend.model_manager.config import (
@ -95,9 +95,7 @@ def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig:
@pytest.fixture
def mm2_download_queue(mm2_session: Session,
request: FixtureRequest
) -> DownloadQueueServiceBase:
def mm2_download_queue(mm2_session: Session, request: FixtureRequest) -> DownloadQueueServiceBase:
download_queue = DownloadQueueService(requests_session=mm2_session)
download_queue.start()
@ -107,30 +105,34 @@ def mm2_download_queue(mm2_session: Session,
request.addfinalizer(stop_queue)
return download_queue
@pytest.fixture
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase:
return mm2_record_store.metadata_store
@pytest.fixture
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
ram_cache = ModelCache(
logger=InvokeAILogger.get_logger(),
max_cache_size=mm2_app_config.ram_cache_size,
max_vram_cache_size=mm2_app_config.vram_cache_size
max_vram_cache_size=mm2_app_config.vram_cache_size,
)
convert_cache = ModelConvertCache(mm2_app_config.models_convert_cache_path)
return ModelLoadService(app_config=mm2_app_config,
record_store=mm2_record_store,
ram_cache=ram_cache,
convert_cache=convert_cache,
)
return ModelLoadService(
app_config=mm2_app_config,
ram_cache=ram_cache,
convert_cache=convert_cache,
)
@pytest.fixture
def mm2_installer(mm2_app_config: InvokeAIAppConfig,
mm2_download_queue: DownloadQueueServiceBase,
mm2_session: Session,
request: FixtureRequest,
) -> ModelInstallServiceBase:
def mm2_installer(
mm2_app_config: InvokeAIAppConfig,
mm2_download_queue: DownloadQueueServiceBase,
mm2_session: Session,
request: FixtureRequest,
) -> ModelInstallServiceBase:
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(mm2_app_config, logger)
events = DummyEventService()
@ -213,15 +215,13 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBas
store.add_model("test_config_5", raw5)
return store
@pytest.fixture
def mm2_model_manager(mm2_record_store: ModelRecordServiceBase,
mm2_installer: ModelInstallServiceBase,
mm2_loader: ModelLoadServiceBase) -> ModelManagerServiceBase:
return ModelManagerService(
store=mm2_record_store,
install=mm2_installer,
load=mm2_loader
)
def mm2_model_manager(
mm2_record_store: ModelRecordServiceBase, mm2_installer: ModelInstallServiceBase, mm2_loader: ModelLoadServiceBase
) -> ModelManagerServiceBase:
return ModelManagerService(store=mm2_record_store, install=mm2_installer, load=mm2_loader)
@pytest.fixture
def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
@ -306,5 +306,3 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
),
)
return sess

View File

@ -5,8 +5,8 @@
import pytest
import torch
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.lora import LoRALayer, LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
@pytest.mark.parametrize(

View File

@ -1,7 +1,8 @@
import pytest
from invokeai.backend.model_manager.util.libc_util import Struct_mallinfo2
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_manager.util.libc_util import Struct_mallinfo2
def test_memory_snapshot_capture():
"""Smoke test of MemorySnapshot.capture()."""