mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Tidy names and locations of modules
- Rename old "model_management" directory to "model_management_OLD" in order to catch dangling references to original model manager. - Caught and fixed most dangling references (still checking) - Rename lora, textual_inversion and model_patcher modules - Introduce a RawModel base class to simplfy the Union returned by the model loaders. - Tidy up the model manager 2-related tests. Add useful fixtures, and a finalizer to the queue and installer fixtures that will stop the services and release threads.
This commit is contained in:
committed by
psychedelicious
parent
ba1f8878dd
commit
2ad0752582
1
tests/backend/model_manager/data/invokeai_root/README
Normal file
1
tests/backend/model_manager/data/invokeai_root/README
Normal file
@ -0,0 +1 @@
|
||||
This is an empty invokeai root that is used as a template for model manager tests.
|
@ -0,0 +1,79 @@
|
||||
model:
|
||||
base_learning_rate: 1.0e-04
|
||||
target: invokeai.backend.models.diffusion.ddpm.LatentDiffusion
|
||||
params:
|
||||
linear_start: 0.00085
|
||||
linear_end: 0.0120
|
||||
num_timesteps_cond: 1
|
||||
log_every_t: 200
|
||||
timesteps: 1000
|
||||
first_stage_key: "jpg"
|
||||
cond_stage_key: "txt"
|
||||
image_size: 64
|
||||
channels: 4
|
||||
cond_stage_trainable: false # Note: different from the one we trained before
|
||||
conditioning_key: crossattn
|
||||
monitor: val/loss_simple_ema
|
||||
scale_factor: 0.18215
|
||||
use_ema: False
|
||||
|
||||
scheduler_config: # 10000 warmup steps
|
||||
target: invokeai.backend.stable_diffusion.lr_scheduler.LambdaLinearScheduler
|
||||
params:
|
||||
warm_up_steps: [ 10000 ]
|
||||
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||
f_start: [ 1.e-6 ]
|
||||
f_max: [ 1. ]
|
||||
f_min: [ 1. ]
|
||||
|
||||
personalization_config:
|
||||
target: invokeai.backend.stable_diffusion.embedding_manager.EmbeddingManager
|
||||
params:
|
||||
placeholder_strings: ["*"]
|
||||
initializer_words: ['sculpture']
|
||||
per_image_tokens: false
|
||||
num_vectors_per_token: 1
|
||||
progressive_words: False
|
||||
|
||||
unet_config:
|
||||
target: invokeai.backend.stable_diffusion.diffusionmodules.openaimodel.UNetModel
|
||||
params:
|
||||
image_size: 32 # unused
|
||||
in_channels: 4
|
||||
out_channels: 4
|
||||
model_channels: 320
|
||||
attention_resolutions: [ 4, 2, 1 ]
|
||||
num_res_blocks: 2
|
||||
channel_mult: [ 1, 2, 4, 4 ]
|
||||
num_heads: 8
|
||||
use_spatial_transformer: True
|
||||
transformer_depth: 1
|
||||
context_dim: 768
|
||||
use_checkpoint: True
|
||||
legacy: False
|
||||
|
||||
first_stage_config:
|
||||
target: invokeai.backend.stable_diffusion.autoencoder.AutoencoderKL
|
||||
params:
|
||||
embed_dim: 4
|
||||
monitor: val/rec_loss
|
||||
ddconfig:
|
||||
double_z: true
|
||||
z_channels: 4
|
||||
resolution: 256
|
||||
in_channels: 3
|
||||
out_ch: 3
|
||||
ch: 128
|
||||
ch_mult:
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
- 4
|
||||
num_res_blocks: 2
|
||||
attn_resolutions: []
|
||||
dropout: 0.0
|
||||
lossconfig:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: invokeai.backend.stable_diffusion.encoders.modules.WeightedFrozenCLIPEmbedder
|
@ -0,0 +1 @@
|
||||
This is a template empty invokeai root directory used to test model management.
|
@ -0,0 +1 @@
|
||||
This is a template empty invokeai root directory used to test model management.
|
@ -0,0 +1,34 @@
|
||||
{
|
||||
"_class_name": "StableDiffusionXLPipeline",
|
||||
"_diffusers_version": "0.23.0",
|
||||
"_name_or_path": "stabilityai/sdxl-turbo",
|
||||
"force_zeros_for_empty_prompt": true,
|
||||
"scheduler": [
|
||||
"diffusers",
|
||||
"EulerAncestralDiscreteScheduler"
|
||||
],
|
||||
"text_encoder": [
|
||||
"transformers",
|
||||
"CLIPTextModel"
|
||||
],
|
||||
"text_encoder_2": [
|
||||
"transformers",
|
||||
"CLIPTextModelWithProjection"
|
||||
],
|
||||
"tokenizer": [
|
||||
"transformers",
|
||||
"CLIPTokenizer"
|
||||
],
|
||||
"tokenizer_2": [
|
||||
"transformers",
|
||||
"CLIPTokenizer"
|
||||
],
|
||||
"unet": [
|
||||
"diffusers",
|
||||
"UNet2DConditionModel"
|
||||
],
|
||||
"vae": [
|
||||
"diffusers",
|
||||
"AutoencoderKL"
|
||||
]
|
||||
}
|
@ -0,0 +1,17 @@
|
||||
{
|
||||
"_class_name": "EulerAncestralDiscreteScheduler",
|
||||
"_diffusers_version": "0.23.0",
|
||||
"beta_end": 0.012,
|
||||
"beta_schedule": "scaled_linear",
|
||||
"beta_start": 0.00085,
|
||||
"clip_sample": false,
|
||||
"interpolation_type": "linear",
|
||||
"num_train_timesteps": 1000,
|
||||
"prediction_type": "epsilon",
|
||||
"sample_max_value": 1.0,
|
||||
"set_alpha_to_one": false,
|
||||
"skip_prk_steps": true,
|
||||
"steps_offset": 1,
|
||||
"timestep_spacing": "trailing",
|
||||
"trained_betas": null
|
||||
}
|
@ -0,0 +1,25 @@
|
||||
{
|
||||
"_name_or_path": "/home/lstein/.cache/huggingface/hub/models--stabilityai--sdxl-turbo/snapshots/fbda35297a8280789ffe2e25206800702fa5c4c1/text_encoder",
|
||||
"architectures": [
|
||||
"CLIPTextModel"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 0,
|
||||
"dropout": 0.0,
|
||||
"eos_token_id": 2,
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 768,
|
||||
"initializer_factor": 1.0,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"layer_norm_eps": 1e-05,
|
||||
"max_position_embeddings": 77,
|
||||
"model_type": "clip_text_model",
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12,
|
||||
"pad_token_id": 1,
|
||||
"projection_dim": 768,
|
||||
"torch_dtype": "float16",
|
||||
"transformers_version": "4.35.0",
|
||||
"vocab_size": 49408
|
||||
}
|
@ -0,0 +1,25 @@
|
||||
{
|
||||
"_name_or_path": "/home/lstein/.cache/huggingface/hub/models--stabilityai--sdxl-turbo/snapshots/fbda35297a8280789ffe2e25206800702fa5c4c1/text_encoder_2",
|
||||
"architectures": [
|
||||
"CLIPTextModelWithProjection"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 0,
|
||||
"dropout": 0.0,
|
||||
"eos_token_id": 2,
|
||||
"hidden_act": "gelu",
|
||||
"hidden_size": 1280,
|
||||
"initializer_factor": 1.0,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 5120,
|
||||
"layer_norm_eps": 1e-05,
|
||||
"max_position_embeddings": 77,
|
||||
"model_type": "clip_text_model",
|
||||
"num_attention_heads": 20,
|
||||
"num_hidden_layers": 32,
|
||||
"pad_token_id": 1,
|
||||
"projection_dim": 1280,
|
||||
"torch_dtype": "float16",
|
||||
"transformers_version": "4.35.0",
|
||||
"vocab_size": 49408
|
||||
}
|
@ -0,0 +1,30 @@
|
||||
{
|
||||
"bos_token": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
@ -0,0 +1,30 @@
|
||||
{
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"49406": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"49407": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"bos_token": "<|startoftext|>",
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"do_lower_case": true,
|
||||
"eos_token": "<|endoftext|>",
|
||||
"errors": "replace",
|
||||
"model_max_length": 77,
|
||||
"pad_token": "<|endoftext|>",
|
||||
"tokenizer_class": "CLIPTokenizer",
|
||||
"unk_token": "<|endoftext|>"
|
||||
}
|
@ -0,0 +1,30 @@
|
||||
{
|
||||
"bos_token": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "!",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
@ -0,0 +1,38 @@
|
||||
{
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"0": {
|
||||
"content": "!",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"49406": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"49407": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"bos_token": "<|startoftext|>",
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"do_lower_case": true,
|
||||
"eos_token": "<|endoftext|>",
|
||||
"errors": "replace",
|
||||
"model_max_length": 77,
|
||||
"pad_token": "!",
|
||||
"tokenizer_class": "CLIPTokenizer",
|
||||
"unk_token": "<|endoftext|>"
|
||||
}
|
@ -0,0 +1,73 @@
|
||||
{
|
||||
"_class_name": "UNet2DConditionModel",
|
||||
"_diffusers_version": "0.23.0",
|
||||
"_name_or_path": "/home/lstein/.cache/huggingface/hub/models--stabilityai--sdxl-turbo/snapshots/fbda35297a8280789ffe2e25206800702fa5c4c1/unet",
|
||||
"act_fn": "silu",
|
||||
"addition_embed_type": "text_time",
|
||||
"addition_embed_type_num_heads": 64,
|
||||
"addition_time_embed_dim": 256,
|
||||
"attention_head_dim": [
|
||||
5,
|
||||
10,
|
||||
20
|
||||
],
|
||||
"attention_type": "default",
|
||||
"block_out_channels": [
|
||||
320,
|
||||
640,
|
||||
1280
|
||||
],
|
||||
"center_input_sample": false,
|
||||
"class_embed_type": null,
|
||||
"class_embeddings_concat": false,
|
||||
"conv_in_kernel": 3,
|
||||
"conv_out_kernel": 3,
|
||||
"cross_attention_dim": 2048,
|
||||
"cross_attention_norm": null,
|
||||
"down_block_types": [
|
||||
"DownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D"
|
||||
],
|
||||
"downsample_padding": 1,
|
||||
"dropout": 0.0,
|
||||
"dual_cross_attention": false,
|
||||
"encoder_hid_dim": null,
|
||||
"encoder_hid_dim_type": null,
|
||||
"flip_sin_to_cos": true,
|
||||
"freq_shift": 0,
|
||||
"in_channels": 4,
|
||||
"layers_per_block": 2,
|
||||
"mid_block_only_cross_attention": null,
|
||||
"mid_block_scale_factor": 1,
|
||||
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
||||
"norm_eps": 1e-05,
|
||||
"norm_num_groups": 32,
|
||||
"num_attention_heads": null,
|
||||
"num_class_embeds": null,
|
||||
"only_cross_attention": false,
|
||||
"out_channels": 4,
|
||||
"projection_class_embeddings_input_dim": 2816,
|
||||
"resnet_out_scale_factor": 1.0,
|
||||
"resnet_skip_time_act": false,
|
||||
"resnet_time_scale_shift": "default",
|
||||
"reverse_transformer_layers_per_block": null,
|
||||
"sample_size": 64,
|
||||
"time_cond_proj_dim": null,
|
||||
"time_embedding_act_fn": null,
|
||||
"time_embedding_dim": null,
|
||||
"time_embedding_type": "positional",
|
||||
"timestep_post_act": null,
|
||||
"transformer_layers_per_block": [
|
||||
1,
|
||||
2,
|
||||
10
|
||||
],
|
||||
"up_block_types": [
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"UpBlock2D"
|
||||
],
|
||||
"upcast_attention": null,
|
||||
"use_linear_projection": true
|
||||
}
|
@ -0,0 +1,32 @@
|
||||
{
|
||||
"_class_name": "AutoencoderKL",
|
||||
"_diffusers_version": "0.23.0",
|
||||
"_name_or_path": "/home/lstein/.cache/huggingface/hub/models--stabilityai--sdxl-turbo/snapshots/fbda35297a8280789ffe2e25206800702fa5c4c1/vae",
|
||||
"act_fn": "silu",
|
||||
"block_out_channels": [
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
512
|
||||
],
|
||||
"down_block_types": [
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D",
|
||||
"DownEncoderBlock2D"
|
||||
],
|
||||
"force_upcast": true,
|
||||
"in_channels": 3,
|
||||
"latent_channels": 4,
|
||||
"layers_per_block": 2,
|
||||
"norm_num_groups": 32,
|
||||
"out_channels": 3,
|
||||
"sample_size": 1024,
|
||||
"scaling_factor": 0.13025,
|
||||
"up_block_types": [
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D",
|
||||
"UpDecoderBlock2D"
|
||||
]
|
||||
}
|
Binary file not shown.
21
tests/backend/model_manager/model_loading/test_model_load.py
Normal file
21
tests/backend/model_manager/model_loading/test_model_load.py
Normal file
@ -0,0 +1,21 @@
|
||||
"""
|
||||
Test model loading
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.model_install import ModelInstallServiceBase
|
||||
from invokeai.app.services.model_load import ModelLoadServiceBase
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
|
||||
def test_loading(mm2_installer: ModelInstallServiceBase, mm2_loader: ModelLoadServiceBase, embedding_file: Path):
|
||||
store = mm2_installer.record_store
|
||||
matches = store.search_by_attr(model_name="test_embedding")
|
||||
assert len(matches) == 0
|
||||
key = mm2_installer.register_path(embedding_file)
|
||||
loaded_model = mm2_loader.load_model_by_config(store.get_model(key))
|
||||
assert loaded_model is not None
|
||||
assert loaded_model.config.key == key
|
||||
with loaded_model as model:
|
||||
assert isinstance(model, TextualInversionModelRaw)
|
310
tests/backend/model_manager/model_manager_fixtures.py
Normal file
310
tests/backend/model_manager/model_manager_fixtures.py
Normal file
@ -0,0 +1,310 @@
|
||||
# Fixtures to support testing of the model_manager v2 installer, metadata and record store
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
from pytest import FixtureRequest
|
||||
from pydantic import BaseModel
|
||||
from requests.sessions import Session
|
||||
from requests_testadapter import TestAdapter, TestSession
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadQueueServiceBase, DownloadQueueService
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.model_manager import ModelManagerServiceBase, ModelManagerService
|
||||
from invokeai.app.services.model_load import ModelLoadServiceBase, ModelLoadService
|
||||
from invokeai.app.services.model_install import ModelInstallService, ModelInstallServiceBase
|
||||
from invokeai.app.services.model_metadata import ModelMetadataStoreBase, ModelMetadataStoreSQL
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
from invokeai.backend.model_manager.config import (
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.backend.model_manager.model_metadata.metadata_examples import (
|
||||
RepoCivitaiModelMetadata1,
|
||||
RepoCivitaiVersionMetadata1,
|
||||
RepoHFMetadata1,
|
||||
RepoHFMetadata1_nofp16,
|
||||
RepoHFModelJson1,
|
||||
)
|
||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
||||
|
||||
|
||||
class DummyEvent(BaseModel):
|
||||
"""Dummy Event to use with Dummy Event service."""
|
||||
|
||||
event_name: str
|
||||
payload: Dict[str, Any]
|
||||
|
||||
|
||||
class DummyEventService(EventServiceBase):
|
||||
"""Dummy event service for testing."""
|
||||
|
||||
events: List[DummyEvent]
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.events = []
|
||||
|
||||
def dispatch(self, event_name: str, payload: Any) -> None:
|
||||
"""Dispatch an event by appending it to self.events."""
|
||||
self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"]))
|
||||
|
||||
|
||||
# Create a temporary directory using the contents of `./data/invokeai_root` as the template
|
||||
@pytest.fixture
|
||||
def mm2_root_dir(tmp_path_factory) -> Path:
|
||||
root_template = Path(__file__).resolve().parent / "data" / "invokeai_root"
|
||||
temp_dir: Path = tmp_path_factory.mktemp("data") / "invokeai_root"
|
||||
shutil.copytree(root_template, temp_dir)
|
||||
return temp_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_model_files(tmp_path_factory) -> Path:
|
||||
root_template = Path(__file__).resolve().parent / "data" / "test_files"
|
||||
temp_dir: Path = tmp_path_factory.mktemp("data") / "test_files"
|
||||
shutil.copytree(root_template, temp_dir)
|
||||
return temp_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def embedding_file(mm2_model_files: Path) -> Path:
|
||||
return mm2_model_files / "test_embedding.safetensors"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def diffusers_dir(mm2_model_files: Path) -> Path:
|
||||
return mm2_model_files / "test-diffusers-main"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_app_config(mm2_root_dir: Path) -> InvokeAIAppConfig:
|
||||
app_config = InvokeAIAppConfig(
|
||||
root=mm2_root_dir,
|
||||
models_dir=mm2_root_dir / "models",
|
||||
log_level="info",
|
||||
)
|
||||
return app_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_download_queue(mm2_session: Session,
|
||||
request: FixtureRequest
|
||||
) -> DownloadQueueServiceBase:
|
||||
download_queue = DownloadQueueService(requests_session=mm2_session)
|
||||
download_queue.start()
|
||||
|
||||
def stop_queue() -> None:
|
||||
download_queue.stop()
|
||||
|
||||
request.addfinalizer(stop_queue)
|
||||
return download_queue
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_metadata_store(mm2_record_store: ModelRecordServiceSQL) -> ModelMetadataStoreBase:
|
||||
return mm2_record_store.metadata_store
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_loader(mm2_app_config: InvokeAIAppConfig, mm2_record_store: ModelRecordServiceBase) -> ModelLoadServiceBase:
|
||||
ram_cache = ModelCache(
|
||||
logger=InvokeAILogger.get_logger(),
|
||||
max_cache_size=mm2_app_config.ram_cache_size,
|
||||
max_vram_cache_size=mm2_app_config.vram_cache_size
|
||||
)
|
||||
convert_cache = ModelConvertCache(mm2_app_config.models_convert_cache_path)
|
||||
return ModelLoadService(app_config=mm2_app_config,
|
||||
record_store=mm2_record_store,
|
||||
ram_cache=ram_cache,
|
||||
convert_cache=convert_cache,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_installer(mm2_app_config: InvokeAIAppConfig,
|
||||
mm2_download_queue: DownloadQueueServiceBase,
|
||||
mm2_session: Session,
|
||||
request: FixtureRequest,
|
||||
) -> ModelInstallServiceBase:
|
||||
logger = InvokeAILogger.get_logger()
|
||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||
events = DummyEventService()
|
||||
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||
|
||||
installer = ModelInstallService(
|
||||
app_config=mm2_app_config,
|
||||
record_store=store,
|
||||
download_queue=mm2_download_queue,
|
||||
event_bus=events,
|
||||
session=mm2_session,
|
||||
)
|
||||
installer.start()
|
||||
|
||||
def stop_installer() -> None:
|
||||
installer.stop()
|
||||
|
||||
request.addfinalizer(stop_installer)
|
||||
return installer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBase:
|
||||
logger = InvokeAILogger.get_logger(config=mm2_app_config)
|
||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||
store = ModelRecordServiceSQL(db, ModelMetadataStoreSQL(db))
|
||||
# add five simple config records to the database
|
||||
raw1 = {
|
||||
"path": "/tmp/foo1",
|
||||
"format": ModelFormat("diffusers"),
|
||||
"name": "test2",
|
||||
"base": BaseModelType("sd-2"),
|
||||
"type": ModelType("vae"),
|
||||
"original_hash": "111222333444",
|
||||
"source": "stabilityai/sdxl-vae",
|
||||
}
|
||||
raw2 = {
|
||||
"path": "/tmp/foo2.ckpt",
|
||||
"name": "model1",
|
||||
"format": ModelFormat("checkpoint"),
|
||||
"base": BaseModelType("sd-1"),
|
||||
"type": "main",
|
||||
"config": "/tmp/foo.yaml",
|
||||
"variant": "normal",
|
||||
"original_hash": "111222333444",
|
||||
"source": "https://civitai.com/models/206883/split",
|
||||
}
|
||||
raw3 = {
|
||||
"path": "/tmp/foo3",
|
||||
"format": ModelFormat("diffusers"),
|
||||
"name": "test3",
|
||||
"base": BaseModelType("sdxl"),
|
||||
"type": ModelType("main"),
|
||||
"original_hash": "111222333444",
|
||||
"source": "author3/model3",
|
||||
"description": "This is test 3",
|
||||
}
|
||||
raw4 = {
|
||||
"path": "/tmp/foo4",
|
||||
"format": ModelFormat("diffusers"),
|
||||
"name": "test4",
|
||||
"base": BaseModelType("sdxl"),
|
||||
"type": ModelType("lora"),
|
||||
"original_hash": "111222333444",
|
||||
"source": "author4/model4",
|
||||
}
|
||||
raw5 = {
|
||||
"path": "/tmp/foo5",
|
||||
"format": ModelFormat("diffusers"),
|
||||
"name": "test5",
|
||||
"base": BaseModelType("sd-1"),
|
||||
"type": ModelType("lora"),
|
||||
"original_hash": "111222333444",
|
||||
"source": "author4/model5",
|
||||
}
|
||||
store.add_model("test_config_1", raw1)
|
||||
store.add_model("test_config_2", raw2)
|
||||
store.add_model("test_config_3", raw3)
|
||||
store.add_model("test_config_4", raw4)
|
||||
store.add_model("test_config_5", raw5)
|
||||
return store
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_model_manager(mm2_record_store: ModelRecordServiceBase,
|
||||
mm2_installer: ModelInstallServiceBase,
|
||||
mm2_loader: ModelLoadServiceBase) -> ModelManagerServiceBase:
|
||||
return ModelManagerService(
|
||||
store=mm2_record_store,
|
||||
install=mm2_installer,
|
||||
load=mm2_loader
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
|
||||
"""This fixtures defines a series of mock URLs for testing download and installation."""
|
||||
sess: Session = TestSession()
|
||||
sess.mount(
|
||||
"https://test.com/missing_model.safetensors",
|
||||
TestAdapter(
|
||||
b"missing",
|
||||
status=404,
|
||||
),
|
||||
)
|
||||
sess.mount(
|
||||
"https://huggingface.co/api/models/stabilityai/sdxl-turbo",
|
||||
TestAdapter(
|
||||
RepoHFMetadata1,
|
||||
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)},
|
||||
),
|
||||
)
|
||||
sess.mount(
|
||||
"https://huggingface.co/api/models/stabilityai/sdxl-turbo-nofp16",
|
||||
TestAdapter(
|
||||
RepoHFMetadata1_nofp16,
|
||||
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1_nofp16)},
|
||||
),
|
||||
)
|
||||
sess.mount(
|
||||
"https://civitai.com/api/v1/model-versions/242807",
|
||||
TestAdapter(
|
||||
RepoCivitaiVersionMetadata1,
|
||||
headers={
|
||||
"Content-Length": len(RepoCivitaiVersionMetadata1),
|
||||
},
|
||||
),
|
||||
)
|
||||
sess.mount(
|
||||
"https://civitai.com/api/v1/models/215485",
|
||||
TestAdapter(
|
||||
RepoCivitaiModelMetadata1,
|
||||
headers={
|
||||
"Content-Length": len(RepoCivitaiModelMetadata1),
|
||||
},
|
||||
),
|
||||
)
|
||||
sess.mount(
|
||||
"https://huggingface.co/stabilityai/sdxl-turbo/resolve/main/model_index.json",
|
||||
TestAdapter(
|
||||
RepoHFModelJson1,
|
||||
headers={
|
||||
"Content-Length": len(RepoHFModelJson1),
|
||||
},
|
||||
),
|
||||
)
|
||||
with open(embedding_file, "rb") as f:
|
||||
data = f.read() # file is small - just 15K
|
||||
sess.mount(
|
||||
"https://www.test.foo/download/test_embedding.safetensors",
|
||||
TestAdapter(data, headers={"Content-Type": "application/octet-stream", "Content-Length": len(data)}),
|
||||
)
|
||||
sess.mount(
|
||||
"https://huggingface.co/api/models/stabilityai/sdxl-turbo",
|
||||
TestAdapter(
|
||||
RepoHFMetadata1,
|
||||
headers={"Content-Type": "application/json; charset=utf-8", "Content-Length": len(RepoHFMetadata1)},
|
||||
),
|
||||
)
|
||||
for root, _, files in os.walk(diffusers_dir):
|
||||
for name in files:
|
||||
path = Path(root, name)
|
||||
url_base = path.relative_to(diffusers_dir).as_posix()
|
||||
url = f"https://huggingface.co/stabilityai/sdxl-turbo/resolve/main/{url_base}"
|
||||
with open(path, "rb") as f:
|
||||
data = f.read()
|
||||
sess.mount(
|
||||
url,
|
||||
TestAdapter(
|
||||
data,
|
||||
headers={
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Content-Length": len(data),
|
||||
},
|
||||
),
|
||||
)
|
||||
return sess
|
||||
|
||||
|
File diff suppressed because one or more lines are too long
@ -0,0 +1,201 @@
|
||||
"""
|
||||
Test model metadata fetching and storage.
|
||||
"""
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from pydantic.networks import HttpUrl
|
||||
from requests.sessions import Session
|
||||
|
||||
from invokeai.app.services.model_metadata import ModelMetadataStoreBase
|
||||
from invokeai.backend.model_manager.config import ModelRepoVariant
|
||||
from invokeai.backend.model_manager.metadata import (
|
||||
CivitaiMetadata,
|
||||
CivitaiMetadataFetch,
|
||||
CommercialUsage,
|
||||
HuggingFaceMetadata,
|
||||
HuggingFaceMetadataFetch,
|
||||
UnknownMetadataException,
|
||||
)
|
||||
from invokeai.backend.model_manager.util import select_hf_files
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
|
||||
|
||||
def test_metadata_store_put_get(mm2_metadata_store: ModelMetadataStoreBase) -> None:
|
||||
tags = {"text-to-image", "diffusers"}
|
||||
input_metadata = HuggingFaceMetadata(
|
||||
name="sdxl-vae",
|
||||
author="stabilityai",
|
||||
tags=tags,
|
||||
id="stabilityai/sdxl-vae",
|
||||
tag_dict={"license": "other"},
|
||||
last_modified=datetime.datetime.now(),
|
||||
)
|
||||
mm2_metadata_store.add_metadata("test_config_1", input_metadata)
|
||||
output_metadata = mm2_metadata_store.get_metadata("test_config_1")
|
||||
assert input_metadata == output_metadata
|
||||
with pytest.raises(UnknownMetadataException):
|
||||
mm2_metadata_store.add_metadata("unknown_key", input_metadata)
|
||||
assert mm2_metadata_store.list_tags() == tags
|
||||
|
||||
|
||||
def test_metadata_store_update(mm2_metadata_store: ModelMetadataStoreBase) -> None:
|
||||
input_metadata = HuggingFaceMetadata(
|
||||
name="sdxl-vae",
|
||||
author="stabilityai",
|
||||
tags={"text-to-image", "diffusers"},
|
||||
id="stabilityai/sdxl-vae",
|
||||
tag_dict={"license": "other"},
|
||||
last_modified=datetime.datetime.now(),
|
||||
)
|
||||
mm2_metadata_store.add_metadata("test_config_1", input_metadata)
|
||||
input_metadata.name = "new-name"
|
||||
mm2_metadata_store.update_metadata("test_config_1", input_metadata)
|
||||
output_metadata = mm2_metadata_store.get_metadata("test_config_1")
|
||||
assert output_metadata.name == "new-name"
|
||||
assert input_metadata == output_metadata
|
||||
|
||||
|
||||
def test_metadata_search(mm2_metadata_store: ModelMetadataStoreBase) -> None:
|
||||
metadata1 = HuggingFaceMetadata(
|
||||
name="sdxl-vae",
|
||||
author="stabilityai",
|
||||
tags={"text-to-image", "diffusers"},
|
||||
id="stabilityai/sdxl-vae",
|
||||
tag_dict={"license": "other"},
|
||||
last_modified=datetime.datetime.now(),
|
||||
)
|
||||
metadata2 = HuggingFaceMetadata(
|
||||
name="model2",
|
||||
author="stabilityai",
|
||||
tags={"text-to-image", "diffusers", "community-contributed"},
|
||||
id="author2/model2",
|
||||
tag_dict={"license": "other"},
|
||||
last_modified=datetime.datetime.now(),
|
||||
)
|
||||
metadata3 = HuggingFaceMetadata(
|
||||
name="model3",
|
||||
author="author3",
|
||||
tags={"text-to-image", "checkpoint", "community-contributed"},
|
||||
id="author3/model3",
|
||||
tag_dict={"license": "other"},
|
||||
last_modified=datetime.datetime.now(),
|
||||
)
|
||||
mm2_metadata_store.add_metadata("test_config_1", metadata1)
|
||||
mm2_metadata_store.add_metadata("test_config_2", metadata2)
|
||||
mm2_metadata_store.add_metadata("test_config_3", metadata3)
|
||||
|
||||
matches = mm2_metadata_store.search_by_author("stabilityai")
|
||||
assert len(matches) == 2
|
||||
assert "test_config_1" in matches
|
||||
assert "test_config_2" in matches
|
||||
matches = mm2_metadata_store.search_by_author("Sherlock Holmes")
|
||||
assert not matches
|
||||
|
||||
matches = mm2_metadata_store.search_by_name("model3")
|
||||
assert len(matches) == 1
|
||||
assert "test_config_3" in matches
|
||||
|
||||
matches = mm2_metadata_store.search_by_tag({"text-to-image"})
|
||||
assert len(matches) == 3
|
||||
|
||||
matches = mm2_metadata_store.search_by_tag({"text-to-image", "diffusers"})
|
||||
assert len(matches) == 2
|
||||
assert "test_config_1" in matches
|
||||
assert "test_config_2" in matches
|
||||
|
||||
matches = mm2_metadata_store.search_by_tag({"checkpoint", "community-contributed"})
|
||||
assert len(matches) == 1
|
||||
assert "test_config_3" in matches
|
||||
|
||||
# does the tag table update correctly?
|
||||
matches = mm2_metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"})
|
||||
assert not matches
|
||||
assert mm2_metadata_store.list_tags() == {"text-to-image", "diffusers", "community-contributed", "checkpoint"}
|
||||
metadata3.tags.add("licensed-for-commercial-use")
|
||||
mm2_metadata_store.update_metadata("test_config_3", metadata3)
|
||||
assert mm2_metadata_store.list_tags() == {
|
||||
"text-to-image",
|
||||
"diffusers",
|
||||
"community-contributed",
|
||||
"checkpoint",
|
||||
"licensed-for-commercial-use",
|
||||
}
|
||||
matches = mm2_metadata_store.search_by_tag({"checkpoint", "licensed-for-commercial-use"})
|
||||
assert len(matches) == 1
|
||||
|
||||
|
||||
def test_metadata_civitai_fetch(mm2_session: Session) -> None:
|
||||
fetcher = CivitaiMetadataFetch(mm2_session)
|
||||
metadata = fetcher.from_url(HttpUrl("https://civitai.com/models/215485/SDXL-turbo"))
|
||||
assert isinstance(metadata, CivitaiMetadata)
|
||||
assert metadata.id == 215485
|
||||
assert metadata.author == "test_author" # note that this is not the same as the original from Civitai
|
||||
assert metadata.allow_commercial_use # changed to make sure we are reading locally not remotely
|
||||
assert metadata.restrictions.AllowCommercialUse == CommercialUsage("RentCivit")
|
||||
assert metadata.version_id == 242807
|
||||
assert metadata.tags == {"tool", "turbo", "sdxl turbo"}
|
||||
|
||||
|
||||
def test_metadata_hf_fetch(mm2_session: Session) -> None:
|
||||
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
||||
metadata = fetcher.from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||
assert isinstance(metadata, HuggingFaceMetadata)
|
||||
assert metadata.author == "test_author" # this is not the same as the original
|
||||
assert metadata.files
|
||||
assert metadata.tags == {
|
||||
"diffusers",
|
||||
"onnx",
|
||||
"safetensors",
|
||||
"text-to-image",
|
||||
"license:other",
|
||||
"has_space",
|
||||
"diffusers:StableDiffusionXLPipeline",
|
||||
"region:us",
|
||||
}
|
||||
|
||||
|
||||
def test_metadata_hf_filter(mm2_session: Session) -> None:
|
||||
metadata = HuggingFaceMetadataFetch(mm2_session).from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||
assert isinstance(metadata, HuggingFaceMetadata)
|
||||
files = [x.path for x in metadata.files]
|
||||
fp16_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("fp16"))
|
||||
assert Path("sdxl-turbo/text_encoder/model.fp16.safetensors") in fp16_files
|
||||
assert Path("sdxl-turbo/text_encoder/model.safetensors") not in fp16_files
|
||||
|
||||
fp32_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("fp32"))
|
||||
assert Path("sdxl-turbo/text_encoder/model.safetensors") in fp32_files
|
||||
assert Path("sdxl-turbo/text_encoder/model.16.safetensors") not in fp32_files
|
||||
|
||||
onnx_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("onnx"))
|
||||
assert Path("sdxl-turbo/text_encoder/model.onnx") in onnx_files
|
||||
assert Path("sdxl-turbo/text_encoder/model.safetensors") not in onnx_files
|
||||
|
||||
default_files = select_hf_files.filter_files(files)
|
||||
assert Path("sdxl-turbo/text_encoder/model.safetensors") in default_files
|
||||
assert Path("sdxl-turbo/text_encoder/model.16.safetensors") not in default_files
|
||||
|
||||
openvino_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("openvino"))
|
||||
print(openvino_files)
|
||||
assert len(openvino_files) == 0
|
||||
|
||||
flax_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("flax"))
|
||||
print(flax_files)
|
||||
assert not flax_files
|
||||
|
||||
metadata = HuggingFaceMetadataFetch(mm2_session).from_url(
|
||||
HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo-nofp16")
|
||||
)
|
||||
assert isinstance(metadata, HuggingFaceMetadata)
|
||||
files = [x.path for x in metadata.files]
|
||||
filtered_files = select_hf_files.filter_files(files, variant=ModelRepoVariant("fp16"))
|
||||
assert (
|
||||
Path("sdxl-turbo-nofp16/text_encoder/model.safetensors") in filtered_files
|
||||
) # confirm that default is returned
|
||||
assert Path("sdxl-turbo-nofp16/text_encoder/model.16.safetensors") not in filtered_files
|
||||
|
||||
|
||||
def test_metadata_hf_urls(mm2_session: Session) -> None:
|
||||
metadata = HuggingFaceMetadataFetch(mm2_session).from_url(HttpUrl("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||
assert isinstance(metadata, HuggingFaceMetadata)
|
27
tests/backend/model_manager/test_libc_util.py
Normal file
27
tests/backend/model_manager/test_libc_util.py
Normal file
@ -0,0 +1,27 @@
|
||||
import pytest
|
||||
|
||||
from invokeai.backend.model_manager.util.libc_util import LibcUtil, Struct_mallinfo2
|
||||
|
||||
|
||||
def test_libc_util_mallinfo2():
|
||||
"""Smoke test of LibcUtil().mallinfo2()."""
|
||||
try:
|
||||
libc = LibcUtil()
|
||||
except OSError:
|
||||
# TODO: Set the expected result preemptively based on the system properties.
|
||||
pytest.xfail("libc shared library is not available on this system.")
|
||||
|
||||
try:
|
||||
info = libc.mallinfo2()
|
||||
except AttributeError:
|
||||
pytest.xfail("`mallinfo2` is not available on this system, likely due to glibc < 2.33.")
|
||||
|
||||
assert info.arena > 0
|
||||
|
||||
|
||||
def test_struct_mallinfo2_to_str():
|
||||
"""Smoke test of Struct_mallinfo2.__str__()."""
|
||||
info = Struct_mallinfo2()
|
||||
info_str = str(info)
|
||||
|
||||
assert len(info_str) > 0
|
102
tests/backend/model_manager/test_lora.py
Normal file
102
tests/backend/model_manager/test_lora.py
Normal file
@ -0,0 +1,102 @@
|
||||
# test that if the model's device changes while the lora is applied, the weights can still be restored
|
||||
|
||||
# test that LoRA patching works on both CPU and CUDA
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_patcher import ModelPatcher
|
||||
from invokeai.backend.lora import LoRALayer, LoRAModelRaw
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"device",
|
||||
[
|
||||
"cpu",
|
||||
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
||||
],
|
||||
)
|
||||
@torch.no_grad()
|
||||
def test_apply_lora(device):
|
||||
"""Test the basic behavior of ModelPatcher.apply_lora(...). Check that patching and unpatching produce the correct
|
||||
result, and that model/LoRA tensors are moved between devices as expected.
|
||||
"""
|
||||
|
||||
linear_in_features = 4
|
||||
linear_out_features = 8
|
||||
lora_dim = 2
|
||||
model = torch.nn.ModuleDict(
|
||||
{"linear_layer_1": torch.nn.Linear(linear_in_features, linear_out_features, device=device, dtype=torch.float16)}
|
||||
)
|
||||
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer(
|
||||
layer_key="linear_layer_1",
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
)
|
||||
}
|
||||
lora = LoRAModelRaw("lora_name", lora_layers)
|
||||
|
||||
lora_weight = 0.5
|
||||
orig_linear_weight = model["linear_layer_1"].weight.data.detach().clone()
|
||||
expected_patched_linear_weight = orig_linear_weight + (lora_dim * lora_weight)
|
||||
|
||||
with ModelPatcher.apply_lora(model, [(lora, lora_weight)], prefix=""):
|
||||
# After patching, all LoRA layer weights should have been moved back to the cpu.
|
||||
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
|
||||
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
|
||||
|
||||
# After patching, the patched model should still be on its original device.
|
||||
assert model["linear_layer_1"].weight.data.device.type == device
|
||||
|
||||
torch.testing.assert_close(model["linear_layer_1"].weight.data, expected_patched_linear_weight)
|
||||
|
||||
# After unpatching, the original model weights should have been restored on the original device.
|
||||
assert model["linear_layer_1"].weight.data.device.type == device
|
||||
torch.testing.assert_close(model["linear_layer_1"].weight.data, orig_linear_weight)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")
|
||||
@torch.no_grad()
|
||||
def test_apply_lora_change_device():
|
||||
"""Test that if LoRA patching is applied on the CPU, and then the patched model is moved to the GPU, unpatching
|
||||
still behaves correctly.
|
||||
"""
|
||||
linear_in_features = 4
|
||||
linear_out_features = 8
|
||||
lora_dim = 2
|
||||
# Initialize the model on the CPU.
|
||||
model = torch.nn.ModuleDict(
|
||||
{"linear_layer_1": torch.nn.Linear(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)}
|
||||
)
|
||||
|
||||
lora_layers = {
|
||||
"linear_layer_1": LoRALayer(
|
||||
layer_key="linear_layer_1",
|
||||
values={
|
||||
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
|
||||
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
|
||||
},
|
||||
)
|
||||
}
|
||||
lora = LoRAModelRaw("lora_name", lora_layers)
|
||||
|
||||
orig_linear_weight = model["linear_layer_1"].weight.data.detach().clone()
|
||||
|
||||
with ModelPatcher.apply_lora(model, [(lora, 0.5)], prefix=""):
|
||||
# After patching, all LoRA layer weights should have been moved back to the cpu.
|
||||
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
|
||||
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
|
||||
|
||||
# After patching, the patched model should still be on the CPU.
|
||||
assert model["linear_layer_1"].weight.data.device.type == "cpu"
|
||||
|
||||
# Move the model to the GPU.
|
||||
assert model.to("cuda")
|
||||
|
||||
# After unpatching, the original model weights should have been restored on the GPU.
|
||||
assert model["linear_layer_1"].weight.data.device.type == "cuda"
|
||||
torch.testing.assert_close(model["linear_layer_1"].weight.data, orig_linear_weight, check_device=False)
|
38
tests/backend/model_manager/test_memory_snapshot.py
Normal file
38
tests/backend/model_manager/test_memory_snapshot.py
Normal file
@ -0,0 +1,38 @@
|
||||
import pytest
|
||||
|
||||
from invokeai.backend.model_manager.util.libc_util import Struct_mallinfo2
|
||||
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
|
||||
|
||||
def test_memory_snapshot_capture():
|
||||
"""Smoke test of MemorySnapshot.capture()."""
|
||||
snapshot = MemorySnapshot.capture()
|
||||
|
||||
# We just check process_ram, because it is the only field that should be supported on all platforms.
|
||||
assert snapshot.process_ram > 0
|
||||
|
||||
|
||||
snapshots = [
|
||||
MemorySnapshot(process_ram=1, vram=2, malloc_info=Struct_mallinfo2()),
|
||||
MemorySnapshot(process_ram=1, vram=2, malloc_info=None),
|
||||
MemorySnapshot(process_ram=1, vram=None, malloc_info=Struct_mallinfo2()),
|
||||
MemorySnapshot(process_ram=1, vram=None, malloc_info=None),
|
||||
None,
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("snapshot_1", snapshots)
|
||||
@pytest.mark.parametrize("snapshot_2", snapshots)
|
||||
def test_get_pretty_snapshot_diff(snapshot_1, snapshot_2):
|
||||
"""Test that get_pretty_snapshot_diff() works with various combinations of missing MemorySnapshot fields."""
|
||||
msg = get_pretty_snapshot_diff(snapshot_1, snapshot_2)
|
||||
print(msg)
|
||||
|
||||
expected_lines = 0
|
||||
if snapshot_1 is not None and snapshot_2 is not None:
|
||||
expected_lines += 1
|
||||
if snapshot_1.vram is not None and snapshot_2.vram is not None:
|
||||
expected_lines += 1
|
||||
if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None:
|
||||
expected_lines += 5
|
||||
|
||||
assert len(msg.splitlines()) == expected_lines
|
73
tests/backend/model_manager/test_model_load_optimization.py
Normal file
73
tests/backend/model_manager/test_model_load_optimization.py
Normal file
@ -0,0 +1,73 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.optimizations import _no_op, skip_torch_weight_init
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["torch_module", "layer_args"],
|
||||
[
|
||||
(torch.nn.Linear, {"in_features": 10, "out_features": 20}),
|
||||
(torch.nn.Conv1d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
|
||||
(torch.nn.Conv2d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
|
||||
(torch.nn.Conv3d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
|
||||
(torch.nn.Embedding, {"num_embeddings": 10, "embedding_dim": 10}),
|
||||
],
|
||||
)
|
||||
def test_skip_torch_weight_init_linear(torch_module, layer_args):
|
||||
"""Test the interactions between `skip_torch_weight_init()` and various torch modules."""
|
||||
seed = 123
|
||||
|
||||
# Initialize a torch layer *before* applying `skip_torch_weight_init()`.
|
||||
reset_params_fn_before = torch_module.reset_parameters
|
||||
torch.manual_seed(seed)
|
||||
layer_before = torch_module(**layer_args)
|
||||
|
||||
# Initialize a torch layer while `skip_torch_weight_init()` is applied.
|
||||
with skip_torch_weight_init():
|
||||
reset_params_fn_during = torch_module.reset_parameters
|
||||
torch.manual_seed(123)
|
||||
layer_during = torch_module(**layer_args)
|
||||
|
||||
# Initialize a torch layer *after* applying `skip_torch_weight_init()`.
|
||||
reset_params_fn_after = torch_module.reset_parameters
|
||||
torch.manual_seed(123)
|
||||
layer_after = torch_module(**layer_args)
|
||||
|
||||
# Check that reset_parameters is skipped while `skip_torch_weight_init()` is active.
|
||||
assert reset_params_fn_during == _no_op
|
||||
assert not torch.allclose(layer_before.weight, layer_during.weight)
|
||||
if hasattr(layer_before, "bias"):
|
||||
assert not torch.allclose(layer_before.bias, layer_during.bias)
|
||||
|
||||
# Check that the original behavior is restored after `skip_torch_weight_init()` ends.
|
||||
assert reset_params_fn_before is reset_params_fn_after
|
||||
assert torch.allclose(layer_before.weight, layer_after.weight)
|
||||
if hasattr(layer_before, "bias"):
|
||||
assert torch.allclose(layer_before.bias, layer_after.bias)
|
||||
|
||||
|
||||
def test_skip_torch_weight_init_restores_base_class_behavior():
|
||||
"""Test that `skip_torch_weight_init()` correctly restores the original behavior of torch.nn.Conv*d modules. This
|
||||
test was created to catch a previous bug where `reset_parameters` was being copied from the base `_ConvNd` class to
|
||||
its child classes (like `Conv1d`).
|
||||
"""
|
||||
with skip_torch_weight_init():
|
||||
# There is no need to do anything while the context manager is applied, we're just testing that the original
|
||||
# behavior is restored correctly.
|
||||
pass
|
||||
|
||||
# Mock the behavior of another library that monkey patches `torch.nn.modules.conv._ConvNd.reset_parameters` and
|
||||
# expects it to affect all of the sub-classes (e.g. `torch.nn.Conv1D`, `torch.nn.Conv2D`, etc.).
|
||||
called_monkey_patched_fn = False
|
||||
|
||||
def monkey_patched_fn(*args, **kwargs):
|
||||
nonlocal called_monkey_patched_fn
|
||||
called_monkey_patched_fn = True
|
||||
|
||||
saved_fn = torch.nn.modules.conv._ConvNd.reset_parameters
|
||||
torch.nn.modules.conv._ConvNd.reset_parameters = monkey_patched_fn
|
||||
_ = torch.nn.Conv1d(10, 20, 3)
|
||||
torch.nn.modules.conv._ConvNd.reset_parameters = saved_fn
|
||||
|
||||
assert called_monkey_patched_fn
|
241
tests/backend/model_manager/util/test_hf_model_select.py
Normal file
241
tests/backend/model_manager/util/test_hf_model_select.py
Normal file
@ -0,0 +1,241 @@
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.backend.model_manager.config import ModelRepoVariant
|
||||
from invokeai.backend.model_manager.util.select_hf_files import filter_files
|
||||
|
||||
|
||||
# This is the full list of model paths returned by the HF API for sdxl-base
|
||||
@pytest.fixture
|
||||
def sdxl_base_files() -> List[Path]:
|
||||
return [
|
||||
Path(x)
|
||||
for x in [
|
||||
".gitattributes",
|
||||
"01.png",
|
||||
"LICENSE.md",
|
||||
"README.md",
|
||||
"comparison.png",
|
||||
"model_index.json",
|
||||
"pipeline.png",
|
||||
"scheduler/scheduler_config.json",
|
||||
"sd_xl_base_1.0.safetensors",
|
||||
"sd_xl_base_1.0_0.9vae.safetensors",
|
||||
"sd_xl_offset_example-lora_1.0.safetensors",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/flax_model.msgpack",
|
||||
"text_encoder/model.fp16.safetensors",
|
||||
"text_encoder/model.onnx",
|
||||
"text_encoder/model.safetensors",
|
||||
"text_encoder/openvino_model.bin",
|
||||
"text_encoder/openvino_model.xml",
|
||||
"text_encoder_2/config.json",
|
||||
"text_encoder_2/flax_model.msgpack",
|
||||
"text_encoder_2/model.fp16.safetensors",
|
||||
"text_encoder_2/model.onnx",
|
||||
"text_encoder_2/model.onnx_data",
|
||||
"text_encoder_2/model.safetensors",
|
||||
"text_encoder_2/openvino_model.bin",
|
||||
"text_encoder_2/openvino_model.xml",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"tokenizer_2/merges.txt",
|
||||
"tokenizer_2/special_tokens_map.json",
|
||||
"tokenizer_2/tokenizer_config.json",
|
||||
"tokenizer_2/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/diffusion_flax_model.msgpack",
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
"unet/model.onnx",
|
||||
"unet/model.onnx_data",
|
||||
"unet/openvino_model.bin",
|
||||
"unet/openvino_model.xml",
|
||||
"vae/config.json",
|
||||
"vae/diffusion_flax_model.msgpack",
|
||||
"vae/diffusion_pytorch_model.fp16.safetensors",
|
||||
"vae/diffusion_pytorch_model.safetensors",
|
||||
"vae_1_0/config.json",
|
||||
"vae_1_0/diffusion_pytorch_model.fp16.safetensors",
|
||||
"vae_1_0/diffusion_pytorch_model.safetensors",
|
||||
"vae_decoder/config.json",
|
||||
"vae_decoder/model.onnx",
|
||||
"vae_decoder/openvino_model.bin",
|
||||
"vae_decoder/openvino_model.xml",
|
||||
"vae_encoder/config.json",
|
||||
"vae_encoder/model.onnx",
|
||||
"vae_encoder/openvino_model.bin",
|
||||
"vae_encoder/openvino_model.xml",
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
# This are what we expect to get when various diffusers variants are requested
|
||||
@pytest.mark.parametrize(
|
||||
"variant,expected_list",
|
||||
[
|
||||
(
|
||||
None,
|
||||
[
|
||||
"model_index.json",
|
||||
"scheduler/scheduler_config.json",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/model.safetensors",
|
||||
"text_encoder_2/config.json",
|
||||
"text_encoder_2/model.safetensors",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"tokenizer_2/merges.txt",
|
||||
"tokenizer_2/special_tokens_map.json",
|
||||
"tokenizer_2/tokenizer_config.json",
|
||||
"tokenizer_2/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
"vae/config.json",
|
||||
"vae/diffusion_pytorch_model.safetensors",
|
||||
"vae_1_0/config.json",
|
||||
"vae_1_0/diffusion_pytorch_model.safetensors",
|
||||
],
|
||||
),
|
||||
(
|
||||
ModelRepoVariant.DEFAULT,
|
||||
[
|
||||
"model_index.json",
|
||||
"scheduler/scheduler_config.json",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/model.safetensors",
|
||||
"text_encoder_2/config.json",
|
||||
"text_encoder_2/model.safetensors",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"tokenizer_2/merges.txt",
|
||||
"tokenizer_2/special_tokens_map.json",
|
||||
"tokenizer_2/tokenizer_config.json",
|
||||
"tokenizer_2/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
"vae/config.json",
|
||||
"vae/diffusion_pytorch_model.safetensors",
|
||||
"vae_1_0/config.json",
|
||||
"vae_1_0/diffusion_pytorch_model.safetensors",
|
||||
],
|
||||
),
|
||||
(
|
||||
ModelRepoVariant.OPENVINO,
|
||||
[
|
||||
"model_index.json",
|
||||
"scheduler/scheduler_config.json",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/openvino_model.bin",
|
||||
"text_encoder/openvino_model.xml",
|
||||
"text_encoder_2/config.json",
|
||||
"text_encoder_2/openvino_model.bin",
|
||||
"text_encoder_2/openvino_model.xml",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"tokenizer_2/merges.txt",
|
||||
"tokenizer_2/special_tokens_map.json",
|
||||
"tokenizer_2/tokenizer_config.json",
|
||||
"tokenizer_2/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/openvino_model.bin",
|
||||
"unet/openvino_model.xml",
|
||||
"vae_decoder/config.json",
|
||||
"vae_decoder/openvino_model.bin",
|
||||
"vae_decoder/openvino_model.xml",
|
||||
"vae_encoder/config.json",
|
||||
"vae_encoder/openvino_model.bin",
|
||||
"vae_encoder/openvino_model.xml",
|
||||
],
|
||||
),
|
||||
(
|
||||
ModelRepoVariant.FP16,
|
||||
[
|
||||
"model_index.json",
|
||||
"scheduler/scheduler_config.json",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/model.fp16.safetensors",
|
||||
"text_encoder_2/config.json",
|
||||
"text_encoder_2/model.fp16.safetensors",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"tokenizer_2/merges.txt",
|
||||
"tokenizer_2/special_tokens_map.json",
|
||||
"tokenizer_2/tokenizer_config.json",
|
||||
"tokenizer_2/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
"vae/config.json",
|
||||
"vae/diffusion_pytorch_model.fp16.safetensors",
|
||||
"vae_1_0/config.json",
|
||||
"vae_1_0/diffusion_pytorch_model.fp16.safetensors",
|
||||
],
|
||||
),
|
||||
(
|
||||
ModelRepoVariant.ONNX,
|
||||
[
|
||||
"model_index.json",
|
||||
"scheduler/scheduler_config.json",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/model.onnx",
|
||||
"text_encoder_2/config.json",
|
||||
"text_encoder_2/model.onnx",
|
||||
"text_encoder_2/model.onnx_data",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"tokenizer_2/merges.txt",
|
||||
"tokenizer_2/special_tokens_map.json",
|
||||
"tokenizer_2/tokenizer_config.json",
|
||||
"tokenizer_2/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/model.onnx",
|
||||
"unet/model.onnx_data",
|
||||
"vae_decoder/config.json",
|
||||
"vae_decoder/model.onnx",
|
||||
"vae_encoder/config.json",
|
||||
"vae_encoder/model.onnx",
|
||||
],
|
||||
),
|
||||
(
|
||||
ModelRepoVariant.FLAX,
|
||||
[
|
||||
"model_index.json",
|
||||
"scheduler/scheduler_config.json",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/flax_model.msgpack",
|
||||
"text_encoder_2/config.json",
|
||||
"text_encoder_2/flax_model.msgpack",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"tokenizer_2/merges.txt",
|
||||
"tokenizer_2/special_tokens_map.json",
|
||||
"tokenizer_2/tokenizer_config.json",
|
||||
"tokenizer_2/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/diffusion_flax_model.msgpack",
|
||||
"vae/config.json",
|
||||
"vae/diffusion_flax_model.msgpack",
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_select(sdxl_base_files: List[Path], variant: ModelRepoVariant, expected_list: List[Path]) -> None:
|
||||
print(f"testing variant {variant}")
|
||||
filtered_files = filter_files(sdxl_base_files, variant)
|
||||
assert set(filtered_files) == {Path(x) for x in expected_list}
|
Reference in New Issue
Block a user