mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Tidy names and locations of modules
- Rename old "model_management" directory to "model_management_OLD" in order to catch dangling references to original model manager. - Caught and fixed most dangling references (still checking) - Rename lora, textual_inversion and model_patcher modules - Introduce a RawModel base class to simplfy the Union returned by the model loaders. - Tidy up the model manager 2-related tests. Add useful fixtures, and a finalizer to the queue and installer fixtures that will stop the services and release threads.
This commit is contained in:
committed by
psychedelicious
parent
ba1f8878dd
commit
2ad0752582
@ -2,7 +2,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.util.test_utils import install_and_load_model
|
||||
|
||||
|
||||
|
@ -5,17 +5,16 @@ Test model loading
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallServiceBase
|
||||
from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw
|
||||
from invokeai.backend.model_manager.load import AnyModelLoader
|
||||
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
|
||||
from invokeai.app.services.model_load import ModelLoadServiceBase
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
|
||||
|
||||
def test_loading(mm2_installer: ModelInstallServiceBase, mm2_loader: AnyModelLoader, embedding_file: Path):
|
||||
def test_loading(mm2_installer: ModelInstallServiceBase, mm2_loader: ModelLoadServiceBase, embedding_file: Path):
|
||||
store = mm2_installer.record_store
|
||||
matches = store.search_by_attr(model_name="test_embedding")
|
||||
assert len(matches) == 0
|
||||
key = mm2_installer.register_path(embedding_file)
|
||||
loaded_model = mm2_loader.load_model(store.get_model(key))
|
||||
loaded_model = mm2_loader.load_model_by_config(store.get_model(key))
|
||||
assert loaded_model is not None
|
||||
assert loaded_model.config.key == key
|
||||
with loaded_model as model:
|
@ -6,24 +6,27 @@ from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
from pytest import FixtureRequest
|
||||
from pydantic import BaseModel
|
||||
from requests.sessions import Session
|
||||
from requests_testadapter import TestAdapter, TestSession
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadQueueService
|
||||
from invokeai.app.services.download import DownloadQueueServiceBase, DownloadQueueService
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.model_manager import ModelManagerServiceBase, ModelManagerService
|
||||
from invokeai.app.services.model_load import ModelLoadServiceBase, ModelLoadService
|
||||
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
|
||||
from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
|
||||
from invokeai.app.services.model_records import ModelRecordServiceSQL
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
from invokeai.backend.model_manager.config import (
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load import AnyModelLoader, ModelCache, ModelConvertCache
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.backend.model_manager_2.model_metadata.metadata_examples import (
|
||||
from tests.backend.model_manager.model_metadata.metadata_examples import (
|
||||
RepoCivitaiModelMetadata1,
|
||||
RepoCivitaiVersionMetadata1,
|
||||
RepoHFMetadata1,
|
||||
@ -86,22 +89,71 @@ def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig:
|
||||
app_config = InvokeAIAppConfig(
|
||||
root=mm2_root_dir,
|
||||
models_dir=mm2_root_dir / "models",
|
||||
log_level="info",
|
||||
)
|
||||
return app_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceSQL) -> AnyModelLoader:
|
||||
logger = InvokeAILogger.get_logger(config=mm2_app_config)
|
||||
def mm2_download_queue(mm2_session: Session,
|
||||
request: FixtureRequest
|
||||
) -> DownloadQueueServiceBase:
|
||||
download_queue = DownloadQueueService(requests_session=mm2_session)
|
||||
download_queue.start()
|
||||
|
||||
def stop_queue() -> None:
|
||||
download_queue.stop()
|
||||
|
||||
request.addfinalizer(stop_queue)
|
||||
return download_queue
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase:
|
||||
return mm2_record_store.metadata_store
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
|
||||
ram_cache = ModelCache(
|
||||
logger=logger, max_cache_size=mm2_app_config.ram_cache_size, max_vram_cache_size=mm2_app_config.vram_cache_size
|
||||
logger=InvokeAILogger.get_logger(),
|
||||
max_cache_size=mm2_app_config.ram_cache_size,
|
||||
max_vram_cache_size=mm2_app_config.vram_cache_size
|
||||
)
|
||||
convert_cache = ModelConvertCache(mm2_app_config.models_convert_cache_path)
|
||||
return AnyModelLoader(app_config=mm2_app_config, logger=logger, ram_cache=ram_cache, convert_cache=convert_cache)
|
||||
return ModelLoadService(app_config=mm2_app_config,
|
||||
record_store=mm2_record_store,
|
||||
ram_cache=ram_cache,
|
||||
convert_cache=convert_cache,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_installer(mm2_app_config: InvokeAIAppConfig,
|
||||
mm2_download_queue: DownloadQueueServiceBase,
|
||||
mm2_session: Session,
|
||||
request: FixtureRequest,
|
||||
) -> ModelInstallServiceBase:
|
||||
logger = InvokeAILogger.get_logger()
|
||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||
events = DummyEventService()
|
||||
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||
|
||||
installer = ModelInstallService(
|
||||
app_config=mm2_app_config,
|
||||
record_store=store,
|
||||
download_queue=mm2_download_queue,
|
||||
event_bus=events,
|
||||
session=mm2_session,
|
||||
)
|
||||
installer.start()
|
||||
|
||||
def stop_installer() -> None:
|
||||
installer.stop()
|
||||
|
||||
request.addfinalizer(stop_installer)
|
||||
return installer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL:
|
||||
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
|
||||
logger = InvokeAILogger.get_logger(config=mm2_app_config)
|
||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||
@ -161,11 +213,15 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL
|
||||
store.add_model("test_config_5", raw5)
|
||||
return store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase:
|
||||
return mm2_record_store.metadata_store
|
||||
|
||||
def mm2_model_manager(mm2_record_store: ModelRecordServiceBase,
|
||||
mm2_installer: ModelInstallServiceBase,
|
||||
mm2_loader: ModelLoadServiceBase) -> ModelManagerServiceBase:
|
||||
return ModelManagerService(
|
||||
store=mm2_record_store,
|
||||
install=mm2_installer,
|
||||
load=mm2_loader
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
|
||||
@ -252,22 +308,3 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
|
||||
return sess
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_installer(mm2_app_config: InvokeAIAppConfig, mm2_session: Session) -> ModelInstallServiceBase:
|
||||
logger = InvokeAILogger.get_logger()
|
||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||
events = DummyEventService()
|
||||
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||
|
||||
download_queue = DownloadQueueService(requests_session=mm2_session)
|
||||
download_queue.start()
|
||||
|
||||
installer = ModelInstallService(
|
||||
app_config=mm2_app_config,
|
||||
record_store=store,
|
||||
download_queue=download_queue,
|
||||
event_bus=events,
|
||||
session=mm2_session,
|
||||
)
|
||||
installer.start()
|
||||
return installer
|
@ -19,7 +19,7 @@ from invokeai.backend.model_manager.metadata import (
|
||||
UnknownMetadataException,
|
||||
)
|
||||
from invokeai.backend.model_manager.util import select_hf_files
|
||||
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
|
||||
|
||||
def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStoreBase) -> None:
|
@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2
|
||||
from invokeai.backend.model_manager.util.libc_util import LibcUtil, Struct_mallinfo2
|
||||
|
||||
|
||||
def test_libc_util_mallinfo2():
|
@ -5,8 +5,8 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_management.lora import ModelPatcher
|
||||
from invokeai.backend.model_management.models.lora import LoRALayer, LoRAModelRaw
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.lora import LoRALayer, LoRAModelRaw
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
@ -1,8 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from invokeai.backend.model_management.libc_util import Struct_mallinfo2
|
||||
from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
|
||||
from invokeai.backend.model_manager.util.libc_util import Struct_mallinfo2
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
|
||||
def test_memory_snapshot_capture():
|
||||
"""Smoke test of MemorySnapshot.capture()."""
|
||||
@ -26,6 +25,7 @@ snapshots = [
|
||||
def test_get_pretty_snapshot_diff(snapshot_1, snapshot_2):
|
||||
"""Test that get_pretty_snapshot_diff() works with various combinations of missing MemorySnapshot fields."""
|
||||
msg = get_pretty_snapshot_diff(snapshot_1, snapshot_2)
|
||||
print(msg)
|
||||
|
||||
expected_lines = 0
|
||||
if snapshot_1 is not None and snapshot_2 is not None:
|
@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_management.model_load_optimizations import _no_op, skip_torch_weight_init
|
||||
from invokeai.backend.model_manager.load.optimizations import _no_op, skip_torch_weight_init
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
Reference in New Issue
Block a user