Merge branch 'ryan/remove-attention-map-saving' into ryan/regional-conditioning

This commit is contained in:
Ryan Dick
2024-03-01 11:03:04 -05:00
740 changed files with 24428 additions and 31726 deletions

View File

@ -1,7 +1,7 @@
import pytest
import torch
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
from invokeai.backend.util.test_utils import install_and_load_model

View File

@ -0,0 +1,30 @@
"""
Test model loading
"""
from pathlib import Path
from invokeai.app.services.model_manager import ModelManagerServiceBase
from invokeai.backend.textual_inversion import TextualInversionModelRaw
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
def test_loading(mm2_model_manager: ModelManagerServiceBase, embedding_file: Path):
store = mm2_model_manager.store
matches = store.search_by_attr(model_name="test_embedding")
assert len(matches) == 0
key = mm2_model_manager.install.register_path(embedding_file)
loaded_model = mm2_model_manager.load_model_by_config(store.get_model(key))
assert loaded_model is not None
assert loaded_model.config.key == key
with loaded_model as model:
assert isinstance(model, TextualInversionModelRaw)
loaded_model_2 = mm2_model_manager.load_model_by_key(key)
assert loaded_model.config.key == loaded_model_2.config.key
loaded_model_3 = mm2_model_manager.load_model_by_attr(
model_name=loaded_model.config.name,
model_type=loaded_model.config.type,
base_model=loaded_model.config.base,
)
assert loaded_model.config.key == loaded_model_3.config.key

View File

@ -2,27 +2,32 @@
import os
import shutil
import time
from pathlib import Path
from typing import Any, Dict, List
import pytest
from pydantic import BaseModel
from pytest import FixtureRequest
from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadQueueService
from invokeai.app.services.download import DownloadQueueService, DownloadQueueServiceBase
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
from invokeai.app.services.model_records import ModelRecordServiceSQL
from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase
from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase
from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
from invokeai.backend.model_manager.config import (
BaseModelType,
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.metadata import ModelMetadataStore
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
from invokeai.backend.util.logging import InvokeAILogger
from tests.backend.model_manager_2.model_metadata.metadata_examples import (
from tests.backend.model_manager.model_metadata.metadata_examples import (
RepoCivitaiModelMetadata1,
RepoCivitaiVersionMetadata1,
RepoHFMetadata1,
@ -85,15 +90,77 @@ def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig:
app_config = InvokeAIAppConfig(
root=mm2_root_dir,
models_dir=mm2_root_dir / "models",
log_level="info",
)
return app_config
@pytest.fixture
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL:
def mm2_download_queue(mm2_session: Session, request: FixtureRequest) -> DownloadQueueServiceBase:
download_queue = DownloadQueueService(requests_session=mm2_session)
download_queue.start()
def stop_queue() -> None:
download_queue.stop()
request.addfinalizer(stop_queue)
return download_queue
@pytest.fixture
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase:
return mm2_record_store.metadata_store
@pytest.fixture
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
ram_cache = ModelCache(
logger=InvokeAILogger.get_logger(),
max_cache_size=mm2_app_config.ram_cache_size,
max_vram_cache_size=mm2_app_config.vram_cache_size,
)
convert_cache = ModelConvertCache(mm2_app_config.models_convert_cache_path)
return ModelLoadService(
app_config=mm2_app_config,
ram_cache=ram_cache,
convert_cache=convert_cache,
)
@pytest.fixture
def mm2_installer(
mm2_app_config: InvokeAIAppConfig,
mm2_download_queue: DownloadQueueServiceBase,
mm2_session: Session,
request: FixtureRequest,
) -> ModelInstallServiceBase:
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(mm2_app_config, logger)
events = DummyEventService()
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
installer = ModelInstallService(
app_config=mm2_app_config,
record_store=store,
download_queue=mm2_download_queue,
event_bus=events,
session=mm2_session,
)
installer.start()
def stop_installer() -> None:
installer.stop()
time.sleep(0.1) # avoid error message from the logger when it is closed before thread prints final message
request.addfinalizer(stop_installer)
return installer
@pytest.fixture
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
logger = InvokeAILogger.get_logger(config=mm2_app_config)
db = create_mock_sqlite_database(mm2_app_config, logger)
store = ModelRecordServiceSQL(db)
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
# add five simple config records to the database
raw1 = {
"path": "/tmp/foo1",
@ -152,15 +219,16 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL
@pytest.fixture
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStore:
db = mm2_record_store._db # to ensure we are sharing the same database
return ModelMetadataStore(db)
def mm2_model_manager(
mm2_record_store: ModelRecordServiceBase, mm2_installer: ModelInstallServiceBase, mm2_loader: ModelLoadServiceBase
) -> ModelManagerServiceBase:
return ModelManagerService(store=mm2_record_store, install=mm2_installer, load=mm2_loader)
@pytest.fixture
def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
"""This fixtures defines a series of mock URLs for testing download and installation."""
sess = TestSession()
sess: Session = TestSession()
sess.mount(
"https://test.com/missing_model.safetensors",
TestAdapter(
@ -240,26 +308,3 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
),
)
return sess
@pytest.fixture
def mm2_installer(mm2_app_config: InvokeAIAppConfig, mm2_session: Session) -> ModelInstallServiceBase:
logger = InvokeAILogger.get_logger()
db = create_mock_sqlite_database(mm2_app_config, logger)
events = DummyEventService()
store = ModelRecordServiceSQL(db)
metadata_store = ModelMetadataStore(db)
download_queue = DownloadQueueService(requests_session=mm2_session)
download_queue.start()
installer = ModelInstallService(
app_config=mm2_app_config,
record_store=store,
download_queue=download_queue,
metadata_store=metadata_store,
event_bus=events,
session=mm2_session,
)
installer.start()
return installer

View File

@ -1,6 +1,7 @@
"""
Test model metadata fetching and storage.
"""
import datetime
from pathlib import Path
@ -8,6 +9,7 @@ import pytest
from pydantic.networks import HttpUrl
from requests.sessions import Session
from invokeai.app.services.model_metadata import ModelMetadataStoreBase
from invokeai.backend.model_manager.config import ModelRepoVariant
from invokeai.backend.model_manager.metadata import (
CivitaiMetadata,
@ -15,14 +17,13 @@ from invokeai.backend.model_manager.metadata import (
CommercialUsage,
HuggingFaceMetadata,
HuggingFaceMetadataFetch,
ModelMetadataStore,
UnknownMetadataException,
)
from invokeai.backend.model_manager.util import select_hf_files
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStore) -> None:
def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStoreBase) -> None:
tags = {"text-to-image", "diffusers"}
input_metadata = HuggingFaceMetadata(
name="sdxl-vae",
@ -40,7 +41,7 @@ def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStore) -> None:
assert mm2_metadata_store.list_tags() == tags
def test_metadata_store_update(mm2_metadata_store: ModelMetadataStore) -> None:
def test_metadata_store_update(mm2_metadata_store: ModelMetadataStoreBase) -> None:
input_metadata = HuggingFaceMetadata(
name="sdxl-vae",
author="stabilityai",
@ -57,7 +58,7 @@ def test_metadata_store_update(mm2_metadata_store: ModelMetadataStore) -> None:
assert input_metadata == output_metadata
def test_metadata_search(mm2_metadata_store: ModelMetadataStore) -> None:
def test_metadata_search(mm2_metadata_store: ModelMetadataStoreBase) -> None:
metadata1 = HuggingFaceMetadata(
name="sdxl-vae",
author="stabilityai",
@ -133,7 +134,7 @@ def test_metadata_civitai_fetch(mm2_session: Session) -> None:
assert metadata.id == 215485
assert metadata.author == "test_author" # note that this is not the same as the original from Civitai
assert metadata.allow_commercial_use # changed to make sure we are reading locally not remotely
assert metadata.restrictions.AllowCommercialUse == CommercialUsage("RentCivit")
assert CommercialUsage("RentCivit") in metadata.restrictions.AllowCommercialUse
assert metadata.version_id == 242807
assert metadata.tags == {"tool", "turbo", "sdxl turbo"}

View File

@ -1,6 +1,6 @@
import pytest
from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2
from invokeai.backend.model_manager.util.libc_util import LibcUtil, Struct_mallinfo2
def test_libc_util_mallinfo2():

View File

@ -5,8 +5,8 @@
import pytest
import torch
from invokeai.backend.model_management.lora import ModelPatcher
from invokeai.backend.model_management.models.lora import LoRALayer, LoRAModelRaw
from invokeai.backend.lora import LoRALayer, LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher
@pytest.mark.parametrize(

View File

@ -1,7 +1,7 @@
import pytest
from invokeai.backend.model_management.libc_util import Struct_mallinfo2
from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_manager.util.libc_util import Struct_mallinfo2
def test_memory_snapshot_capture():
@ -26,6 +26,7 @@ snapshots = [
def test_get_pretty_snapshot_diff(snapshot_1, snapshot_2):
"""Test that get_pretty_snapshot_diff() works with various combinations of missing MemorySnapshot fields."""
msg = get_pretty_snapshot_diff(snapshot_1, snapshot_2)
print(msg)
expected_lines = 0
if snapshot_1 is not None and snapshot_2 is not None:

View File

@ -1,7 +1,7 @@
import pytest
import torch
from invokeai.backend.model_management.model_load_optimizations import _no_op, skip_torch_weight_init
from invokeai.backend.model_manager.load.optimizations import _no_op, skip_torch_weight_init
@pytest.mark.parametrize(

View File

@ -192,6 +192,7 @@ def sdxl_base_files() -> List[Path]:
"text_encoder/model.onnx",
"text_encoder_2/config.json",
"text_encoder_2/model.onnx",
"text_encoder_2/model.onnx_data",
"tokenizer/merges.txt",
"tokenizer/special_tokens_map.json",
"tokenizer/tokenizer_config.json",
@ -202,6 +203,7 @@ def sdxl_base_files() -> List[Path]:
"tokenizer_2/vocab.json",
"unet/config.json",
"unet/model.onnx",
"unet/model.onnx_data",
"vae_decoder/config.json",
"vae_decoder/model.onnx",
"vae_encoder/config.json",

View File

@ -1,6 +1,7 @@
"""
Test interaction of logging with configuration system.
"""
import io
import logging
import re