BREAKING CHANGES: invocations now require model key, not base/type/name

- Implement new model loader and modify invocations and embeddings

- Finish implementation loaders for all models currently supported by
  InvokeAI.

- Move lora, textual_inversion, and model patching support into
  backend/embeddings.

- Restore support for model cache statistics collection (a little ugly,
  needs work).

- Fixed up invocations that load and patch models.

- Move seamless and silencewarnings utils into better location
This commit is contained in:
Lincoln Stein
2024-02-05 22:56:32 -05:00
committed by psychedelicious
parent 5745ce9c7d
commit 78ef946e01
31 changed files with 727 additions and 496 deletions

View File

@ -0,0 +1,4 @@
"""Initialization file for invokeai.backend.embeddings modules."""
# from .model_patcher import ModelPatcher
# __all__ = ["ModelPatcher"]

View File

@ -0,0 +1,12 @@
"""Base class for LoRA and Textual Inversion models.
The EmbeddingRaw class is the base class of LoRAModelRaw and TextualInversionModelRaw,
and is used for type checking of calls to the model patcher.
The use of "Raw" here is a historical artifact, and carried forward in
order to avoid confusion.
"""
class EmbeddingModelRaw:
"""Base class for LoRA and Textual Inversion models."""

View File

@ -11,6 +11,8 @@ from typing_extensions import Self
from invokeai.backend.model_manager import BaseModelType
from .embedding_base import EmbeddingModelRaw
class LoRALayerBase:
# rank: Optional[int]
@ -317,7 +319,7 @@ class FullLayer(LoRALayerBase):
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
) -> None:
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
@ -367,7 +369,7 @@ AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
class LoRAModelRaw: # (torch.nn.Module):
class LoRAModelRaw(EmbeddingModelRaw): # (torch.nn.Module):
_name: str
layers: Dict[str, AnyLoRALayer]
@ -471,16 +473,16 @@ class LoRAModelRaw: # (torch.nn.Module):
file_path = Path(file_path)
model = cls(
name=file_path.stem, # TODO:
name=file_path.stem,
layers={},
)
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
sd = load_file(file_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(file_path, map_location="cpu")
sd = torch.load(file_path, map_location="cpu")
state_dict = cls._group_state(state_dict)
state_dict = cls._group_state(sd)
if base_model == BaseModelType.StableDiffusionXL:
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)

View File

@ -4,22 +4,20 @@ from __future__ import annotations
import pickle
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple
import numpy as np
import torch
from compel.embeddings_provider import BaseTextualInversionManager
from diffusers import ModelMixin, OnnxRuntimeModel, UNet2DConditionModel
from safetensors.torch import load_file
from diffusers import OnnxRuntimeModel, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer
from typing_extensions import Self
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager import AnyModel
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from .lora import LoRAModelRaw
from .textual_inversion import TextualInversionManager, TextualInversionModelRaw
"""
loras = [
@ -67,7 +65,7 @@ class ModelPatcher:
cls,
unet: UNet2DConditionModel,
loras: List[Tuple[LoRAModelRaw, float]],
) -> Generator[None, None, None]:
) -> None:
with cls.apply_lora(unet, loras, "lora_unet_"):
yield
@ -76,8 +74,8 @@ class ModelPatcher:
def apply_lora_text_encoder(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModelRaw, float]],
):
loras: Iterator[Tuple[LoRAModelRaw, float]],
) -> None:
with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield
@ -87,7 +85,7 @@ class ModelPatcher:
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModelRaw, float]],
):
) -> None:
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
yield
@ -97,7 +95,7 @@ class ModelPatcher:
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModelRaw, float]],
):
) -> None:
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
yield
@ -105,10 +103,10 @@ class ModelPatcher:
@contextmanager
def apply_lora(
cls,
model: Union[torch.nn.Module, ModelMixin, UNet2DConditionModel],
loras: List[Tuple[LoRAModelRaw, float]],
model: AnyModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
prefix: str,
) -> Generator[None, None, None]:
) -> None:
original_weights = {}
try:
with torch.no_grad():
@ -125,6 +123,7 @@ class ModelPatcher:
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
# weights to have valid keys.
assert isinstance(model, torch.nn.Module)
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
# All of the LoRA weight calculations will be done on the same device as the module weight.
@ -170,8 +169,8 @@ class ModelPatcher:
cls,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
ti_list: List[Tuple[str, TextualInversionModel]],
) -> Generator[Tuple[CLIPTokenizer, TextualInversionManager], None, None]:
ti_list: List[Tuple[str, TextualInversionModelRaw]],
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
init_tokens_count = None
new_tokens_added = None
@ -201,7 +200,7 @@ class ModelPatcher:
trigger += f"-!pad-{i}"
return f"<{trigger}>"
def _get_ti_embedding(model_embeddings: torch.nn.Module, ti: TextualInversionModel) -> torch.Tensor:
def _get_ti_embedding(model_embeddings: torch.nn.Module, ti: TextualInversionModelRaw) -> torch.Tensor:
# for SDXL models, select the embedding that matches the text encoder's dimensions
if ti.embedding_2 is not None:
return (
@ -229,6 +228,7 @@ class ModelPatcher:
model_embeddings = text_encoder.get_input_embeddings()
for ti_name, ti in ti_list:
assert isinstance(ti, TextualInversionModelRaw)
ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti)
ti_tokens = []
@ -267,7 +267,7 @@ class ModelPatcher:
cls,
text_encoder: CLIPTextModel,
clip_skip: int,
) -> Generator[None, None, None]:
) -> None:
skipped_layers = []
try:
for _i in range(clip_skip):
@ -285,7 +285,7 @@ class ModelPatcher:
cls,
unet: UNet2DConditionModel,
freeu_config: Optional[FreeUConfig] = None,
) -> Generator[None, None, None]:
) -> None:
did_apply_freeu = False
try:
assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute?
@ -301,94 +301,6 @@ class ModelPatcher:
unet.disable_freeu()
class TextualInversionModel:
embedding: torch.Tensor # [n, 768]|[n, 1280]
embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models
@classmethod
def from_checkpoint(
cls,
file_path: Union[str, Path],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> Self:
if not isinstance(file_path, Path):
file_path = Path(file_path)
result = cls() # TODO:
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(file_path, map_location="cpu")
# both v1 and v2 format embeddings
# difference mostly in metadata
if "string_to_param" in state_dict:
if len(state_dict["string_to_param"]) > 1:
print(
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first',
" token will be used.",
)
result.embedding = next(iter(state_dict["string_to_param"].values()))
# v3 (easynegative)
elif "emb_params" in state_dict:
result.embedding = state_dict["emb_params"]
# v5(sdxl safetensors file)
elif "clip_g" in state_dict and "clip_l" in state_dict:
result.embedding = state_dict["clip_g"]
result.embedding_2 = state_dict["clip_l"]
# v4(diffusers bin files)
else:
result.embedding = next(iter(state_dict.values()))
if len(result.embedding.shape) == 1:
result.embedding = result.embedding.unsqueeze(0)
if not isinstance(result.embedding, torch.Tensor):
raise ValueError(f"Invalid embeddings file: {file_path.name}")
return result
# no type hints for BaseTextualInversionManager?
class TextualInversionManager(BaseTextualInversionManager): # type: ignore
pad_tokens: Dict[int, List[int]]
tokenizer: CLIPTokenizer
def __init__(self, tokenizer: CLIPTokenizer):
self.pad_tokens = {}
self.tokenizer = tokenizer
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
if len(self.pad_tokens) == 0:
return token_ids
if token_ids[0] == self.tokenizer.bos_token_id:
raise ValueError("token_ids must not start with bos_token_id")
if token_ids[-1] == self.tokenizer.eos_token_id:
raise ValueError("token_ids must not end with eos_token_id")
new_token_ids = []
for token_id in token_ids:
new_token_ids.append(token_id)
if token_id in self.pad_tokens:
new_token_ids.extend(self.pad_tokens[token_id])
# Do not exceed the max model input size
# The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(),
# which first removes and then adds back the start and end tokens.
max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2
if len(new_token_ids) > max_length:
new_token_ids = new_token_ids[0:max_length]
return new_token_ids
class ONNXModelPatcher:
@classmethod
@contextmanager
@ -396,7 +308,7 @@ class ONNXModelPatcher:
cls,
unet: OnnxRuntimeModel,
loras: List[Tuple[LoRAModelRaw, float]],
) -> Generator[None, None, None]:
) -> None:
with cls.apply_lora(unet, loras, "lora_unet_"):
yield
@ -406,7 +318,7 @@ class ONNXModelPatcher:
cls,
text_encoder: OnnxRuntimeModel,
loras: List[Tuple[LoRAModelRaw, float]],
) -> Generator[None, None, None]:
) -> None:
with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield
@ -419,7 +331,7 @@ class ONNXModelPatcher:
model: IAIOnnxRuntimeModel,
loras: List[Tuple[LoRAModelRaw, float]],
prefix: str,
) -> Generator[None, None, None]:
) -> None:
from .models.base import IAIOnnxRuntimeModel
if not isinstance(model, IAIOnnxRuntimeModel):
@ -506,7 +418,7 @@ class ONNXModelPatcher:
tokenizer: CLIPTokenizer,
text_encoder: IAIOnnxRuntimeModel,
ti_list: List[Tuple[str, Any]],
) -> Generator[Tuple[CLIPTokenizer, TextualInversionManager], None, None]:
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
from .models.base import IAIOnnxRuntimeModel
if not isinstance(text_encoder, IAIOnnxRuntimeModel):

View File

@ -0,0 +1,100 @@
"""Textual Inversion wrapper class."""
from pathlib import Path
from typing import Dict, List, Optional, Union
import torch
from compel.embeddings_provider import BaseTextualInversionManager
from safetensors.torch import load_file
from transformers import CLIPTokenizer
from typing_extensions import Self
from .embedding_base import EmbeddingModelRaw
class TextualInversionModelRaw(EmbeddingModelRaw):
embedding: torch.Tensor # [n, 768]|[n, 1280]
embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models
@classmethod
def from_checkpoint(
cls,
file_path: Union[str, Path],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> Self:
if not isinstance(file_path, Path):
file_path = Path(file_path)
result = cls() # TODO:
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(file_path, map_location="cpu")
# both v1 and v2 format embeddings
# difference mostly in metadata
if "string_to_param" in state_dict:
if len(state_dict["string_to_param"]) > 1:
print(
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first',
" token will be used.",
)
result.embedding = next(iter(state_dict["string_to_param"].values()))
# v3 (easynegative)
elif "emb_params" in state_dict:
result.embedding = state_dict["emb_params"]
# v5(sdxl safetensors file)
elif "clip_g" in state_dict and "clip_l" in state_dict:
result.embedding = state_dict["clip_g"]
result.embedding_2 = state_dict["clip_l"]
# v4(diffusers bin files)
else:
result.embedding = next(iter(state_dict.values()))
if len(result.embedding.shape) == 1:
result.embedding = result.embedding.unsqueeze(0)
if not isinstance(result.embedding, torch.Tensor):
raise ValueError(f"Invalid embeddings file: {file_path.name}")
return result
# no type hints for BaseTextualInversionManager?
class TextualInversionManager(BaseTextualInversionManager): # type: ignore
pad_tokens: Dict[int, List[int]]
tokenizer: CLIPTokenizer
def __init__(self, tokenizer: CLIPTokenizer):
self.pad_tokens = {}
self.tokenizer = tokenizer
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
if len(self.pad_tokens) == 0:
return token_ids
if token_ids[0] == self.tokenizer.bos_token_id:
raise ValueError("token_ids must not start with bos_token_id")
if token_ids[-1] == self.tokenizer.eos_token_id:
raise ValueError("token_ids must not end with eos_token_id")
new_token_ids = []
for token_id in token_ids:
new_token_ids.append(token_id)
if token_id in self.pad_tokens:
new_token_ids.extend(self.pad_tokens[token_id])
# Do not exceed the max model input size
# The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(),
# which first removes and then adds back the start and end tokens.
max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2
if len(new_token_ids) > max_length:
new_token_ids = new_token_ids[0:max_length]
return new_token_ids

View File

@ -241,10 +241,11 @@ class InstallHelper(object):
if match := re.match(f"^([^/]+/[^/]+?)(?::({variants}))?$", model_path_id_or_url):
repo_id = match.group(1)
repo_variant = ModelRepoVariant(match.group(2)) if match.group(2) else None
subfolder = Path(model_info.subfolder) if model_info.subfolder else None
return HFModelSource(
repo_id=repo_id,
access_token=HfFolder.get_token(),
subfolder=model_info.subfolder,
subfolder=subfolder,
variant=repo_variant,
)
if re.match(r"^(http|https):", model_path_id_or_url):

View File

@ -30,8 +30,11 @@ from typing_extensions import Annotated, Any, Dict
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from ..embeddings.embedding_base import EmbeddingModelRaw
from ..ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel, IPAdapter, IPAdapterPlus, EmbeddingModelRaw]
class InvalidModelConfigException(Exception):
"""Exception for when config parser doesn't recognized this combination of model type and format."""
@ -299,7 +302,7 @@ AnyModelConfig = Union[
]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel, IPAdapter, IPAdapterPlus]
# IMPLEMENTATION NOTE:
# The preferred alternative to the above is a discriminated Union as shown

View File

@ -18,8 +18,8 @@ from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple, Type
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.config import VaeCheckpointConfig, VaeDiffusersConfig
from invokeai.backend.model_manager import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.config import AnyModel, VaeCheckpointConfig, VaeDiffusersConfig
from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.util.logging import InvokeAILogger

View File

@ -19,7 +19,7 @@ from invokeai.backend.model_manager import (
)
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import CacheStats, ModelCacheBase, ModelLockerBase
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
@ -71,7 +71,7 @@ class ModelLoader(ModelLoaderBase):
model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type)
if not model_path.exists():
raise InvalidModelConfigException(f"Files for model 'model_config.name' not found at {model_path}")
raise InvalidModelConfigException(f"Files for model '{model_config.name}' not found at {model_path}")
model_path = self._convert_if_needed(model_config, model_path, submodel_type)
locker = self._load_if_needed(model_config, model_path, submodel_type)

View File

@ -1,4 +1,6 @@
"""Init file for ModelCache."""
from .model_cache_base import ModelCacheBase, CacheStats # noqa F401
from .model_cache_default import ModelCache # noqa F401
_all__ = ["ModelCacheBase", "ModelCache"]
_all__ = ["ModelCacheBase", "ModelCache", "CacheStats"]

View File

@ -8,13 +8,13 @@ model will be cleared and (re)loaded from disk when next needed.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from dataclasses import dataclass, field
from logging import Logger
from typing import Generic, Optional, TypeVar
from typing import Dict, Generic, Optional, TypeVar
import torch
from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.config import AnyModel, SubModelType
class ModelLockerBase(ABC):
@ -65,6 +65,19 @@ class CacheRecord(Generic[T]):
return self._locks > 0
@dataclass
class CacheStats(object):
"""Collect statistics on cache performance."""
hits: int = 0 # cache hits
misses: int = 0 # cache misses
high_watermark: int = 0 # amount of cache used
in_cache: int = 0 # number of models in cache
cleared: int = 0 # number of models cleared to make space
cache_size: int = 0 # total size of cache
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
class ModelCacheBase(ABC, Generic[T]):
"""Virtual base class for RAM model cache."""
@ -98,10 +111,22 @@ class ModelCacheBase(ABC, Generic[T]):
pass
@abstractmethod
def move_model_to_device(self, cache_entry: CacheRecord, device: torch.device) -> None:
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], device: torch.device) -> None:
"""Move model into the indicated device."""
pass
@property
@abstractmethod
def stats(self) -> CacheStats:
"""Return collected CacheStats object."""
pass
@stats.setter
@abstractmethod
def stats(self, stats: CacheStats) -> None:
"""Set the CacheStats object for collectin cache statistics."""
pass
@property
@abstractmethod
def logger(self) -> Logger:

View File

@ -24,19 +24,17 @@ import math
import sys
import time
from contextlib import suppress
from dataclasses import dataclass, field
from logging import Logger
from typing import Dict, List, Optional
import torch
from invokeai.backend.model_manager import SubModelType
from invokeai.backend.model_manager.load.load_base import AnyModel
from invokeai.backend.model_manager import AnyModel, SubModelType
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger
from .model_cache_base import CacheRecord, ModelCacheBase
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase
from .model_locker import ModelLocker, ModelLockerBase
if choose_torch_device() == torch.device("mps"):
@ -56,20 +54,6 @@ GIG = 1073741824
MB = 2**20
@dataclass
class CacheStats(object):
"""Collect statistics on cache performance."""
hits: int = 0 # cache hits
misses: int = 0 # cache misses
high_watermark: int = 0 # amount of cache used
in_cache: int = 0 # number of models in cache
cleared: int = 0 # number of models cleared to make space
cache_size: int = 0 # total size of cache
# {submodel_key => size}
loaded_model_sizes: Dict[str, int] = field(default_factory=dict)
class ModelCache(ModelCacheBase[AnyModel]):
"""Implementation of ModelCacheBase."""
@ -110,7 +94,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
self._log_memory_usage = log_memory_usage or self._logger.level == logging.DEBUG
# used for stats collection
self.stats = CacheStats()
self._stats: Optional[CacheStats] = None
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
self._cache_stack: List[str] = []
@ -140,6 +124,16 @@ class ModelCache(ModelCacheBase[AnyModel]):
"""Return the cap on cache size."""
return self._max_cache_size
@property
def stats(self) -> Optional[CacheStats]:
"""Return collected CacheStats object."""
return self._stats
@stats.setter
def stats(self, stats: CacheStats) -> None:
"""Set the CacheStats object for collectin cache statistics."""
self._stats = stats
def cache_size(self) -> int:
"""Get the total size of the models currently cached."""
total = 0
@ -189,21 +183,24 @@ class ModelCache(ModelCacheBase[AnyModel]):
"""
key = self._make_cache_key(key, submodel_type)
if key in self._cached_models:
self.stats.hits += 1
if self.stats:
self.stats.hits += 1
else:
self.stats.misses += 1
if self.stats:
self.stats.misses += 1
raise IndexError(f"The model with key {key} is not in the cache.")
cache_entry = self._cached_models[key]
# more stats
stats_name = stats_name or key
self.stats.cache_size = int(self._max_cache_size * GIG)
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max(
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
)
if self.stats:
stats_name = stats_name or key
self.stats.cache_size = int(self._max_cache_size * GIG)
self.stats.high_watermark = max(self.stats.high_watermark, self.cache_size())
self.stats.in_cache = len(self._cached_models)
self.stats.loaded_model_sizes[stats_name] = max(
self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size
)
# this moves the entry to the top (right end) of the stack
with suppress(Exception):

View File

@ -5,7 +5,7 @@
from pathlib import Path
from typing import Optional, Tuple
from invokeai.backend.embeddings.model_patcher import TextualInversionModel as TextualInversionModelRaw
from invokeai.backend.embeddings.textual_inversion import TextualInversionModelRaw
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,

View File

@ -4,3 +4,12 @@ Initialization file for the invokeai.backend.stable_diffusion package
from .diffusers_pipeline import PipelineIntermediateState, StableDiffusionGeneratorPipeline # noqa: F401
from .diffusion import InvokeAIDiffuserComponent # noqa: F401
from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401
from .seamless import set_seamless # noqa: F401
__all__ = [
"PipelineIntermediateState",
"StableDiffusionGeneratorPipeline",
"InvokeAIDiffuserComponent",
"AttentionMapSaver",
"set_seamless",
]

View File

@ -0,0 +1,102 @@
from __future__ import annotations
from contextlib import contextmanager
from typing import List, Union
import torch.nn as nn
from diffusers.models import AutoencoderKL, UNet2DConditionModel
def _conv_forward_asymmetric(self, input, weight, bias):
"""
Patch for Conv2d._conv_forward that supports asymmetric padding
"""
working = nn.functional.pad(input, self.asymmetric_padding["x"], mode=self.asymmetric_padding_mode["x"])
working = nn.functional.pad(working, self.asymmetric_padding["y"], mode=self.asymmetric_padding_mode["y"])
return nn.functional.conv2d(
working,
weight,
bias,
self.stride,
nn.modules.utils._pair(0),
self.dilation,
self.groups,
)
@contextmanager
def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL], seamless_axes: List[str]):
try:
to_restore = []
for m_name, m in model.named_modules():
if isinstance(model, UNet2DConditionModel):
if ".attentions." in m_name:
continue
if ".resnets." in m_name:
if ".conv2" in m_name:
continue
if ".conv_shortcut" in m_name:
continue
"""
if isinstance(model, UNet2DConditionModel):
if False and ".upsamplers." in m_name:
continue
if False and ".downsamplers." in m_name:
continue
if True and ".resnets." in m_name:
if True and ".conv1" in m_name:
if False and "down_blocks" in m_name:
continue
if False and "mid_block" in m_name:
continue
if False and "up_blocks" in m_name:
continue
if True and ".conv2" in m_name:
continue
if True and ".conv_shortcut" in m_name:
continue
if True and ".attentions." in m_name:
continue
if False and m_name in ["conv_in", "conv_out"]:
continue
"""
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
m.asymmetric_padding_mode = {}
m.asymmetric_padding = {}
m.asymmetric_padding_mode["x"] = "circular" if ("x" in seamless_axes) else "constant"
m.asymmetric_padding["x"] = (
m._reversed_padding_repeated_twice[0],
m._reversed_padding_repeated_twice[1],
0,
0,
)
m.asymmetric_padding_mode["y"] = "circular" if ("y" in seamless_axes) else "constant"
m.asymmetric_padding["y"] = (
0,
0,
m._reversed_padding_repeated_twice[2],
m._reversed_padding_repeated_twice[3],
)
to_restore.append((m, m._conv_forward))
m._conv_forward = _conv_forward_asymmetric.__get__(m, nn.Conv2d)
yield
finally:
for module, orig_conv_forward in to_restore:
module._conv_forward = orig_conv_forward
if hasattr(module, "asymmetric_padding_mode"):
del module.asymmetric_padding_mode
if hasattr(module, "asymmetric_padding"):
del module.asymmetric_padding

View File

@ -0,0 +1,28 @@
"""Context class to silence transformers and diffusers warnings."""
import warnings
from typing import Any
from diffusers import logging as diffusers_logging
from transformers import logging as transformers_logging
class SilenceWarnings(object):
"""Use in context to temporarily turn off warnings from transformers & diffusers modules.
with SilenceWarnings():
# do something
"""
def __init__(self) -> None:
self.transformers_verbosity = transformers_logging.get_verbosity()
self.diffusers_verbosity = diffusers_logging.get_verbosity()
def __enter__(self) -> None:
transformers_logging.set_verbosity_error()
diffusers_logging.set_verbosity_error()
warnings.simplefilter("ignore")
def __exit__(self, *args: Any) -> None:
transformers_logging.set_verbosity(self.transformers_verbosity)
diffusers_logging.set_verbosity(self.diffusers_verbosity)
warnings.simplefilter("default")