InvokeAI/tests/backend/model_manager/model_manager_fixtures.py
Lincoln Stein 9f9379682e ruff fixes
2024-06-07 13:54:41 +10:00

344 lines
11 KiB
Python

# Fixtures to support testing of the model_manager v2 installer, metadata and record store
import os
import shutil
from pathlib import Path
import pytest
from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueService, DownloadQueueServiceBase
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase
from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
from invokeai.backend.model_manager.config import (
BaseModelType,
LoRADiffusersConfig,
MainCheckpointConfig,
MainDiffusersConfig,
ModelFormat,
ModelSourceType,
ModelType,
ModelVariantType,
VAEDiffusersConfig,
)
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
from invokeai.backend.util.logging import InvokeAILogger
from tests.backend.model_manager.model_metadata.metadata_examples import (
HFTestLoraMetadata,
RepoCivitaiModelMetadata1,
RepoCivitaiVersionMetadata1,
RepoHFMetadata1,
RepoHFMetadata1_nofp16,
RepoHFModelJson1,
)
from tests.fixtures.sqlite_database import create_mock_sqlite_database
from tests.test_nodes import TestEventService
# Create a temporary directory using the contents of `./data/invokeai_root` as the template
@pytest.fixture
def mm2_root_dir(tmp_path_factory) -> Path:
root_template = Path(__file__).resolve().parent / "data" / "invokeai_root"
temp_dir: Path = tmp_path_factory.mktemp("data") / "invokeai_root"
shutil.copytree(root_template, temp_dir)
return temp_dir
@pytest.fixture
def mm2_model_files(tmp_path_factory) -> Path:
root_template = Path(__file__).resolve().parent / "data" / "test_files"
temp_dir: Path = tmp_path_factory.mktemp("data") / "test_files"
shutil.copytree(root_template, temp_dir)
return temp_dir
@pytest.fixture
def embedding_file(mm2_model_files: Path) -> Path:
return mm2_model_files / "test_embedding.safetensors"
@pytest.fixture
def vae_directory(mm2_model_files: Path) -> Path:
return mm2_model_files / "taesdxl"
@pytest.fixture
def diffusers_dir(mm2_model_files: Path) -> Path:
return mm2_model_files / "test-diffusers-main"
@pytest.fixture
def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig:
app_config = InvokeAIAppConfig(models_dir=mm2_root_dir / "models", log_level="info")
app_config._root = mm2_root_dir
return app_config
@pytest.fixture
def mm2_download_queue(mm2_session: Session) -> DownloadQueueServiceBase:
download_queue = DownloadQueueService(requests_session=mm2_session)
download_queue.start()
yield download_queue
download_queue.stop()
@pytest.fixture
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
ram_cache = ModelCache(
logger=InvokeAILogger.get_logger(),
max_cache_size=mm2_app_config.ram,
max_vram_cache_size=mm2_app_config.vram,
)
convert_cache = ModelConvertCache(mm2_app_config.convert_cache_path)
return ModelLoadService(
app_config=mm2_app_config,
ram_cache=ram_cache,
convert_cache=convert_cache,
)
@pytest.fixture
def mm2_installer(
mm2_app_config: InvokeAIAppConfig,
mm2_download_queue: DownloadQueueServiceBase,
mm2_session: Session,
) -> ModelInstallServiceBase:
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(mm2_app_config, logger)
events = TestEventService()
store = ModelRecordServiceSQL(db)
installer = ModelInstallService(
app_config=mm2_app_config,
record_store=store,
download_queue=mm2_download_queue,
event_bus=events,
session=mm2_session,
)
installer.start()
yield installer
installer.stop()
@pytest.fixture
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
logger = InvokeAILogger.get_logger(config=mm2_app_config)
db = create_mock_sqlite_database(mm2_app_config, logger)
store = ModelRecordServiceSQL(db)
# add five simple config records to the database
config1 = VAEDiffusersConfig(
key="test_config_1",
path="/tmp/foo1",
format=ModelFormat.Diffusers,
name="test2",
base=BaseModelType.StableDiffusion2,
type=ModelType.VAE,
hash="111222333444",
source="stabilityai/sdxl-vae",
source_type=ModelSourceType.HFRepoID,
)
config2 = MainCheckpointConfig(
key="test_config_2",
path="/tmp/foo2.ckpt",
name="model1",
format=ModelFormat.Checkpoint,
base=BaseModelType.StableDiffusion1,
type=ModelType.Main,
config_path="/tmp/foo.yaml",
variant=ModelVariantType.Normal,
hash="111222333444",
source="https://civitai.com/models/206883/split",
source_type=ModelSourceType.Url,
)
config3 = MainDiffusersConfig(
key="test_config_3",
path="/tmp/foo3",
format=ModelFormat.Diffusers,
name="test3",
base=BaseModelType.StableDiffusionXL,
type=ModelType.Main,
hash="111222333444",
source="author3/model3",
description="This is test 3",
source_type=ModelSourceType.HFRepoID,
)
config4 = LoRADiffusersConfig(
key="test_config_4",
path="/tmp/foo4",
format=ModelFormat.Diffusers,
name="test4",
base=BaseModelType.StableDiffusionXL,
type=ModelType.LoRA,
hash="111222333444",
source="author4/model4",
source_type=ModelSourceType.HFRepoID,
)
config5 = LoRADiffusersConfig(
key="test_config_5",
path="/tmp/foo5",
format=ModelFormat.Diffusers,
name="test5",
base=BaseModelType.StableDiffusion1,
type=ModelType.LoRA,
hash="111222333444",
source="author4/model5",
source_type=ModelSourceType.HFRepoID,
)
store.add_model(config1)
store.add_model(config2)
store.add_model(config3)
store.add_model(config4)
store.add_model(config5)
return store
@pytest.fixture
def mm2_model_manager(
mm2_record_store: ModelRecordServiceBase, mm2_installer: ModelInstallServiceBase, mm2_loader: ModelLoadServiceBase
) -> ModelManagerServiceBase:
return ModelManagerService(store=mm2_record_store, install=mm2_installer, load=mm2_loader)
@pytest.fixture
def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
"""This fixtures defines a series of mock URLs for testing download and installation."""
sess: Session = TestSession()
sess.mount(
"https://test.com/missing_model.safetensors",
TestAdapter(
b"missing",
status=404,
),
)
sess.mount(
"https://huggingface.co/api/models/stabilityai/sdxl-turbo",
TestAdapter(
RepoHFMetadata1,
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)},
),
)
sess.mount(
"https://huggingface.co/api/models/stabilityai/sdxl-turbo-nofp16",
TestAdapter(
RepoHFMetadata1_nofp16,
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1_nofp16)},
),
)
sess.mount(
"https://civitai.com/api/v1/model-versions/242807",
TestAdapter(
RepoCivitaiVersionMetadata1,
headers={
"Content-Length": len(RepoCivitaiVersionMetadata1),
},
),
)
sess.mount(
"https://civitai.com/api/v1/models/215485",
TestAdapter(
RepoCivitaiModelMetadata1,
headers={
"Content-Length": len(RepoCivitaiModelMetadata1),
},
),
)
sess.mount(
"https://huggingface.co/stabilityai/sdxl-turbo/resolve/main/model_index.json",
TestAdapter(
RepoHFModelJson1,
headers={
"Content-Length": len(RepoHFModelJson1),
},
),
)
with open(embedding_file, "rb") as f:
data = f.read() # file is small - just 15K
sess.mount(
"https://www.test.foo/download/test_embedding.safetensors",
TestAdapter(data, headers={"Content-Type": "application/octet-stream", "Content-Length": len(data)}),
)
sess.mount(
"https://huggingface.co/api/models/stabilityai/sdxl-turbo",
TestAdapter(
RepoHFMetadata1,
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)},
),
)
sess.mount(
"https://huggingface.co/api/models/InvokeAI-test/textual_inversion_tests?blobs=True",
TestAdapter(
HFTestLoraMetadata,
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(HFTestLoraMetadata)},
),
)
sess.mount(
"https://huggingface.co/InvokeAI-test/textual_inversion_tests/resolve/main/learned_embeds-steps-1000.safetensors",
TestAdapter(
data,
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(data)},
),
)
for root, _, files in os.walk(diffusers_dir):
for name in files:
path = Path(root, name)
url_base = path.relative_to(diffusers_dir).as_posix()
url = f"https://huggingface.co/stabilityai/sdxl-turbo/resolve/main/{url_base}"
with open(path, "rb") as f:
data = f.read()
sess.mount(
url,
TestAdapter(
data,
headers={
"Content-Type": "application/json; charset=utf-8",
"Content-Length": len(data),
},
),
)
for i in ["12345", "9999", "54321"]:
content = (
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
) # for pause tests, must make content large
sess.mount(
f"http://www.civitai.com/models/{i}",
TestAdapter(
content,
headers={
"Content-Length": len(content),
"Content-Disposition": f'filename="mock{i}.safetensors"',
},
),
)
sess.mount(
"http://www.huggingface.co/foo.txt",
TestAdapter(
content,
headers={
"Content-Length": len(content),
"Content-Disposition": 'filename="foo.safetensors"',
},
),
)
# here are some malformed URLs to test
# missing the content length
sess.mount(
"http://www.civitai.com/models/missing",
TestAdapter(
b"Missing content length",
headers={
"Content-Disposition": 'filename="missing.txt"',
},
),
)
# not found test
sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
return sess