InvokeAI/tests/backend/model_manager/model_manager_fixtures.py
Lincoln Stein ab46865e5b Tidy names and locations of modules
- Rename old "model_management" directory to "model_management_OLD" in order to catch
  dangling references to original model manager.
- Caught and fixed most dangling references (still checking)
- Rename lora, textual_inversion and model_patcher modules
- Introduce a RawModel base class to simplfy the Union returned by the
  model loaders.
- Tidy up the model manager 2-related tests. Add useful fixtures, and
  a finalizer to the queue and installer fixtures that will stop the
  services and release threads.
2024-02-29 13:28:20 -05:00

311 lines
10 KiB
Python

# Fixtures to support testing of the model_manager v2 installer, metadata and record store
import os
import shutil
from pathlib import Path
from typing import Any, Dict, List
import pytest
from pytest import FixtureRequest
from pydantic import BaseModel
from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueServiceBase, DownloadQueueService
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.model_manager import ModelManagerServiceBase, ModelManagerService
from invokeai.app.services.model_load import ModelLoadServiceBase, ModelLoadService
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
from invokeai.backend.model_manager.config import (
BaseModelType,
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
from invokeai.backend.util.logging import InvokeAILogger
from tests.backend.model_manager.model_metadata.metadata_examples import (
RepoCivitaiModelMetadata1,
RepoCivitaiVersionMetadata1,
RepoHFMetadata1,
RepoHFMetadata1_nofp16,
RepoHFModelJson1,
)
from tests.fixtures.sqlite_database import create_mock_sqlite_database
class DummyEvent(BaseModel):
"""Dummy Event to use with Dummy Event service."""
event_name: str
payload: Dict[str, Any]
class DummyEventService(EventServiceBase):
"""Dummy event service for testing."""
events: List[DummyEvent]
def __init__(self) -> None:
super().__init__()
self.events = []
def dispatch(self, event_name: str, payload: Any) -> None:
"""Dispatch an event by appending it to self.events."""
self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"]))
# Create a temporary directory using the contents of `./data/invokeai_root` as the template
@pytest.fixture
def mm2_root_dir(tmp_path_factory) -> Path:
root_template = Path(__file__).resolve().parent / "data" / "invokeai_root"
temp_dir: Path = tmp_path_factory.mktemp("data") / "invokeai_root"
shutil.copytree(root_template, temp_dir)
return temp_dir
@pytest.fixture
def mm2_model_files(tmp_path_factory) -> Path:
root_template = Path(__file__).resolve().parent / "data" / "test_files"
temp_dir: Path = tmp_path_factory.mktemp("data") / "test_files"
shutil.copytree(root_template, temp_dir)
return temp_dir
@pytest.fixture
def embedding_file(mm2_model_files: Path) -> Path:
return mm2_model_files / "test_embedding.safetensors"
@pytest.fixture
def diffusers_dir(mm2_model_files: Path) -> Path:
return mm2_model_files / "test-diffusers-main"
@pytest.fixture
def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig:
app_config = InvokeAIAppConfig(
root=mm2_root_dir,
models_dir=mm2_root_dir / "models",
log_level="info",
)
return app_config
@pytest.fixture
def mm2_download_queue(mm2_session: Session,
request: FixtureRequest
) -> DownloadQueueServiceBase:
download_queue = DownloadQueueService(requests_session=mm2_session)
download_queue.start()
def stop_queue() -> None:
download_queue.stop()
request.addfinalizer(stop_queue)
return download_queue
@pytest.fixture
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase:
return mm2_record_store.metadata_store
@pytest.fixture
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
ram_cache = ModelCache(
logger=InvokeAILogger.get_logger(),
max_cache_size=mm2_app_config.ram_cache_size,
max_vram_cache_size=mm2_app_config.vram_cache_size
)
convert_cache = ModelConvertCache(mm2_app_config.models_convert_cache_path)
return ModelLoadService(app_config=mm2_app_config,
record_store=mm2_record_store,
ram_cache=ram_cache,
convert_cache=convert_cache,
)
@pytest.fixture
def mm2_installer(mm2_app_config: InvokeAIAppConfig,
mm2_download_queue: DownloadQueueServiceBase,
mm2_session: Session,
request: FixtureRequest,
) -> ModelInstallServiceBase:
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(mm2_app_config, logger)
events = DummyEventService()
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
installer = ModelInstallService(
app_config=mm2_app_config,
record_store=store,
download_queue=mm2_download_queue,
event_bus=events,
session=mm2_session,
)
installer.start()
def stop_installer() -> None:
installer.stop()
request.addfinalizer(stop_installer)
return installer
@pytest.fixture
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
logger = InvokeAILogger.get_logger(config=mm2_app_config)
db = create_mock_sqlite_database(mm2_app_config, logger)
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
# add five simple config records to the database
raw1 = {
"path": "/tmp/foo1",
"format": ModelFormat("diffusers"),
"name": "test2",
"base": BaseModelType("sd-2"),
"type": ModelType("vae"),
"original_hash": "111222333444",
"source": "stabilityai/sdxl-vae",
}
raw2 = {
"path": "/tmp/foo2.ckpt",
"name": "model1",
"format": ModelFormat("checkpoint"),
"base": BaseModelType("sd-1"),
"type": "main",
"config": "/tmp/foo.yaml",
"variant": "normal",
"original_hash": "111222333444",
"source": "https://civitai.com/models/206883/split",
}
raw3 = {
"path": "/tmp/foo3",
"format": ModelFormat("diffusers"),
"name": "test3",
"base": BaseModelType("sdxl"),
"type": ModelType("main"),
"original_hash": "111222333444",
"source": "author3/model3",
"description": "This is test 3",
}
raw4 = {
"path": "/tmp/foo4",
"format": ModelFormat("diffusers"),
"name": "test4",
"base": BaseModelType("sdxl"),
"type": ModelType("lora"),
"original_hash": "111222333444",
"source": "author4/model4",
}
raw5 = {
"path": "/tmp/foo5",
"format": ModelFormat("diffusers"),
"name": "test5",
"base": BaseModelType("sd-1"),
"type": ModelType("lora"),
"original_hash": "111222333444",
"source": "author4/model5",
}
store.add_model("test_config_1", raw1)
store.add_model("test_config_2", raw2)
store.add_model("test_config_3", raw3)
store.add_model("test_config_4", raw4)
store.add_model("test_config_5", raw5)
return store
@pytest.fixture
def mm2_model_manager(mm2_record_store: ModelRecordServiceBase,
mm2_installer: ModelInstallServiceBase,
mm2_loader: ModelLoadServiceBase) -> ModelManagerServiceBase:
return ModelManagerService(
store=mm2_record_store,
install=mm2_installer,
load=mm2_loader
)
@pytest.fixture
def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
"""This fixtures defines a series of mock URLs for testing download and installation."""
sess: Session = TestSession()
sess.mount(
"https://test.com/missing_model.safetensors",
TestAdapter(
b"missing",
status=404,
),
)
sess.mount(
"https://huggingface.co/api/models/stabilityai/sdxl-turbo",
TestAdapter(
RepoHFMetadata1,
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)},
),
)
sess.mount(
"https://huggingface.co/api/models/stabilityai/sdxl-turbo-nofp16",
TestAdapter(
RepoHFMetadata1_nofp16,
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1_nofp16)},
),
)
sess.mount(
"https://civitai.com/api/v1/model-versions/242807",
TestAdapter(
RepoCivitaiVersionMetadata1,
headers={
"Content-Length": len(RepoCivitaiVersionMetadata1),
},
),
)
sess.mount(
"https://civitai.com/api/v1/models/215485",
TestAdapter(
RepoCivitaiModelMetadata1,
headers={
"Content-Length": len(RepoCivitaiModelMetadata1),
},
),
)
sess.mount(
"https://huggingface.co/stabilityai/sdxl-turbo/resolve/main/model_index.json",
TestAdapter(
RepoHFModelJson1,
headers={
"Content-Length": len(RepoHFModelJson1),
},
),
)
with open(embedding_file, "rb") as f:
data = f.read() # file is small - just 15K
sess.mount(
"https://www.test.foo/download/test_embedding.safetensors",
TestAdapter(data, headers={"Content-Type": "application/octet-stream", "Content-Length": len(data)}),
)
sess.mount(
"https://huggingface.co/api/models/stabilityai/sdxl-turbo",
TestAdapter(
RepoHFMetadata1,
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)},
),
)
for root, _, files in os.walk(diffusers_dir):
for name in files:
path = Path(root, name)
url_base = path.relative_to(diffusers_dir).as_posix()
url = f"https://huggingface.co/stabilityai/sdxl-turbo/resolve/main/{url_base}"
with open(path, "rb") as f:
data = f.read()
sess.mount(
url,
TestAdapter(
data,
headers={
"Content-Type": "application/json; charset=utf-8",
"Content-Length": len(data),
},
),
)
return sess