mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
BREAKING CHANGES: invocations now require model key, not base/type/name
- Implement new model loader and modify invocations and embeddings - Finish implementation loaders for all models currently supported by InvokeAI. - Move lora, textual_inversion, and model patching support into backend/embeddings. - Restore support for model cache statistics collection (a little ugly, needs work). - Fixed up invocations that load and patch models. - Move seamless and silencewarnings utils into better location
This commit is contained in:
committed by
psychedelicious
parent
fbded1c0f2
commit
dfcf38be91
4
invokeai/backend/embeddings/__init__.py
Normal file
4
invokeai/backend/embeddings/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
"""Initialization file for invokeai.backend.embeddings modules."""
|
||||
|
||||
# from .model_patcher import ModelPatcher
|
||||
# __all__ = ["ModelPatcher"]
|
12
invokeai/backend/embeddings/embedding_base.py
Normal file
12
invokeai/backend/embeddings/embedding_base.py
Normal file
@ -0,0 +1,12 @@
|
||||
"""Base class for LoRA and Textual Inversion models.
|
||||
|
||||
The EmbeddingRaw class is the base class of LoRAModelRaw and TextualInversionModelRaw,
|
||||
and is used for type checking of calls to the model patcher.
|
||||
|
||||
The use of "Raw" here is a historical artifact, and carried forward in
|
||||
order to avoid confusion.
|
||||
"""
|
||||
|
||||
|
||||
class EmbeddingModelRaw:
|
||||
"""Base class for LoRA and Textual Inversion models."""
|
@ -11,6 +11,8 @@ from typing_extensions import Self
|
||||
|
||||
from invokeai.backend.model_manager import BaseModelType
|
||||
|
||||
from .embedding_base import EmbeddingModelRaw
|
||||
|
||||
|
||||
class LoRALayerBase:
|
||||
# rank: Optional[int]
|
||||
@ -317,7 +319,7 @@ class FullLayer(LoRALayerBase):
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
@ -367,7 +369,7 @@ AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
||||
|
||||
|
||||
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
|
||||
class LoRAModelRaw: # (torch.nn.Module):
|
||||
class LoRAModelRaw(EmbeddingModelRaw): # (torch.nn.Module):
|
||||
_name: str
|
||||
layers: Dict[str, AnyLoRALayer]
|
||||
|
||||
@ -471,16 +473,16 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
file_path = Path(file_path)
|
||||
|
||||
model = cls(
|
||||
name=file_path.stem, # TODO:
|
||||
name=file_path.stem,
|
||||
layers={},
|
||||
)
|
||||
|
||||
if file_path.suffix == ".safetensors":
|
||||
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||
sd = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||
else:
|
||||
state_dict = torch.load(file_path, map_location="cpu")
|
||||
sd = torch.load(file_path, map_location="cpu")
|
||||
|
||||
state_dict = cls._group_state(state_dict)
|
||||
state_dict = cls._group_state(sd)
|
||||
|
||||
if base_model == BaseModelType.StableDiffusionXL:
|
||||
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||
|
@ -4,22 +4,20 @@ from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
from diffusers import ModelMixin, OnnxRuntimeModel, UNet2DConditionModel
|
||||
from safetensors.torch import load_file
|
||||
from diffusers import OnnxRuntimeModel, UNet2DConditionModel
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
from typing_extensions import Self
|
||||
|
||||
from invokeai.app.shared.models import FreeUConfig
|
||||
from invokeai.backend.model_manager import AnyModel
|
||||
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
|
||||
from .lora import LoRAModelRaw
|
||||
from .textual_inversion import TextualInversionManager, TextualInversionModelRaw
|
||||
|
||||
"""
|
||||
loras = [
|
||||
@ -67,7 +65,7 @@ class ModelPatcher:
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
) -> Generator[None, None, None]:
|
||||
) -> None:
|
||||
with cls.apply_lora(unet, loras, "lora_unet_"):
|
||||
yield
|
||||
|
||||
@ -76,8 +74,8 @@ class ModelPatcher:
|
||||
def apply_lora_text_encoder(
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
):
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
) -> None:
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
||||
yield
|
||||
|
||||
@ -87,7 +85,7 @@ class ModelPatcher:
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
):
|
||||
) -> None:
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
|
||||
yield
|
||||
|
||||
@ -97,7 +95,7 @@ class ModelPatcher:
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
):
|
||||
) -> None:
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
|
||||
yield
|
||||
|
||||
@ -105,10 +103,10 @@ class ModelPatcher:
|
||||
@contextmanager
|
||||
def apply_lora(
|
||||
cls,
|
||||
model: Union[torch.nn.Module, ModelMixin, UNet2DConditionModel],
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
model: AnyModel,
|
||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
) -> Generator[None, None, None]:
|
||||
) -> None:
|
||||
original_weights = {}
|
||||
try:
|
||||
with torch.no_grad():
|
||||
@ -125,6 +123,7 @@ class ModelPatcher:
|
||||
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
|
||||
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
|
||||
# weights to have valid keys.
|
||||
assert isinstance(model, torch.nn.Module)
|
||||
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
||||
|
||||
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||
@ -170,8 +169,8 @@ class ModelPatcher:
|
||||
cls,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModel,
|
||||
ti_list: List[Tuple[str, TextualInversionModel]],
|
||||
) -> Generator[Tuple[CLIPTokenizer, TextualInversionManager], None, None]:
|
||||
ti_list: List[Tuple[str, TextualInversionModelRaw]],
|
||||
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
|
||||
init_tokens_count = None
|
||||
new_tokens_added = None
|
||||
|
||||
@ -201,7 +200,7 @@ class ModelPatcher:
|
||||
trigger += f"-!pad-{i}"
|
||||
return f"<{trigger}>"
|
||||
|
||||
def _get_ti_embedding(model_embeddings: torch.nn.Module, ti: TextualInversionModel) -> torch.Tensor:
|
||||
def _get_ti_embedding(model_embeddings: torch.nn.Module, ti: TextualInversionModelRaw) -> torch.Tensor:
|
||||
# for SDXL models, select the embedding that matches the text encoder's dimensions
|
||||
if ti.embedding_2 is not None:
|
||||
return (
|
||||
@ -229,6 +228,7 @@ class ModelPatcher:
|
||||
model_embeddings = text_encoder.get_input_embeddings()
|
||||
|
||||
for ti_name, ti in ti_list:
|
||||
assert isinstance(ti, TextualInversionModelRaw)
|
||||
ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti)
|
||||
|
||||
ti_tokens = []
|
||||
@ -267,7 +267,7 @@ class ModelPatcher:
|
||||
cls,
|
||||
text_encoder: CLIPTextModel,
|
||||
clip_skip: int,
|
||||
) -> Generator[None, None, None]:
|
||||
) -> None:
|
||||
skipped_layers = []
|
||||
try:
|
||||
for _i in range(clip_skip):
|
||||
@ -285,7 +285,7 @@ class ModelPatcher:
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
freeu_config: Optional[FreeUConfig] = None,
|
||||
) -> Generator[None, None, None]:
|
||||
) -> None:
|
||||
did_apply_freeu = False
|
||||
try:
|
||||
assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute?
|
||||
@ -301,94 +301,6 @@ class ModelPatcher:
|
||||
unet.disable_freeu()
|
||||
|
||||
|
||||
class TextualInversionModel:
|
||||
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
||||
embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
file_path: Union[str, Path],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> Self:
|
||||
if not isinstance(file_path, Path):
|
||||
file_path = Path(file_path)
|
||||
|
||||
result = cls() # TODO:
|
||||
|
||||
if file_path.suffix == ".safetensors":
|
||||
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||
else:
|
||||
state_dict = torch.load(file_path, map_location="cpu")
|
||||
|
||||
# both v1 and v2 format embeddings
|
||||
# difference mostly in metadata
|
||||
if "string_to_param" in state_dict:
|
||||
if len(state_dict["string_to_param"]) > 1:
|
||||
print(
|
||||
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first',
|
||||
" token will be used.",
|
||||
)
|
||||
|
||||
result.embedding = next(iter(state_dict["string_to_param"].values()))
|
||||
|
||||
# v3 (easynegative)
|
||||
elif "emb_params" in state_dict:
|
||||
result.embedding = state_dict["emb_params"]
|
||||
|
||||
# v5(sdxl safetensors file)
|
||||
elif "clip_g" in state_dict and "clip_l" in state_dict:
|
||||
result.embedding = state_dict["clip_g"]
|
||||
result.embedding_2 = state_dict["clip_l"]
|
||||
|
||||
# v4(diffusers bin files)
|
||||
else:
|
||||
result.embedding = next(iter(state_dict.values()))
|
||||
|
||||
if len(result.embedding.shape) == 1:
|
||||
result.embedding = result.embedding.unsqueeze(0)
|
||||
|
||||
if not isinstance(result.embedding, torch.Tensor):
|
||||
raise ValueError(f"Invalid embeddings file: {file_path.name}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# no type hints for BaseTextualInversionManager?
|
||||
class TextualInversionManager(BaseTextualInversionManager): # type: ignore
|
||||
pad_tokens: Dict[int, List[int]]
|
||||
tokenizer: CLIPTokenizer
|
||||
|
||||
def __init__(self, tokenizer: CLIPTokenizer):
|
||||
self.pad_tokens = {}
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
|
||||
if len(self.pad_tokens) == 0:
|
||||
return token_ids
|
||||
|
||||
if token_ids[0] == self.tokenizer.bos_token_id:
|
||||
raise ValueError("token_ids must not start with bos_token_id")
|
||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
raise ValueError("token_ids must not end with eos_token_id")
|
||||
|
||||
new_token_ids = []
|
||||
for token_id in token_ids:
|
||||
new_token_ids.append(token_id)
|
||||
if token_id in self.pad_tokens:
|
||||
new_token_ids.extend(self.pad_tokens[token_id])
|
||||
|
||||
# Do not exceed the max model input size
|
||||
# The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(),
|
||||
# which first removes and then adds back the start and end tokens.
|
||||
max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2
|
||||
if len(new_token_ids) > max_length:
|
||||
new_token_ids = new_token_ids[0:max_length]
|
||||
|
||||
return new_token_ids
|
||||
|
||||
|
||||
class ONNXModelPatcher:
|
||||
@classmethod
|
||||
@contextmanager
|
||||
@ -396,7 +308,7 @@ class ONNXModelPatcher:
|
||||
cls,
|
||||
unet: OnnxRuntimeModel,
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
) -> Generator[None, None, None]:
|
||||
) -> None:
|
||||
with cls.apply_lora(unet, loras, "lora_unet_"):
|
||||
yield
|
||||
|
||||
@ -406,7 +318,7 @@ class ONNXModelPatcher:
|
||||
cls,
|
||||
text_encoder: OnnxRuntimeModel,
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
) -> Generator[None, None, None]:
|
||||
) -> None:
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
||||
yield
|
||||
|
||||
@ -419,7 +331,7 @@ class ONNXModelPatcher:
|
||||
model: IAIOnnxRuntimeModel,
|
||||
loras: List[Tuple[LoRAModelRaw, float]],
|
||||
prefix: str,
|
||||
) -> Generator[None, None, None]:
|
||||
) -> None:
|
||||
from .models.base import IAIOnnxRuntimeModel
|
||||
|
||||
if not isinstance(model, IAIOnnxRuntimeModel):
|
||||
@ -506,7 +418,7 @@ class ONNXModelPatcher:
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: IAIOnnxRuntimeModel,
|
||||
ti_list: List[Tuple[str, Any]],
|
||||
) -> Generator[Tuple[CLIPTokenizer, TextualInversionManager], None, None]:
|
||||
) -> Iterator[Tuple[CLIPTokenizer, TextualInversionManager]]:
|
||||
from .models.base import IAIOnnxRuntimeModel
|
||||
|
||||
if not isinstance(text_encoder, IAIOnnxRuntimeModel):
|
||||
|
100
invokeai/backend/embeddings/textual_inversion.py
Normal file
100
invokeai/backend/embeddings/textual_inversion.py
Normal file
@ -0,0 +1,100 @@
|
||||
"""Textual Inversion wrapper class."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
from safetensors.torch import load_file
|
||||
from transformers import CLIPTokenizer
|
||||
from typing_extensions import Self
|
||||
|
||||
from .embedding_base import EmbeddingModelRaw
|
||||
|
||||
|
||||
class TextualInversionModelRaw(EmbeddingModelRaw):
|
||||
embedding: torch.Tensor # [n, 768]|[n, 1280]
|
||||
embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
file_path: Union[str, Path],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> Self:
|
||||
if not isinstance(file_path, Path):
|
||||
file_path = Path(file_path)
|
||||
|
||||
result = cls() # TODO:
|
||||
|
||||
if file_path.suffix == ".safetensors":
|
||||
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
|
||||
else:
|
||||
state_dict = torch.load(file_path, map_location="cpu")
|
||||
|
||||
# both v1 and v2 format embeddings
|
||||
# difference mostly in metadata
|
||||
if "string_to_param" in state_dict:
|
||||
if len(state_dict["string_to_param"]) > 1:
|
||||
print(
|
||||
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first',
|
||||
" token will be used.",
|
||||
)
|
||||
|
||||
result.embedding = next(iter(state_dict["string_to_param"].values()))
|
||||
|
||||
# v3 (easynegative)
|
||||
elif "emb_params" in state_dict:
|
||||
result.embedding = state_dict["emb_params"]
|
||||
|
||||
# v5(sdxl safetensors file)
|
||||
elif "clip_g" in state_dict and "clip_l" in state_dict:
|
||||
result.embedding = state_dict["clip_g"]
|
||||
result.embedding_2 = state_dict["clip_l"]
|
||||
|
||||
# v4(diffusers bin files)
|
||||
else:
|
||||
result.embedding = next(iter(state_dict.values()))
|
||||
|
||||
if len(result.embedding.shape) == 1:
|
||||
result.embedding = result.embedding.unsqueeze(0)
|
||||
|
||||
if not isinstance(result.embedding, torch.Tensor):
|
||||
raise ValueError(f"Invalid embeddings file: {file_path.name}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# no type hints for BaseTextualInversionManager?
|
||||
class TextualInversionManager(BaseTextualInversionManager): # type: ignore
|
||||
pad_tokens: Dict[int, List[int]]
|
||||
tokenizer: CLIPTokenizer
|
||||
|
||||
def __init__(self, tokenizer: CLIPTokenizer):
|
||||
self.pad_tokens = {}
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
|
||||
if len(self.pad_tokens) == 0:
|
||||
return token_ids
|
||||
|
||||
if token_ids[0] == self.tokenizer.bos_token_id:
|
||||
raise ValueError("token_ids must not start with bos_token_id")
|
||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
raise ValueError("token_ids must not end with eos_token_id")
|
||||
|
||||
new_token_ids = []
|
||||
for token_id in token_ids:
|
||||
new_token_ids.append(token_id)
|
||||
if token_id in self.pad_tokens:
|
||||
new_token_ids.extend(self.pad_tokens[token_id])
|
||||
|
||||
# Do not exceed the max model input size
|
||||
# The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(),
|
||||
# which first removes and then adds back the start and end tokens.
|
||||
max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2
|
||||
if len(new_token_ids) > max_length:
|
||||
new_token_ids = new_token_ids[0:max_length]
|
||||
|
||||
return new_token_ids
|
Reference in New Issue
Block a user