mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'ryan/remove-attention-map-saving' into ryan/regional-conditioning
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager import BaseModelType, ModelType, SubModelType
|
||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher
|
||||
from invokeai.backend.util.test_utils import install_and_load_model
|
||||
|
||||
|
30
tests/backend/model_manager/model_loading/test_model_load.py
Normal file
30
tests/backend/model_manager/model_loading/test_model_load.py
Normal file
@ -0,0 +1,30 @@
|
||||
"""
|
||||
Test model loading
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.model_manager import ModelManagerServiceBase
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
|
||||
|
||||
def test_loading(mm2_model_manager: ModelManagerServiceBase, embedding_file: Path):
|
||||
store = mm2_model_manager.store
|
||||
matches = store.search_by_attr(model_name="test_embedding")
|
||||
assert len(matches) == 0
|
||||
key = mm2_model_manager.install.register_path(embedding_file)
|
||||
loaded_model = mm2_model_manager.load_model_by_config(store.get_model(key))
|
||||
assert loaded_model is not None
|
||||
assert loaded_model.config.key == key
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, TextualInversionModelRaw)
|
||||
loaded_model_2 = mm2_model_manager.load_model_by_key(key)
|
||||
assert loaded_model.config.key == loaded_model_2.config.key
|
||||
|
||||
loaded_model_3 = mm2_model_manager.load_model_by_attr(
|
||||
model_name=loaded_model.config.name,
|
||||
model_type=loaded_model.config.type,
|
||||
base_model=loaded_model.config.base,
|
||||
)
|
||||
assert loaded_model.config.key == loaded_model_3.config.key
|
@ -2,27 +2,32 @@
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from pytest import FixtureRequest
|
||||
from requests.sessions import Session
|
||||
from requests_testadapter import TestAdapter, TestSession
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadQueueService
|
||||
from invokeai.app.services.download import DownloadQueueService, DownloadQueueServiceBase
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
|
||||
from invokeai.app.services.model_records import ModelRecordServiceSQL
|
||||
from invokeai.app.services.model_load import ModelLoadService, ModelLoadServiceBase
|
||||
from invokeai.app.services.model_manager import ModelManagerService, ModelManagerServiceBase
|
||||
from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
from invokeai.backend.model_manager.config import (
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata import ModelMetadataStore
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.backend.model_manager_2.model_metadata.metadata_examples import (
|
||||
from tests.backend.model_manager.model_metadata.metadata_examples import (
|
||||
RepoCivitaiModelMetadata1,
|
||||
RepoCivitaiVersionMetadata1,
|
||||
RepoHFMetadata1,
|
||||
@ -85,15 +90,77 @@ def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig:
|
||||
app_config = InvokeAIAppConfig(
|
||||
root=mm2_root_dir,
|
||||
models_dir=mm2_root_dir / "models",
|
||||
log_level="info",
|
||||
)
|
||||
return app_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL:
|
||||
def mm2_download_queue(mm2_session: Session, request: FixtureRequest) -> DownloadQueueServiceBase:
|
||||
download_queue = DownloadQueueService(requests_session=mm2_session)
|
||||
download_queue.start()
|
||||
|
||||
def stop_queue() -> None:
|
||||
download_queue.stop()
|
||||
|
||||
request.addfinalizer(stop_queue)
|
||||
return download_queue
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase:
|
||||
return mm2_record_store.metadata_store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
|
||||
ram_cache = ModelCache(
|
||||
logger=InvokeAILogger.get_logger(),
|
||||
max_cache_size=mm2_app_config.ram_cache_size,
|
||||
max_vram_cache_size=mm2_app_config.vram_cache_size,
|
||||
)
|
||||
convert_cache = ModelConvertCache(mm2_app_config.models_convert_cache_path)
|
||||
return ModelLoadService(
|
||||
app_config=mm2_app_config,
|
||||
ram_cache=ram_cache,
|
||||
convert_cache=convert_cache,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_installer(
|
||||
mm2_app_config: InvokeAIAppConfig,
|
||||
mm2_download_queue: DownloadQueueServiceBase,
|
||||
mm2_session: Session,
|
||||
request: FixtureRequest,
|
||||
) -> ModelInstallServiceBase:
|
||||
logger = InvokeAILogger.get_logger()
|
||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||
events = DummyEventService()
|
||||
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||
|
||||
installer = ModelInstallService(
|
||||
app_config=mm2_app_config,
|
||||
record_store=store,
|
||||
download_queue=mm2_download_queue,
|
||||
event_bus=events,
|
||||
session=mm2_session,
|
||||
)
|
||||
installer.start()
|
||||
|
||||
def stop_installer() -> None:
|
||||
installer.stop()
|
||||
time.sleep(0.1) # avoid error message from the logger when it is closed before thread prints final message
|
||||
|
||||
request.addfinalizer(stop_installer)
|
||||
return installer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
|
||||
logger = InvokeAILogger.get_logger(config=mm2_app_config)
|
||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||
store = ModelRecordServiceSQL(db)
|
||||
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||
# add five simple config records to the database
|
||||
raw1 = {
|
||||
"path": "/tmp/foo1",
|
||||
@ -152,15 +219,16 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceSQL
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStore:
|
||||
db = mm2_record_store._db # to ensure we are sharing the same database
|
||||
return ModelMetadataStore(db)
|
||||
def mm2_model_manager(
|
||||
mm2_record_store: ModelRecordServiceBase, mm2_installer: ModelInstallServiceBase, mm2_loader: ModelLoadServiceBase
|
||||
) -> ModelManagerServiceBase:
|
||||
return ModelManagerService(store=mm2_record_store, install=mm2_installer, load=mm2_loader)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
|
||||
"""This fixtures defines a series of mock URLs for testing download and installation."""
|
||||
sess = TestSession()
|
||||
sess: Session = TestSession()
|
||||
sess.mount(
|
||||
"https://test.com/missing_model.safetensors",
|
||||
TestAdapter(
|
||||
@ -240,26 +308,3 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
|
||||
),
|
||||
)
|
||||
return sess
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_installer(mm2_app_config: InvokeAIAppConfig, mm2_session: Session) -> ModelInstallServiceBase:
|
||||
logger = InvokeAILogger.get_logger()
|
||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||
events = DummyEventService()
|
||||
store = ModelRecordServiceSQL(db)
|
||||
metadata_store = ModelMetadataStore(db)
|
||||
|
||||
download_queue = DownloadQueueService(requests_session=mm2_session)
|
||||
download_queue.start()
|
||||
|
||||
installer = ModelInstallService(
|
||||
app_config=mm2_app_config,
|
||||
record_store=store,
|
||||
download_queue=download_queue,
|
||||
metadata_store=metadata_store,
|
||||
event_bus=events,
|
||||
session=mm2_session,
|
||||
)
|
||||
installer.start()
|
||||
return installer
|
File diff suppressed because one or more lines are too long
@ -1,6 +1,7 @@
|
||||
"""
|
||||
Test model metadata fetching and storage.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
|
||||
@ -8,6 +9,7 @@ import pytest
|
||||
from pydantic.networks import HttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from invokeai.app.services.model_metadata import ModelMetadataStoreBase
|
||||
from invokeai.backend.model_manager.config import ModelRepoVariant
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
CivitaiMetadata,
|
||||
@ -15,14 +17,13 @@ from invokeai.backend.model_manager.metadata import (
|
||||
CommercialUsage,
|
||||
HuggingFaceMetadata,
|
||||
HuggingFaceMetadataFetch,
|
||||
ModelMetadataStore,
|
||||
UnknownMetadataException,
|
||||
)
|
||||
from invokeai.backend.model_manager.util import select_hf_files
|
||||
from tests.backend.model_manager_2.model_manager_2_fixtures import * # noqa F403
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
|
||||
|
||||
def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStore) -> None:
|
||||
def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStoreBase) -> None:
|
||||
tags = {"text-to-image", "diffusers"}
|
||||
input_metadata = HuggingFaceMetadata(
|
||||
name="sdxl-vae",
|
||||
@ -40,7 +41,7 @@ def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStore) -> None:
|
||||
assert mm2_metadata_store.list_tags() == tags
|
||||
|
||||
|
||||
def test_metadata_store_update(mm2_metadata_store: ModelMetadataStore) -> None:
|
||||
def test_metadata_store_update(mm2_metadata_store: ModelMetadataStoreBase) -> None:
|
||||
input_metadata = HuggingFaceMetadata(
|
||||
name="sdxl-vae",
|
||||
author="stabilityai",
|
||||
@ -57,7 +58,7 @@ def test_metadata_store_update(mm2_metadata_store: ModelMetadataStore) -> None:
|
||||
assert input_metadata == output_metadata
|
||||
|
||||
|
||||
def test_metadata_search(mm2_metadata_store: ModelMetadataStore) -> None:
|
||||
def test_metadata_search(mm2_metadata_store: ModelMetadataStoreBase) -> None:
|
||||
metadata1 = HuggingFaceMetadata(
|
||||
name="sdxl-vae",
|
||||
author="stabilityai",
|
||||
@ -133,7 +134,7 @@ def test_metadata_civitai_fetch(mm2_session: Session) -> None:
|
||||
assert metadata.id == 215485
|
||||
assert metadata.author == "test_author" # note that this is not the same as the original from Civitai
|
||||
assert metadata.allow_commercial_use # changed to make sure we are reading locally not remotely
|
||||
assert metadata.restrictions.AllowCommercialUse == CommercialUsage("RentCivit")
|
||||
assert CommercialUsage("RentCivit") in metadata.restrictions.AllowCommercialUse
|
||||
assert metadata.version_id == 242807
|
||||
assert metadata.tags == {"tool", "turbo", "sdxl turbo"}
|
||||
|
@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from invokeai.backend.model_management.libc_util import LibcUtil, Struct_mallinfo2
|
||||
from invokeai.backend.model_manager.util.libc_util import LibcUtil, Struct_mallinfo2
|
||||
|
||||
|
||||
def test_libc_util_mallinfo2():
|
@ -5,8 +5,8 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_management.lora import ModelPatcher
|
||||
from invokeai.backend.model_management.models.lora import LoRALayer, LoRAModelRaw
|
||||
from invokeai.backend.lora import LoRALayer, LoRAModelRaw
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from invokeai.backend.model_management.libc_util import Struct_mallinfo2
|
||||
from invokeai.backend.model_management.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
from invokeai.backend.model_manager.util.libc_util import Struct_mallinfo2
|
||||
|
||||
|
||||
def test_memory_snapshot_capture():
|
||||
@ -26,6 +26,7 @@ snapshots = [
|
||||
def test_get_pretty_snapshot_diff(snapshot_1, snapshot_2):
|
||||
"""Test that get_pretty_snapshot_diff() works with various combinations of missing MemorySnapshot fields."""
|
||||
msg = get_pretty_snapshot_diff(snapshot_1, snapshot_2)
|
||||
print(msg)
|
||||
|
||||
expected_lines = 0
|
||||
if snapshot_1 is not None and snapshot_2 is not None:
|
@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_management.model_load_optimizations import _no_op, skip_torch_weight_init
|
||||
from invokeai.backend.model_manager.load.optimizations import _no_op, skip_torch_weight_init
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
@ -192,6 +192,7 @@ def sdxl_base_files() -> List[Path]:
|
||||
"text_encoder/model.onnx",
|
||||
"text_encoder_2/config.json",
|
||||
"text_encoder_2/model.onnx",
|
||||
"text_encoder_2/model.onnx_data",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
@ -202,6 +203,7 @@ def sdxl_base_files() -> List[Path]:
|
||||
"tokenizer_2/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/model.onnx",
|
||||
"unet/model.onnx_data",
|
||||
"vae_decoder/config.json",
|
||||
"vae_decoder/model.onnx",
|
||||
"vae_encoder/config.json",
|
@ -1,6 +1,7 @@
|
||||
"""
|
||||
Test interaction of logging with configuration system.
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import re
|
||||
|
Reference in New Issue
Block a user