BREAKING CHANGES: invocations now require model key, not base/type/name

- Implement new model loader and modify invocations and embeddings

- Finish implementation loaders for all models currently supported by
  InvokeAI.

- Move lora, textual_inversion, and model patching support into
  backend/embeddings.

- Restore support for model cache statistics collection (a little ugly,
  needs work).

- Fixed up invocations that load and patch models.

- Move seamless and silencewarnings utils into better location
This commit is contained in:
Lincoln Stein
2024-02-05 22:56:32 -05:00
committed by psychedelicious
parent fbded1c0f2
commit dfcf38be91
31 changed files with 727 additions and 496 deletions

View File

@ -0,0 +1,4 @@
"""Initialization file for invokeai.backend.embeddings modules."""
# from .model_patcher import ModelPatcher
# __all__ = ["ModelPatcher"]

View File

@ -0,0 +1,12 @@
"""Base class for LoRA and Textual Inversion models.
The EmbeddingRaw class is the base class of LoRAModelRaw and TextualInversionModelRaw,
and is used for type checking of calls to the model patcher.
The use of "Raw" here is a historical artifact, and carried forward in
order to avoid confusion.
"""
class EmbeddingModelRaw:
"""Base class for LoRA and Textual Inversion models."""

View File

@ -11,6 +11,8 @@ from typing_extensions import Self
from invokeai.backend.model_manager import BaseModelType
from .embedding_base import EmbeddingModelRaw
class LoRALayerBase:
# rank: Optional[int]
@ -317,7 +319,7 @@ class FullLayer(LoRALayerBase):
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
) -> None:
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
@ -367,7 +369,7 @@ AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
class LoRAModelRaw: # (torch.nn.Module):
class LoRAModelRaw(EmbeddingModelRaw): # (torch.nn.Module):
_name: str
layers: Dict[str, AnyLoRALayer]
@ -471,16 +473,16 @@ class LoRAModelRaw: # (torch.nn.Module):
file_path = Path(file_path)
model = cls(
name=file_path.stem, # TODO:
name=file_path.stem,
layers={},
)
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
sd = load_file(file_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(file_path, map_location="cpu")
sd = torch.load(file_path, map_location="cpu")
state_dict = cls._group_state(state_dict)
state_dict = cls._group_state(sd)
if base_model == BaseModelType.StableDiffusionXL:
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)

View File

@ -4,22 +4,20 @@ from __future__ import annotations
import pickle
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple
import numpy as np
import torch
from compel.embeddings_provider import BaseTextualInversionManager
from diffusers import ModelMixin, OnnxRuntimeModel, UNet2DConditionModel
from safetensors.torch import load_file
from diffusers import OnnxRuntimeModel, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer
from typing_extensions import Self
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager import AnyModel
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from .lora import LoRAModelRaw
from .textual_inversion import TextualInversionManager, TextualInversionModelRaw
"""
loras = [
@ -67,7 +65,7 @@ class ModelPatcher:
cls,
unet: UNet2DConditionModel,
loras: List[Tuple[LoRAModelRaw, float]],
) -> Generator[None, None, None]:
) -> None:
with cls.apply_lora(unet, loras, "lora_unet_"):
yield
@ -76,8 +74,8 @@ class ModelPatcher:
def apply_lora_text_encoder(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModelRaw, float]],
):
loras: Iterator[Tuple[LoRAModelRaw, float]],
) -> None:
with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield
@ -87,7 +85,7 @@ class ModelPatcher:
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModelRaw, float]],
):
) -> None:
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
yield
@ -97,7 +95,7 @@ class ModelPatcher:
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModelRaw, float]],
):
) -> None:
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
yield
@ -105,10 +103,10 @@ class ModelPatcher:
@contextmanager
def apply_lora(
cls,
model: Union[torch.nn.Module, ModelMixin, UNet2DConditionModel],
loras: List[Tuple[LoRAModelRaw, float]],
model: AnyModel,
loras: Iterator[Tuple[LoRAModelRaw, float]],
prefix: str,
) -> Generator[None, None, None]:
) -> None:
original_weights = {}
try:
with torch.no_grad():
@ -125,6 +123,7 @@ class ModelPatcher:
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
# weights to have valid keys.
assert isinstance(model, torch.nn.Module)
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
# All of the LoRA weight calculations will be done on the same device as the module weight.
@ -170,8 +169,8 @@ class ModelPatcher:
cls,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
ti_list: List[Tuple[str, TextualInversionModel]],
) -> Generator[Tuple[CLIPTokenizer, TextualInversionManager], None, None]:
ti_list: List[Tuple[str, TextualInversionModelRaw]],
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
init_tokens_count = None
new_tokens_added = None
@ -201,7 +200,7 @@ class ModelPatcher:
trigger += f"-!pad-{i}"
return f"<{trigger}>"
def _get_ti_embedding(model_embeddings: torch.nn.Module, ti: TextualInversionModel) -> torch.Tensor:
def _get_ti_embedding(model_embeddings: torch.nn.Module, ti: TextualInversionModelRaw) -> torch.Tensor:
# for SDXL models, select the embedding that matches the text encoder's dimensions
if ti.embedding_2 is not None:
return (
@ -229,6 +228,7 @@ class ModelPatcher:
model_embeddings = text_encoder.get_input_embeddings()
for ti_name, ti in ti_list:
assert isinstance(ti, TextualInversionModelRaw)
ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti)
ti_tokens = []
@ -267,7 +267,7 @@ class ModelPatcher:
cls,
text_encoder: CLIPTextModel,
clip_skip: int,
) -> Generator[None, None, None]:
) -> None:
skipped_layers = []
try:
for _i in range(clip_skip):
@ -285,7 +285,7 @@ class ModelPatcher:
cls,
unet: UNet2DConditionModel,
freeu_config: Optional[FreeUConfig] = None,
) -> Generator[None, None, None]:
) -> None:
did_apply_freeu = False
try:
assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute?
@ -301,94 +301,6 @@ class ModelPatcher:
unet.disable_freeu()
class TextualInversionModel:
embedding: torch.Tensor # [n, 768]|[n, 1280]
embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models
@classmethod
def from_checkpoint(
cls,
file_path: Union[str, Path],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> Self:
if not isinstance(file_path, Path):
file_path = Path(file_path)
result = cls() # TODO:
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(file_path, map_location="cpu")
# both v1 and v2 format embeddings
# difference mostly in metadata
if "string_to_param" in state_dict:
if len(state_dict["string_to_param"]) > 1:
print(
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first',
" token will be used.",
)
result.embedding = next(iter(state_dict["string_to_param"].values()))
# v3 (easynegative)
elif "emb_params" in state_dict:
result.embedding = state_dict["emb_params"]
# v5(sdxl safetensors file)
elif "clip_g" in state_dict and "clip_l" in state_dict:
result.embedding = state_dict["clip_g"]
result.embedding_2 = state_dict["clip_l"]
# v4(diffusers bin files)
else:
result.embedding = next(iter(state_dict.values()))
if len(result.embedding.shape) == 1:
result.embedding = result.embedding.unsqueeze(0)
if not isinstance(result.embedding, torch.Tensor):
raise ValueError(f"Invalid embeddings file: {file_path.name}")
return result
# no type hints for BaseTextualInversionManager?
class TextualInversionManager(BaseTextualInversionManager): # type: ignore
pad_tokens: Dict[int, List[int]]
tokenizer: CLIPTokenizer
def __init__(self, tokenizer: CLIPTokenizer):
self.pad_tokens = {}
self.tokenizer = tokenizer
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
if len(self.pad_tokens) == 0:
return token_ids
if token_ids[0] == self.tokenizer.bos_token_id:
raise ValueError("token_ids must not start with bos_token_id")
if token_ids[-1] == self.tokenizer.eos_token_id:
raise ValueError("token_ids must not end with eos_token_id")
new_token_ids = []
for token_id in token_ids:
new_token_ids.append(token_id)
if token_id in self.pad_tokens:
new_token_ids.extend(self.pad_tokens[token_id])
# Do not exceed the max model input size
# The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(),
# which first removes and then adds back the start and end tokens.
max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2
if len(new_token_ids) > max_length:
new_token_ids = new_token_ids[0:max_length]
return new_token_ids
class ONNXModelPatcher:
@classmethod
@contextmanager
@ -396,7 +308,7 @@ class ONNXModelPatcher:
cls,
unet: OnnxRuntimeModel,
loras: List[Tuple[LoRAModelRaw, float]],
) -> Generator[None, None, None]:
) -> None:
with cls.apply_lora(unet, loras, "lora_unet_"):
yield
@ -406,7 +318,7 @@ class ONNXModelPatcher:
cls,
text_encoder: OnnxRuntimeModel,
loras: List[Tuple[LoRAModelRaw, float]],
) -> Generator[None, None, None]:
) -> None:
with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield
@ -419,7 +331,7 @@ class ONNXModelPatcher:
model: IAIOnnxRuntimeModel,
loras: List[Tuple[LoRAModelRaw, float]],
prefix: str,
) -> Generator[None, None, None]:
) -> None:
from .models.base import IAIOnnxRuntimeModel
if not isinstance(model, IAIOnnxRuntimeModel):
@ -506,7 +418,7 @@ class ONNXModelPatcher:
tokenizer: CLIPTokenizer,
text_encoder: IAIOnnxRuntimeModel,
ti_list: List[Tuple[str, Any]],
) -> Generator[Tuple[CLIPTokenizer, TextualInversionManager], None, None]:
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
from .models.base import IAIOnnxRuntimeModel
if not isinstance(text_encoder, IAIOnnxRuntimeModel):

View File

@ -0,0 +1,100 @@
"""Textual Inversion wrapper class."""
from pathlib import Path
from typing import Dict, List, Optional, Union
import torch
from compel.embeddings_provider import BaseTextualInversionManager
from safetensors.torch import load_file
from transformers import CLIPTokenizer
from typing_extensions import Self
from .embedding_base import EmbeddingModelRaw
class TextualInversionModelRaw(EmbeddingModelRaw):
embedding: torch.Tensor # [n, 768]|[n, 1280]
embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models
@classmethod
def from_checkpoint(
cls,
file_path: Union[str, Path],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> Self:
if not isinstance(file_path, Path):
file_path = Path(file_path)
result = cls() # TODO:
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(file_path, map_location="cpu")
# both v1 and v2 format embeddings
# difference mostly in metadata
if "string_to_param" in state_dict:
if len(state_dict["string_to_param"]) > 1:
print(
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first',
" token will be used.",
)
result.embedding = next(iter(state_dict["string_to_param"].values()))
# v3 (easynegative)
elif "emb_params" in state_dict:
result.embedding = state_dict["emb_params"]
# v5(sdxl safetensors file)
elif "clip_g" in state_dict and "clip_l" in state_dict:
result.embedding = state_dict["clip_g"]
result.embedding_2 = state_dict["clip_l"]
# v4(diffusers bin files)
else:
result.embedding = next(iter(state_dict.values()))
if len(result.embedding.shape) == 1:
result.embedding = result.embedding.unsqueeze(0)
if not isinstance(result.embedding, torch.Tensor):
raise ValueError(f"Invalid embeddings file: {file_path.name}")
return result
# no type hints for BaseTextualInversionManager?
class TextualInversionManager(BaseTextualInversionManager): # type: ignore
pad_tokens: Dict[int, List[int]]
tokenizer: CLIPTokenizer
def __init__(self, tokenizer: CLIPTokenizer):
self.pad_tokens = {}
self.tokenizer = tokenizer
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
if len(self.pad_tokens) == 0:
return token_ids
if token_ids[0] == self.tokenizer.bos_token_id:
raise ValueError("token_ids must not start with bos_token_id")
if token_ids[-1] == self.tokenizer.eos_token_id:
raise ValueError("token_ids must not end with eos_token_id")
new_token_ids = []
for token_id in token_ids:
new_token_ids.append(token_id)
if token_id in self.pad_tokens:
new_token_ids.extend(self.pad_tokens[token_id])
# Do not exceed the max model input size
# The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(),
# which first removes and then adds back the start and end tokens.
max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2
if len(new_token_ids) > max_length:
new_token_ids = new_token_ids[0:max_length]
return new_token_ids