Tidy names and locations of modules

- Rename old "model_management" directory to "model_management_OLD" in order to catch
  dangling references to original model manager.
- Caught and fixed most dangling references (still checking)
- Rename lora, textual_inversion and model_patcher modules
- Introduce a RawModel base class to simplfy the Union returned by the
  model loaders.
- Tidy up the model manager 2-related tests. Add useful fixtures, and
  a finalizer to the queue and installer fixtures that will stop the
  services and release threads.
This commit is contained in:
Lincoln Stein
2024-02-17 11:45:32 -05:00
committed by psychedelicious
parent ba1f8878dd
commit 2ad0752582
89 changed files with 355 additions and 1609 deletions

View File

@ -2,7 +2,7 @@ import pytest
import torch
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType
from invokeai.backend.util.test_utils import install_and_load_model

View File

@ -5,17 +5,16 @@ Test model loading
from pathlib import Path
from invokeai.app.services.model_install import ModelInstallServiceBase
from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw
from invokeai.backend.model_manager.load import AnyModelLoader
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
from invokeai.app.services.model_load import ModelLoadServiceBase
from invokeai.backend.textual_inversion import TextualInversionModelRaw
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
def test_loading(mm2_installer: ModelInstallServiceBase, mm2_loader: AnyModelLoader, embedding_file: Path):
def test_loading(mm2_installer: ModelInstallServiceBase, mm2_loader: ModelLoadServiceBase, embedding_file: Path):
store = mm2_installer.record_store
matches = store.search_by_attr(model_name="test_embedding")
assert len(matches) == 0
key = mm2_installer.register_path(embedding_file)
loaded_model = mm2_loader.load_model(store.get_model(key))
loaded_model = mm2_loader.load_model_by_config(store.get_model(key))
assert loaded_model is not None
assert loaded_model.config.key == key
with loaded_model as model:

View File

@ -6,24 +6,27 @@ from pathlib import Path
from typing import Any, Dict, List
import pytest
from pytest import FixtureRequest
from pydantic import BaseModel
from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueService
from invokeai.app.services.download import DownloadQueueServiceBase, DownloadQueueService
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.model_manager import ModelManagerServiceBase, ModelManagerService
from invokeai.app.services.model_load import ModelLoadServiceBase, ModelLoadService
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
from invokeai.app.services.model_records import ModelRecordServiceSQL
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
from invokeai.backend.model_manager.config import (
BaseModelType,
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.load import AnyModelLoader, ModelCache, ModelConvertCache
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
from invokeai.backend.util.logging import InvokeAILogger
from tests.backend.model_manager_2.model_metadata.metadata_examples import (
from tests.backend.model_manager.model_metadata.metadata_examples import (
RepoCivitaiModelMetadata1,
RepoCivitaiVersionMetadata1,
RepoHFMetadata1,
@ -86,22 +89,71 @@ def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig:
app_config = InvokeAIAppConfig(
root=mm2_root_dir,
models_dir=mm2_root_dir / "models",
log_level="info",
)
return app_config
@pytest.fixture
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceSQL) -> AnyModelLoader:
logger = InvokeAILogger.get_logger(config=mm2_app_config)
def mm2_download_queue(mm2_session: Session,
request: FixtureRequest
) -> DownloadQueueServiceBase:
download_queue = DownloadQueueService(requests_session=mm2_session)
download_queue.start()
def stop_queue() -> None:
download_queue.stop()
request.addfinalizer(stop_queue)
return download_queue
@pytest.fixture
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase:
return mm2_record_store.metadata_store
@pytest.fixture
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
ram_cache = ModelCache(
logger=logger, max_cache_size=mm2_app_config.ram_cache_size, max_vram_cache_size=mm2_app_config.vram_cache_size
logger=InvokeAILogger.get_logger(),
max_cache_size=mm2_app_config.ram_cache_size,
max_vram_cache_size=mm2_app_config.vram_cache_size
)
convert_cache = ModelConvertCache(mm2_app_config.models_convert_cache_path)
return AnyModelLoader(app_config=mm2_app_config, logger=logger, ram_cache=ram_cache, convert_cache=convert_cache)
return ModelLoadService(app_config=mm2_app_config,
record_store=mm2_record_store,
ram_cache=ram_cache,
convert_cache=convert_cache,
)
@pytest.fixture
def mm2_installer(mm2_app_config: InvokeAIAppConfig,
mm2_download_queue: DownloadQueueServiceBase,
mm2_session: Session,
request: FixtureRequest,
) -> ModelInstallServiceBase:
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(mm2_app_config, logger)
events = DummyEventService()
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
installer = ModelInstallService(
app_config=mm2_app_config,
record_store=store,
download_queue=mm2_download_queue,
event_bus=events,
session=mm2_session,
)
installer.start()
def stop_installer() -> None:
installer.stop()
request.addfinalizer(stop_installer)
return installer
@pytest.fixture
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL:
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
logger = InvokeAILogger.get_logger(config=mm2_app_config)
db = create_mock_sqlite_database(mm2_app_config, logger)
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
@ -161,11 +213,15 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL
store.add_model("test_config_5", raw5)
return store
@pytest.fixture
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase:
return mm2_record_store.metadata_store
def mm2_model_manager(mm2_record_store: ModelRecordServiceBase,
mm2_installer: ModelInstallServiceBase,
mm2_loader: ModelLoadServiceBase) -> ModelManagerServiceBase:
return ModelManagerService(
store=mm2_record_store,
install=mm2_installer,
load=mm2_loader
)
@pytest.fixture
def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
@ -252,22 +308,3 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
return sess
@pytest.fixture
def mm2_installer(mm2_app_config: InvokeAIAppConfig, mm2_session: Session) -> ModelInstallServiceBase:
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(mm2_app_config, logger)
events = DummyEventService()
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
download_queue = DownloadQueueService(requests_session=mm2_session)
download_queue.start()
installer = ModelInstallService(
app_config=mm2_app_config,
record_store=store,
download_queue=download_queue,
event_bus=events,
session=mm2_session,
)
installer.start()
return installer

View File

@ -19,7 +19,7 @@ from invokeai.backend.model_manager.metadata import (
UnknownMetadataException,
)
from invokeai.backend.model_manager.util import select_hf_files
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStoreBase) -> None:

View File

@ -1,6 +1,6 @@
import pytest
from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2
from invokeai.backend.model_manager.util.libc_util import LibcUtil, Struct_mallinfo2
def test_libc_util_mallinfo2():

View File

@ -5,8 +5,8 @@
import pytest
import torch
from invokeai.backend.model_management.lora import ModelPatcher
from invokeai.backend.model_management.models.lora import LoRALayer, LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.lora import LoRALayer, LoRAModelRaw
@pytest.mark.parametrize(

View File

@ -1,8 +1,7 @@
import pytest
from invokeai.backend.model_management.libc_util import Struct_mallinfo2
from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_manager.util.libc_util import Struct_mallinfo2
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
def test_memory_snapshot_capture():
"""Smoke test of MemorySnapshot.capture()."""
@ -26,6 +25,7 @@ snapshots = [
def test_get_pretty_snapshot_diff(snapshot_1, snapshot_2):
"""Test that get_pretty_snapshot_diff() works with various combinations of missing MemorySnapshot fields."""
msg = get_pretty_snapshot_diff(snapshot_1, snapshot_2)
print(msg)
expected_lines = 0
if snapshot_1 is not None and snapshot_2 is not None:

View File

@ -1,7 +1,7 @@
import pytest
import torch
from invokeai.backend.model_management.model_load_optimizations import _no_op, skip_torch_weight_init
from invokeai.backend.model_manager.load.optimizations import _no_op, skip_torch_weight_init
@pytest.mark.parametrize(