added textual inversion and lora loaders

This commit is contained in:
Lincoln Stein 2024-02-04 23:18:00 -05:00 committed by psychedelicious
parent 34d5cad4c9
commit ad2926a24c
16 changed files with 701 additions and 38 deletions

View File

@ -178,6 +178,11 @@ class ModelInstallService(ModelInstallServiceBase):
)
def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102
similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state]
if similar_jobs:
self._logger.warning(f"There is already an active install job for {source}. Not enqueuing.")
return similar_jobs[0]
if isinstance(source, LocalModelSource):
install_job = self._import_local_model(source, config)
self._install_queue.put(install_job) # synchronously install

View File

@ -1,13 +1,17 @@
# Copyright (c) 2024 The InvokeAI Development team
"""LoRA model support."""
import bisect
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import torch
from safetensors.torch import load_file
from pathlib import Path
from typing import Dict, Optional, Union, List, Tuple
from typing_extensions import Self
from invokeai.backend.model_manager import BaseModelType
class LoRALayerBase:
# rank: Optional[int]
# alpha: Optional[float]
@ -41,7 +45,7 @@ class LoRALayerBase:
self.rank = None # set in layer implementation
self.layer_key = layer_key
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
raise NotImplementedError()
def calc_size(self) -> int:
@ -82,7 +86,7 @@ class LoRALayer(LoRALayerBase):
self.rank = self.down.shape[0]
def get_weight(self, orig_weight: torch.Tensor):
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
if self.mid is not None:
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
@ -121,11 +125,7 @@ class LoHALayer(LoRALayerBase):
# t1: Optional[torch.Tensor] = None
# t2: Optional[torch.Tensor] = None
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor]
):
def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]):
super().__init__(layer_key, values)
self.w1_a = values["hada_w1_a"]
@ -145,7 +145,7 @@ class LoHALayer(LoRALayerBase):
self.rank = self.w1_b.shape[0]
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
if self.t1 is None:
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
@ -227,7 +227,7 @@ class LoKRLayer(LoRALayerBase):
else:
self.rank = None # unscaled
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
w1: Optional[torch.Tensor] = self.w1
if w1 is None:
assert self.w1_a is not None
@ -305,7 +305,7 @@ class FullLayer(LoRALayerBase):
self.rank = None # unscaled
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
return self.weight
def calc_size(self) -> int:
@ -330,7 +330,7 @@ class IA3Layer(LoRALayerBase):
def __init__(
self,
layer_key: str,
values: Dict[str, torch.Tensor],
values: Dict[str, torch.Tensor],
):
super().__init__(layer_key, values)
@ -339,10 +339,11 @@ class IA3Layer(LoRALayerBase):
self.rank = None # unscaled
def get_weight(self, orig_weight: torch.Tensor):
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
weight = self.weight
if not self.on_input:
weight = weight.reshape(-1, 1)
assert orig_weight is not None
return orig_weight * weight
def calc_size(self) -> int:
@ -361,8 +362,10 @@ class IA3Layer(LoRALayerBase):
self.weight = self.weight.to(device=device, dtype=dtype)
self.on_input = self.on_input.to(device=device, dtype=dtype)
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
class LoRAModelRaw: # (torch.nn.Module):
_name: str
@ -530,7 +533,7 @@ class LoRAModelRaw: # (torch.nn.Module):
# code from
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
def make_sdxl_unet_conversion_map() -> List[Tuple[str,str]]:
def make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
unet_conversion_map_layer = []

View File

@ -0,0 +1,586 @@
# Copyright (c) 2024 Ryan Dick, Lincoln D. Stein, and the InvokeAI Development Team
"""These classes implement model patching with LoRAs and Textual Inversions."""
from __future__ import annotations
import pickle
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
import numpy as np
import torch
from compel.embeddings_provider import BaseTextualInversionManager
from diffusers import ModelMixin, OnnxRuntimeModel, UNet2DConditionModel
from safetensors.torch import load_file
from transformers import CLIPTextModel, CLIPTokenizer
from typing_extensions import Self
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from .lora import LoRAModelRaw
"""
loras = [
(lora_model1, 0.7),
(lora_model2, 0.4),
]
with LoRAHelper.apply_lora_unet(unet, loras):
# unet with applied loras
# unmodified unet
"""
# TODO: rename smth like ModelPatcher and add TI method?
class ModelPatcher:
@staticmethod
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
assert "." not in lora_key
if not lora_key.startswith(prefix):
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
module = model
module_key = ""
key_parts = lora_key[len(prefix) :].split("_")
submodule_name = key_parts.pop(0)
while len(key_parts) > 0:
try:
module = module.get_submodule(submodule_name)
module_key += "." + submodule_name
submodule_name = key_parts.pop(0)
except Exception:
submodule_name += "_" + key_parts.pop(0)
module = module.get_submodule(submodule_name)
module_key = (module_key + "." + submodule_name).lstrip(".")
return (module_key, module)
@classmethod
@contextmanager
def apply_lora_unet(
cls,
unet: UNet2DConditionModel,
loras: List[Tuple[LoRAModelRaw, float]],
) -> Generator[None, None, None]:
with cls.apply_lora(unet, loras, "lora_unet_"):
yield
@classmethod
@contextmanager
def apply_lora_text_encoder(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModelRaw, float]],
):
with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield
@classmethod
@contextmanager
def apply_sdxl_lora_text_encoder(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModelRaw, float]],
):
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
yield
@classmethod
@contextmanager
def apply_sdxl_lora_text_encoder2(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModelRaw, float]],
):
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
yield
@classmethod
@contextmanager
def apply_lora(
cls,
model: Union[torch.nn.Module, ModelMixin, UNet2DConditionModel],
loras: List[Tuple[LoRAModelRaw, float]],
prefix: str,
) -> Generator[None, None, None]:
original_weights = {}
try:
with torch.no_grad():
for lora, lora_weight in loras:
# assert lora.device.type == "cpu"
for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix):
continue
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
# should be improved in the following ways:
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
# LoRA model is applied.
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
# weights to have valid keys.
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
# All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.)
device = module.weight.device
dtype = module.weight.dtype
if module_key not in original_weights:
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device)
layer.to(dtype=torch.float32)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
layer.to(device=torch.device("cpu"))
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
if module.weight.shape != layer_weight.shape:
# TODO: debug on lycoris
assert hasattr(layer_weight, "reshape")
layer_weight = layer_weight.reshape(module.weight.shape)
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
module.weight += layer_weight.to(dtype=dtype)
yield # wait for context manager exit
finally:
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
with torch.no_grad():
for module_key, weight in original_weights.items():
model.get_submodule(module_key).weight.copy_(weight)
@classmethod
@contextmanager
def apply_ti(
cls,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
ti_list: List[Tuple[str, TextualInversionModel]],
) -> Generator[Tuple[CLIPTokenizer, TextualInversionManager], None, None]:
init_tokens_count = None
new_tokens_added = None
# TODO: This is required since Transformers 4.32 see
# https://github.com/huggingface/transformers/pull/25088
# More information by NVIDIA:
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
# This value might need to be changed in the future and take the GPUs model into account as there seem
# to be ideal values for different GPUS. This value is temporary!
# For references to the current discussion please see https://github.com/invoke-ai/InvokeAI/pull/4817
pad_to_multiple_of = 8
try:
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
# exiting this `apply_ti(...)` context manager.
#
# In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`,
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
ti_manager = TextualInversionManager(ti_tokenizer)
init_tokens_count = text_encoder.resize_token_embeddings(None, pad_to_multiple_of).num_embeddings
def _get_trigger(ti_name: str, index: int) -> str:
trigger = ti_name
if index > 0:
trigger += f"-!pad-{i}"
return f"<{trigger}>"
def _get_ti_embedding(model_embeddings: torch.nn.Module, ti: TextualInversionModel) -> torch.Tensor:
# for SDXL models, select the embedding that matches the text encoder's dimensions
if ti.embedding_2 is not None:
return (
ti.embedding_2
if ti.embedding_2.shape[1] == model_embeddings.weight.data[0].shape[0]
else ti.embedding
)
else:
return ti.embedding
# modify tokenizer
new_tokens_added = 0
for ti_name, ti in ti_list:
ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti)
for i in range(ti_embedding.shape[0]):
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
# Modify text_encoder.
# resize_token_embeddings(...) constructs a new torch.nn.Embedding internally. Initializing the weights of
# this embedding is slow and unnecessary, so we wrap this step in skip_torch_weight_init() to save some
# time.
with skip_torch_weight_init():
text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added, pad_to_multiple_of)
model_embeddings = text_encoder.get_input_embeddings()
for ti_name, ti in ti_list:
ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti)
ti_tokens = []
for i in range(ti_embedding.shape[0]):
embedding = ti_embedding[i]
trigger = _get_trigger(ti_name, i)
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
if token_id == ti_tokenizer.unk_token_id:
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
if model_embeddings.weight.data[token_id].shape != embedding.shape:
raise ValueError(
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension"
f" {embedding.shape[0]}, but the current model has token dimension"
f" {model_embeddings.weight.data[token_id].shape[0]}."
)
model_embeddings.weight.data[token_id] = embedding.to(
device=text_encoder.device, dtype=text_encoder.dtype
)
ti_tokens.append(token_id)
if len(ti_tokens) > 1:
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
yield ti_tokenizer, ti_manager
finally:
if init_tokens_count and new_tokens_added:
text_encoder.resize_token_embeddings(init_tokens_count, pad_to_multiple_of)
@classmethod
@contextmanager
def apply_clip_skip(
cls,
text_encoder: CLIPTextModel,
clip_skip: int,
) -> Generator[None, None, None]:
skipped_layers = []
try:
for _i in range(clip_skip):
skipped_layers.append(text_encoder.text_model.encoder.layers.pop(-1))
yield
finally:
while len(skipped_layers) > 0:
text_encoder.text_model.encoder.layers.append(skipped_layers.pop())
@classmethod
@contextmanager
def apply_freeu(
cls,
unet: UNet2DConditionModel,
freeu_config: Optional[FreeUConfig] = None,
) -> Generator[None, None, None]:
did_apply_freeu = False
try:
assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute?
if freeu_config is not None:
unet.enable_freeu(b1=freeu_config.b1, b2=freeu_config.b2, s1=freeu_config.s1, s2=freeu_config.s2)
did_apply_freeu = True
yield
finally:
assert hasattr(unet, "disable_freeu") # mypy doesn't pick up this attribute?
if did_apply_freeu:
unet.disable_freeu()
class TextualInversionModel:
embedding: torch.Tensor # [n, 768]|[n, 1280]
embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models
@classmethod
def from_checkpoint(
cls,
file_path: Union[str, Path],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> Self:
if not isinstance(file_path, Path):
file_path = Path(file_path)
result = cls() # TODO:
if file_path.suffix == ".safetensors":
state_dict = load_file(file_path.absolute().as_posix(), device="cpu")
else:
state_dict = torch.load(file_path, map_location="cpu")
# both v1 and v2 format embeddings
# difference mostly in metadata
if "string_to_param" in state_dict:
if len(state_dict["string_to_param"]) > 1:
print(
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first',
" token will be used.",
)
result.embedding = next(iter(state_dict["string_to_param"].values()))
# v3 (easynegative)
elif "emb_params" in state_dict:
result.embedding = state_dict["emb_params"]
# v5(sdxl safetensors file)
elif "clip_g" in state_dict and "clip_l" in state_dict:
result.embedding = state_dict["clip_g"]
result.embedding_2 = state_dict["clip_l"]
# v4(diffusers bin files)
else:
result.embedding = next(iter(state_dict.values()))
if len(result.embedding.shape) == 1:
result.embedding = result.embedding.unsqueeze(0)
if not isinstance(result.embedding, torch.Tensor):
raise ValueError(f"Invalid embeddings file: {file_path.name}")
return result
# no type hints for BaseTextualInversionManager?
class TextualInversionManager(BaseTextualInversionManager): # type: ignore
pad_tokens: Dict[int, List[int]]
tokenizer: CLIPTokenizer
def __init__(self, tokenizer: CLIPTokenizer):
self.pad_tokens = {}
self.tokenizer = tokenizer
def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]:
if len(self.pad_tokens) == 0:
return token_ids
if token_ids[0] == self.tokenizer.bos_token_id:
raise ValueError("token_ids must not start with bos_token_id")
if token_ids[-1] == self.tokenizer.eos_token_id:
raise ValueError("token_ids must not end with eos_token_id")
new_token_ids = []
for token_id in token_ids:
new_token_ids.append(token_id)
if token_id in self.pad_tokens:
new_token_ids.extend(self.pad_tokens[token_id])
# Do not exceed the max model input size
# The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(),
# which first removes and then adds back the start and end tokens.
max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2
if len(new_token_ids) > max_length:
new_token_ids = new_token_ids[0:max_length]
return new_token_ids
class ONNXModelPatcher:
@classmethod
@contextmanager
def apply_lora_unet(
cls,
unet: OnnxRuntimeModel,
loras: List[Tuple[LoRAModelRaw, float]],
) -> Generator[None, None, None]:
with cls.apply_lora(unet, loras, "lora_unet_"):
yield
@classmethod
@contextmanager
def apply_lora_text_encoder(
cls,
text_encoder: OnnxRuntimeModel,
loras: List[Tuple[LoRAModelRaw, float]],
) -> Generator[None, None, None]:
with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield
# based on
# https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323
@classmethod
@contextmanager
def apply_lora(
cls,
model: IAIOnnxRuntimeModel,
loras: List[Tuple[LoRAModelRaw, float]],
prefix: str,
) -> Generator[None, None, None]:
from .models.base import IAIOnnxRuntimeModel
if not isinstance(model, IAIOnnxRuntimeModel):
raise Exception("Only IAIOnnxRuntimeModel models supported")
orig_weights = {}
try:
blended_loras: Dict[str, torch.Tensor] = {}
for lora, lora_weight in loras:
for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix):
continue
layer.to(dtype=torch.float32)
layer_key = layer_key.replace(prefix, "")
# TODO: rewrite to pass original tensor weight(required by ia3)
layer_weight = layer.get_weight(None).detach().cpu().numpy() * lora_weight
if layer_key in blended_loras:
blended_loras[layer_key] += layer_weight
else:
blended_loras[layer_key] = layer_weight
node_names = {}
for node in model.nodes.values():
node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name
for layer_key, lora_weight in blended_loras.items():
conv_key = layer_key + "_Conv"
gemm_key = layer_key + "_Gemm"
matmul_key = layer_key + "_MatMul"
if conv_key in node_names or gemm_key in node_names:
if conv_key in node_names:
conv_node = model.nodes[node_names[conv_key]]
else:
conv_node = model.nodes[node_names[gemm_key]]
weight_name = [n for n in conv_node.input if ".weight" in n][0]
orig_weight = model.tensors[weight_name]
if orig_weight.shape[-2:] == (1, 1):
if lora_weight.shape[-2:] == (1, 1):
new_weight = orig_weight.squeeze((3, 2)) + lora_weight.squeeze((3, 2))
else:
new_weight = orig_weight.squeeze((3, 2)) + lora_weight
new_weight = np.expand_dims(new_weight, (2, 3))
else:
if orig_weight.shape != lora_weight.shape:
new_weight = orig_weight + lora_weight.reshape(orig_weight.shape)
else:
new_weight = orig_weight + lora_weight
orig_weights[weight_name] = orig_weight
model.tensors[weight_name] = new_weight.astype(orig_weight.dtype)
elif matmul_key in node_names:
weight_node = model.nodes[node_names[matmul_key]]
matmul_name = [n for n in weight_node.input if "MatMul" in n][0]
orig_weight = model.tensors[matmul_name]
new_weight = orig_weight + lora_weight.transpose()
orig_weights[matmul_name] = orig_weight
model.tensors[matmul_name] = new_weight.astype(orig_weight.dtype)
else:
# warn? err?
pass
yield
finally:
# restore original weights
for name, orig_weight in orig_weights.items():
model.tensors[name] = orig_weight
@classmethod
@contextmanager
def apply_ti(
cls,
tokenizer: CLIPTokenizer,
text_encoder: IAIOnnxRuntimeModel,
ti_list: List[Tuple[str, Any]],
) -> Generator[Tuple[CLIPTokenizer, TextualInversionManager], None, None]:
from .models.base import IAIOnnxRuntimeModel
if not isinstance(text_encoder, IAIOnnxRuntimeModel):
raise Exception("Only IAIOnnxRuntimeModel models supported")
orig_embeddings = None
try:
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
# exiting this `apply_ti(...)` context manager.
#
# In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`,
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
ti_manager = TextualInversionManager(ti_tokenizer)
def _get_trigger(ti_name: str, index: int) -> str:
trigger = ti_name
if index > 0:
trigger += f"-!pad-{i}"
return f"<{trigger}>"
# modify text_encoder
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
# modify tokenizer
new_tokens_added = 0
for ti_name, ti in ti_list:
if ti.embedding_2 is not None:
ti_embedding = (
ti.embedding_2 if ti.embedding_2.shape[1] == orig_embeddings.shape[0] else ti.embedding
)
else:
ti_embedding = ti.embedding
for i in range(ti_embedding.shape[0]):
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i))
embeddings = np.concatenate(
(np.copy(orig_embeddings), np.zeros((new_tokens_added, orig_embeddings.shape[1]))),
axis=0,
)
for ti_name, _ in ti_list:
ti_tokens = []
for i in range(ti_embedding.shape[0]):
embedding = ti_embedding[i].detach().numpy()
trigger = _get_trigger(ti_name, i)
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
if token_id == ti_tokenizer.unk_token_id:
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
if embeddings[token_id].shape != embedding.shape:
raise ValueError(
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension"
f" {embedding.shape[0]}, but the current model has token dimension"
f" {embeddings[token_id].shape[0]}."
)
embeddings[token_id] = embedding
ti_tokens.append(token_id)
if len(ti_tokens) > 1:
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = embeddings.astype(
orig_embeddings.dtype
)
yield ti_tokenizer, ti_manager
finally:
# restore
if orig_embeddings is not None:
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = orig_embeddings

View File

@ -102,7 +102,7 @@ class ModelPatcher:
def apply_lora(
cls,
model: torch.nn.Module,
loras: List[Tuple[LoRAModel, float]],
loras: List[Tuple[LoRAModel, float]], # THIS IS INCORRECT. IT IS ACTUALLY A LoRAModelRaw
prefix: str,
):
original_weights = {}
@ -194,6 +194,8 @@ class ModelPatcher:
return f"<{trigger}>"
def _get_ti_embedding(model_embeddings, ti):
print(f"DEBUG: model_embeddings={type(model_embeddings)}, ti={type(ti)}")
print(f"DEBUG: is it an nn.Module? {isinstance(model_embeddings, torch.nn.Module)}")
# for SDXL models, select the embedding that matches the text encoder's dimensions
if ti.embedding_2 is not None:
return (
@ -202,6 +204,7 @@ class ModelPatcher:
else ti.embedding
)
else:
print(f"DEBUG: ti.embedding={type(ti.embedding)}")
return ti.embedding
# modify tokenizer

View File

@ -28,9 +28,11 @@ from diffusers import ModelMixin
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from typing_extensions import Annotated, Any, Dict
from .onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from ..ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
class InvalidModelConfigException(Exception):
"""Exception for when config parser doesn't recognized this combination of model type and format."""

View File

@ -10,11 +10,17 @@ from diffusers import ModelMixin
from diffusers.configuration_utils import ConfigMixin
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
InvalidModelConfigException,
ModelRepoVariant,
SubModelType,
)
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs, calc_model_size_by_data
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
@ -160,4 +166,3 @@ class ModelLoader(ModelLoaderBase):
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
raise NotImplementedError

View File

@ -97,4 +97,4 @@ def get_pretty_snapshot_diff(snapshot_1: Optional[MemorySnapshot], snapshot_2: O
if snapshot_1.vram is not None and snapshot_2.vram is not None:
msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram)
return "\n"+msg if len(msg)>0 else msg
return "\n" + msg if len(msg) > 0 else msg

View File

@ -1,5 +1,3 @@
"""Init file for RamCache."""
from .model_cache_base import ModelCacheBase
from .model_cache_default import ModelCache
_all__ = ["ModelCacheBase", "ModelCache"]

View File

@ -14,8 +14,10 @@ from invokeai.backend.model_manager import (
)
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from .generic_diffusers import GenericDiffusersLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
class ControlnetLoader(GenericDiffusersLoader):
@ -37,7 +39,7 @@ class ControlnetLoader(GenericDiffusersLoader):
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"Vae conversion not supported for model type: {config.base}")
else:
assert hasattr(config, 'config')
assert hasattr(config, "config")
config_file = config.config
if weights_path.suffix == ".safetensors":

View File

@ -15,6 +15,7 @@ from invokeai.backend.model_manager import (
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from invokeai.backend.model_manager.load.load_default import ModelLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers)
class GenericDiffusersLoader(ModelLoader):

View File

@ -1,11 +1,11 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for IP Adapter model loading in InvokeAI."""
import torch
from pathlib import Path
from typing import Optional
import torch
from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter
from invokeai.backend.model_manager import (
AnyModel,
@ -18,6 +18,7 @@ from invokeai.backend.model_manager import (
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from invokeai.backend.model_manager.load.load_default import ModelLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
class IPAdapterInvokeAILoader(ModelLoader):
"""Class to load IP Adapter diffusers models."""
@ -36,4 +37,3 @@ class IPAdapterInvokeAILoader(ModelLoader):
dtype=self._torch_dtype,
)
return model

View File

@ -2,13 +2,12 @@
"""Class for LoRA model loading in InvokeAI."""
from logging import Logger
from pathlib import Path
from typing import Optional, Tuple
from logging import Logger
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.embeddings.lora import LoRAModelRaw
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
@ -18,9 +17,11 @@ from invokeai.backend.model_manager import (
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.lora import LoRAModelRaw
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers)
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris)
@ -47,6 +48,7 @@ class LoraLoader(ModelLoader):
) -> AnyModel:
if submodel_type is not None:
raise ValueError("There are no submodels in a LoRA model.")
assert self._model_base is not None
model = LoRAModelRaw.from_checkpoint(
file_path=model_path,
dtype=self._torch_dtype,
@ -56,9 +58,11 @@ class LoraLoader(ModelLoader):
# override
def _get_model_path(
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
self._model_base = config.base # cheating a little - setting this variable for later call to _load_model()
self._model_base = (
config.base
) # cheating a little - we remember this variable for using in the subsequent call to _load_model()
model_base_path = self._app_config.models_path
model_path = model_base_path / config.path
@ -72,5 +76,3 @@ class LoraLoader(ModelLoader):
result = model_path.resolve(), config, submodel_type
return result

View File

@ -0,0 +1,55 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for TI model loading in InvokeAI."""
from pathlib import Path
from typing import Optional, Tuple
from invokeai.backend.embeddings.model_patcher import TextualInversionModel as TextualInversionModelRaw
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelRepoVariant,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from invokeai.backend.model_manager.load.load_default import ModelLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFile)
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFolder)
class TextualInversionLoader(ModelLoader):
"""Class to load TI models."""
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if submodel_type is not None:
raise ValueError("There are no submodels in a TI model.")
model = TextualInversionModelRaw.from_checkpoint(
file_path=model_path,
dtype=self._torch_dtype,
)
return model
# override
def _get_model_path(
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
model_path = self._app_config.models_path / config.path
if config.format == ModelFormat.EmbeddingFolder:
path = model_path / "learned_embeds.bin"
else:
path = model_path
if not path.exists():
raise OSError(f"The embedding file at {path} was not found")
return path, config, submodel_type

View File

@ -15,6 +15,7 @@ from invokeai.backend.model_manager import (
)
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from .generic_diffusers import GenericDiffusersLoader

View File

@ -3,13 +3,13 @@
import json
from pathlib import Path
from typing import Optional, Union
from typing import Optional
import torch
from diffusers import DiffusionPipeline
from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.model_manager.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
def calc_model_size_by_data(model: AnyModel) -> int: