diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 2b2294bfce..1c188b300d 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -178,6 +178,11 @@ class ModelInstallService(ModelInstallServiceBase): ) def import_model(self, source: ModelSource, config: Optional[Dict[str, Any]] = None) -> ModelInstallJob: # noqa D102 + similar_jobs = [x for x in self.list_jobs() if x.source == source and not x.in_terminal_state] + if similar_jobs: + self._logger.warning(f"There is already an active install job for {source}. Not enqueuing.") + return similar_jobs[0] + if isinstance(source, LocalModelSource): install_job = self._import_local_model(source, config) self._install_queue.put(install_job) # synchronously install diff --git a/invokeai/backend/model_manager/lora.py b/invokeai/backend/embeddings/lora.py similarity index 96% rename from invokeai/backend/model_manager/lora.py rename to invokeai/backend/embeddings/lora.py index 4c48de48ec..9a59a97708 100644 --- a/invokeai/backend/model_manager/lora.py +++ b/invokeai/backend/embeddings/lora.py @@ -1,13 +1,17 @@ # Copyright (c) 2024 The InvokeAI Development team """LoRA model support.""" +import bisect +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + import torch from safetensors.torch import load_file -from pathlib import Path -from typing import Dict, Optional, Union, List, Tuple from typing_extensions import Self + from invokeai.backend.model_manager import BaseModelType + class LoRALayerBase: # rank: Optional[int] # alpha: Optional[float] @@ -41,7 +45,7 @@ class LoRALayerBase: self.rank = None # set in layer implementation self.layer_key = layer_key - def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: + def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: raise NotImplementedError() def calc_size(self) -> int: @@ -82,7 +86,7 @@ class LoRALayer(LoRALayerBase): self.rank = self.down.shape[0] - def get_weight(self, orig_weight: torch.Tensor): + def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: if self.mid is not None: up = self.up.reshape(self.up.shape[0], self.up.shape[1]) down = self.down.reshape(self.down.shape[0], self.down.shape[1]) @@ -121,11 +125,7 @@ class LoHALayer(LoRALayerBase): # t1: Optional[torch.Tensor] = None # t2: Optional[torch.Tensor] = None - def __init__( - self, - layer_key: str, - values: Dict[str, torch.Tensor] - ): + def __init__(self, layer_key: str, values: Dict[str, torch.Tensor]): super().__init__(layer_key, values) self.w1_a = values["hada_w1_a"] @@ -145,7 +145,7 @@ class LoHALayer(LoRALayerBase): self.rank = self.w1_b.shape[0] - def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: + def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: if self.t1 is None: weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b) @@ -227,7 +227,7 @@ class LoKRLayer(LoRALayerBase): else: self.rank = None # unscaled - def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: + def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: w1: Optional[torch.Tensor] = self.w1 if w1 is None: assert self.w1_a is not None @@ -305,7 +305,7 @@ class FullLayer(LoRALayerBase): self.rank = None # unscaled - def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: + def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: return self.weight def calc_size(self) -> int: @@ -330,7 +330,7 @@ class IA3Layer(LoRALayerBase): def __init__( self, layer_key: str, - values: Dict[str, torch.Tensor], + values: Dict[str, torch.Tensor], ): super().__init__(layer_key, values) @@ -339,10 +339,11 @@ class IA3Layer(LoRALayerBase): self.rank = None # unscaled - def get_weight(self, orig_weight: torch.Tensor): + def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: weight = self.weight if not self.on_input: weight = weight.reshape(-1, 1) + assert orig_weight is not None return orig_weight * weight def calc_size(self) -> int: @@ -361,8 +362,10 @@ class IA3Layer(LoRALayerBase): self.weight = self.weight.to(device=device, dtype=dtype) self.on_input = self.on_input.to(device=device, dtype=dtype) + AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer] - + + # TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix class LoRAModelRaw: # (torch.nn.Module): _name: str @@ -530,7 +533,7 @@ class LoRAModelRaw: # (torch.nn.Module): # code from # https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32 -def make_sdxl_unet_conversion_map() -> List[Tuple[str,str]]: +def make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]: """Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format.""" unet_conversion_map_layer = [] diff --git a/invokeai/backend/embeddings/model_patcher.py b/invokeai/backend/embeddings/model_patcher.py new file mode 100644 index 0000000000..6d73235197 --- /dev/null +++ b/invokeai/backend/embeddings/model_patcher.py @@ -0,0 +1,586 @@ +# Copyright (c) 2024 Ryan Dick, Lincoln D. Stein, and the InvokeAI Development Team +"""These classes implement model patching with LoRAs and Textual Inversions.""" +from __future__ import annotations + +import pickle +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Dict, Generator, List, Optional, Tuple, Union + +import numpy as np +import torch +from compel.embeddings_provider import BaseTextualInversionManager +from diffusers import ModelMixin, OnnxRuntimeModel, UNet2DConditionModel +from safetensors.torch import load_file +from transformers import CLIPTextModel, CLIPTokenizer +from typing_extensions import Self + +from invokeai.app.shared.models import FreeUConfig +from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init +from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel + +from .lora import LoRAModelRaw + +""" +loras = [ + (lora_model1, 0.7), + (lora_model2, 0.4), +] +with LoRAHelper.apply_lora_unet(unet, loras): + # unet with applied loras +# unmodified unet + +""" + + +# TODO: rename smth like ModelPatcher and add TI method? +class ModelPatcher: + @staticmethod + def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: + assert "." not in lora_key + + if not lora_key.startswith(prefix): + raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}") + + module = model + module_key = "" + key_parts = lora_key[len(prefix) :].split("_") + + submodule_name = key_parts.pop(0) + + while len(key_parts) > 0: + try: + module = module.get_submodule(submodule_name) + module_key += "." + submodule_name + submodule_name = key_parts.pop(0) + except Exception: + submodule_name += "_" + key_parts.pop(0) + + module = module.get_submodule(submodule_name) + module_key = (module_key + "." + submodule_name).lstrip(".") + + return (module_key, module) + + @classmethod + @contextmanager + def apply_lora_unet( + cls, + unet: UNet2DConditionModel, + loras: List[Tuple[LoRAModelRaw, float]], + ) -> Generator[None, None, None]: + with cls.apply_lora(unet, loras, "lora_unet_"): + yield + + @classmethod + @contextmanager + def apply_lora_text_encoder( + cls, + text_encoder: CLIPTextModel, + loras: List[Tuple[LoRAModelRaw, float]], + ): + with cls.apply_lora(text_encoder, loras, "lora_te_"): + yield + + @classmethod + @contextmanager + def apply_sdxl_lora_text_encoder( + cls, + text_encoder: CLIPTextModel, + loras: List[Tuple[LoRAModelRaw, float]], + ): + with cls.apply_lora(text_encoder, loras, "lora_te1_"): + yield + + @classmethod + @contextmanager + def apply_sdxl_lora_text_encoder2( + cls, + text_encoder: CLIPTextModel, + loras: List[Tuple[LoRAModelRaw, float]], + ): + with cls.apply_lora(text_encoder, loras, "lora_te2_"): + yield + + @classmethod + @contextmanager + def apply_lora( + cls, + model: Union[torch.nn.Module, ModelMixin, UNet2DConditionModel], + loras: List[Tuple[LoRAModelRaw, float]], + prefix: str, + ) -> Generator[None, None, None]: + original_weights = {} + try: + with torch.no_grad(): + for lora, lora_weight in loras: + # assert lora.device.type == "cpu" + for layer_key, layer in lora.layers.items(): + if not layer_key.startswith(prefix): + continue + + # TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This + # should be improved in the following ways: + # 1. The key mapping could be more-efficiently pre-computed. This would save time every time a + # LoRA model is applied. + # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the + # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA + # weights to have valid keys. + module_key, module = cls._resolve_lora_key(model, layer_key, prefix) + + # All of the LoRA weight calculations will be done on the same device as the module weight. + # (Performance will be best if this is a CUDA device.) + device = module.weight.device + dtype = module.weight.dtype + + if module_key not in original_weights: + original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) + + layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 + + # We intentionally move to the target device first, then cast. Experimentally, this was found to + # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the + # same thing in a single call to '.to(...)'. + layer.to(device=device) + layer.to(dtype=torch.float32) + # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA + # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. + layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale) + layer.to(device=torch.device("cpu")) + + assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! + if module.weight.shape != layer_weight.shape: + # TODO: debug on lycoris + assert hasattr(layer_weight, "reshape") + layer_weight = layer_weight.reshape(module.weight.shape) + + assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! + module.weight += layer_weight.to(dtype=dtype) + + yield # wait for context manager exit + + finally: + assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule() + with torch.no_grad(): + for module_key, weight in original_weights.items(): + model.get_submodule(module_key).weight.copy_(weight) + + @classmethod + @contextmanager + def apply_ti( + cls, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModel, + ti_list: List[Tuple[str, TextualInversionModel]], + ) -> Generator[Tuple[CLIPTokenizer, TextualInversionManager], None, None]: + init_tokens_count = None + new_tokens_added = None + + # TODO: This is required since Transformers 4.32 see + # https://github.com/huggingface/transformers/pull/25088 + # More information by NVIDIA: + # https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc + # This value might need to be changed in the future and take the GPUs model into account as there seem + # to be ideal values for different GPUS. This value is temporary! + # For references to the current discussion please see https://github.com/invoke-ai/InvokeAI/pull/4817 + pad_to_multiple_of = 8 + + try: + # HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a + # workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after + # exiting this `apply_ti(...)` context manager. + # + # In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`, + # but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs). + ti_tokenizer = pickle.loads(pickle.dumps(tokenizer)) + ti_manager = TextualInversionManager(ti_tokenizer) + init_tokens_count = text_encoder.resize_token_embeddings(None, pad_to_multiple_of).num_embeddings + + def _get_trigger(ti_name: str, index: int) -> str: + trigger = ti_name + if index > 0: + trigger += f"-!pad-{i}" + return f"<{trigger}>" + + def _get_ti_embedding(model_embeddings: torch.nn.Module, ti: TextualInversionModel) -> torch.Tensor: + # for SDXL models, select the embedding that matches the text encoder's dimensions + if ti.embedding_2 is not None: + return ( + ti.embedding_2 + if ti.embedding_2.shape[1] == model_embeddings.weight.data[0].shape[0] + else ti.embedding + ) + else: + return ti.embedding + + # modify tokenizer + new_tokens_added = 0 + for ti_name, ti in ti_list: + ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti) + + for i in range(ti_embedding.shape[0]): + new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i)) + + # Modify text_encoder. + # resize_token_embeddings(...) constructs a new torch.nn.Embedding internally. Initializing the weights of + # this embedding is slow and unnecessary, so we wrap this step in skip_torch_weight_init() to save some + # time. + with skip_torch_weight_init(): + text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added, pad_to_multiple_of) + model_embeddings = text_encoder.get_input_embeddings() + + for ti_name, ti in ti_list: + ti_embedding = _get_ti_embedding(text_encoder.get_input_embeddings(), ti) + + ti_tokens = [] + for i in range(ti_embedding.shape[0]): + embedding = ti_embedding[i] + trigger = _get_trigger(ti_name, i) + + token_id = ti_tokenizer.convert_tokens_to_ids(trigger) + if token_id == ti_tokenizer.unk_token_id: + raise RuntimeError(f"Unable to find token id for token '{trigger}'") + + if model_embeddings.weight.data[token_id].shape != embedding.shape: + raise ValueError( + f"Cannot load embedding for {trigger}. It was trained on a model with token dimension" + f" {embedding.shape[0]}, but the current model has token dimension" + f" {model_embeddings.weight.data[token_id].shape[0]}." + ) + + model_embeddings.weight.data[token_id] = embedding.to( + device=text_encoder.device, dtype=text_encoder.dtype + ) + ti_tokens.append(token_id) + + if len(ti_tokens) > 1: + ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:] + + yield ti_tokenizer, ti_manager + + finally: + if init_tokens_count and new_tokens_added: + text_encoder.resize_token_embeddings(init_tokens_count, pad_to_multiple_of) + + @classmethod + @contextmanager + def apply_clip_skip( + cls, + text_encoder: CLIPTextModel, + clip_skip: int, + ) -> Generator[None, None, None]: + skipped_layers = [] + try: + for _i in range(clip_skip): + skipped_layers.append(text_encoder.text_model.encoder.layers.pop(-1)) + + yield + + finally: + while len(skipped_layers) > 0: + text_encoder.text_model.encoder.layers.append(skipped_layers.pop()) + + @classmethod + @contextmanager + def apply_freeu( + cls, + unet: UNet2DConditionModel, + freeu_config: Optional[FreeUConfig] = None, + ) -> Generator[None, None, None]: + did_apply_freeu = False + try: + assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute? + if freeu_config is not None: + unet.enable_freeu(b1=freeu_config.b1, b2=freeu_config.b2, s1=freeu_config.s1, s2=freeu_config.s2) + did_apply_freeu = True + + yield + + finally: + assert hasattr(unet, "disable_freeu") # mypy doesn't pick up this attribute? + if did_apply_freeu: + unet.disable_freeu() + + +class TextualInversionModel: + embedding: torch.Tensor # [n, 768]|[n, 1280] + embedding_2: Optional[torch.Tensor] = None # [n, 768]|[n, 1280] - for SDXL models + + @classmethod + def from_checkpoint( + cls, + file_path: Union[str, Path], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> Self: + if not isinstance(file_path, Path): + file_path = Path(file_path) + + result = cls() # TODO: + + if file_path.suffix == ".safetensors": + state_dict = load_file(file_path.absolute().as_posix(), device="cpu") + else: + state_dict = torch.load(file_path, map_location="cpu") + + # both v1 and v2 format embeddings + # difference mostly in metadata + if "string_to_param" in state_dict: + if len(state_dict["string_to_param"]) > 1: + print( + f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first', + " token will be used.", + ) + + result.embedding = next(iter(state_dict["string_to_param"].values())) + + # v3 (easynegative) + elif "emb_params" in state_dict: + result.embedding = state_dict["emb_params"] + + # v5(sdxl safetensors file) + elif "clip_g" in state_dict and "clip_l" in state_dict: + result.embedding = state_dict["clip_g"] + result.embedding_2 = state_dict["clip_l"] + + # v4(diffusers bin files) + else: + result.embedding = next(iter(state_dict.values())) + + if len(result.embedding.shape) == 1: + result.embedding = result.embedding.unsqueeze(0) + + if not isinstance(result.embedding, torch.Tensor): + raise ValueError(f"Invalid embeddings file: {file_path.name}") + + return result + + +# no type hints for BaseTextualInversionManager? +class TextualInversionManager(BaseTextualInversionManager): # type: ignore + pad_tokens: Dict[int, List[int]] + tokenizer: CLIPTokenizer + + def __init__(self, tokenizer: CLIPTokenizer): + self.pad_tokens = {} + self.tokenizer = tokenizer + + def expand_textual_inversion_token_ids_if_necessary(self, token_ids: list[int]) -> list[int]: + if len(self.pad_tokens) == 0: + return token_ids + + if token_ids[0] == self.tokenizer.bos_token_id: + raise ValueError("token_ids must not start with bos_token_id") + if token_ids[-1] == self.tokenizer.eos_token_id: + raise ValueError("token_ids must not end with eos_token_id") + + new_token_ids = [] + for token_id in token_ids: + new_token_ids.append(token_id) + if token_id in self.pad_tokens: + new_token_ids.extend(self.pad_tokens[token_id]) + + # Do not exceed the max model input size + # The -2 here is compensating for compensate compel.embeddings_provider.get_token_ids(), + # which first removes and then adds back the start and end tokens. + max_length = list(self.tokenizer.max_model_input_sizes.values())[0] - 2 + if len(new_token_ids) > max_length: + new_token_ids = new_token_ids[0:max_length] + + return new_token_ids + + +class ONNXModelPatcher: + @classmethod + @contextmanager + def apply_lora_unet( + cls, + unet: OnnxRuntimeModel, + loras: List[Tuple[LoRAModelRaw, float]], + ) -> Generator[None, None, None]: + with cls.apply_lora(unet, loras, "lora_unet_"): + yield + + @classmethod + @contextmanager + def apply_lora_text_encoder( + cls, + text_encoder: OnnxRuntimeModel, + loras: List[Tuple[LoRAModelRaw, float]], + ) -> Generator[None, None, None]: + with cls.apply_lora(text_encoder, loras, "lora_te_"): + yield + + # based on + # https://github.com/ssube/onnx-web/blob/ca2e436f0623e18b4cfe8a0363fcfcf10508acf7/api/onnx_web/convert/diffusion/lora.py#L323 + @classmethod + @contextmanager + def apply_lora( + cls, + model: IAIOnnxRuntimeModel, + loras: List[Tuple[LoRAModelRaw, float]], + prefix: str, + ) -> Generator[None, None, None]: + from .models.base import IAIOnnxRuntimeModel + + if not isinstance(model, IAIOnnxRuntimeModel): + raise Exception("Only IAIOnnxRuntimeModel models supported") + + orig_weights = {} + + try: + blended_loras: Dict[str, torch.Tensor] = {} + + for lora, lora_weight in loras: + for layer_key, layer in lora.layers.items(): + if not layer_key.startswith(prefix): + continue + + layer.to(dtype=torch.float32) + layer_key = layer_key.replace(prefix, "") + # TODO: rewrite to pass original tensor weight(required by ia3) + layer_weight = layer.get_weight(None).detach().cpu().numpy() * lora_weight + if layer_key in blended_loras: + blended_loras[layer_key] += layer_weight + else: + blended_loras[layer_key] = layer_weight + + node_names = {} + for node in model.nodes.values(): + node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name + + for layer_key, lora_weight in blended_loras.items(): + conv_key = layer_key + "_Conv" + gemm_key = layer_key + "_Gemm" + matmul_key = layer_key + "_MatMul" + + if conv_key in node_names or gemm_key in node_names: + if conv_key in node_names: + conv_node = model.nodes[node_names[conv_key]] + else: + conv_node = model.nodes[node_names[gemm_key]] + + weight_name = [n for n in conv_node.input if ".weight" in n][0] + orig_weight = model.tensors[weight_name] + + if orig_weight.shape[-2:] == (1, 1): + if lora_weight.shape[-2:] == (1, 1): + new_weight = orig_weight.squeeze((3, 2)) + lora_weight.squeeze((3, 2)) + else: + new_weight = orig_weight.squeeze((3, 2)) + lora_weight + + new_weight = np.expand_dims(new_weight, (2, 3)) + else: + if orig_weight.shape != lora_weight.shape: + new_weight = orig_weight + lora_weight.reshape(orig_weight.shape) + else: + new_weight = orig_weight + lora_weight + + orig_weights[weight_name] = orig_weight + model.tensors[weight_name] = new_weight.astype(orig_weight.dtype) + + elif matmul_key in node_names: + weight_node = model.nodes[node_names[matmul_key]] + matmul_name = [n for n in weight_node.input if "MatMul" in n][0] + + orig_weight = model.tensors[matmul_name] + new_weight = orig_weight + lora_weight.transpose() + + orig_weights[matmul_name] = orig_weight + model.tensors[matmul_name] = new_weight.astype(orig_weight.dtype) + + else: + # warn? err? + pass + + yield + + finally: + # restore original weights + for name, orig_weight in orig_weights.items(): + model.tensors[name] = orig_weight + + @classmethod + @contextmanager + def apply_ti( + cls, + tokenizer: CLIPTokenizer, + text_encoder: IAIOnnxRuntimeModel, + ti_list: List[Tuple[str, Any]], + ) -> Generator[Tuple[CLIPTokenizer, TextualInversionManager], None, None]: + from .models.base import IAIOnnxRuntimeModel + + if not isinstance(text_encoder, IAIOnnxRuntimeModel): + raise Exception("Only IAIOnnxRuntimeModel models supported") + + orig_embeddings = None + + try: + # HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a + # workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after + # exiting this `apply_ti(...)` context manager. + # + # In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`, + # but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs). + ti_tokenizer = pickle.loads(pickle.dumps(tokenizer)) + ti_manager = TextualInversionManager(ti_tokenizer) + + def _get_trigger(ti_name: str, index: int) -> str: + trigger = ti_name + if index > 0: + trigger += f"-!pad-{i}" + return f"<{trigger}>" + + # modify text_encoder + orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"] + + # modify tokenizer + new_tokens_added = 0 + for ti_name, ti in ti_list: + if ti.embedding_2 is not None: + ti_embedding = ( + ti.embedding_2 if ti.embedding_2.shape[1] == orig_embeddings.shape[0] else ti.embedding + ) + else: + ti_embedding = ti.embedding + + for i in range(ti_embedding.shape[0]): + new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti_name, i)) + + embeddings = np.concatenate( + (np.copy(orig_embeddings), np.zeros((new_tokens_added, orig_embeddings.shape[1]))), + axis=0, + ) + + for ti_name, _ in ti_list: + ti_tokens = [] + for i in range(ti_embedding.shape[0]): + embedding = ti_embedding[i].detach().numpy() + trigger = _get_trigger(ti_name, i) + + token_id = ti_tokenizer.convert_tokens_to_ids(trigger) + if token_id == ti_tokenizer.unk_token_id: + raise RuntimeError(f"Unable to find token id for token '{trigger}'") + + if embeddings[token_id].shape != embedding.shape: + raise ValueError( + f"Cannot load embedding for {trigger}. It was trained on a model with token dimension" + f" {embedding.shape[0]}, but the current model has token dimension" + f" {embeddings[token_id].shape[0]}." + ) + + embeddings[token_id] = embedding + ti_tokens.append(token_id) + + if len(ti_tokens) > 1: + ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:] + + text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = embeddings.astype( + orig_embeddings.dtype + ) + + yield ti_tokenizer, ti_manager + + finally: + # restore + if orig_embeddings is not None: + text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = orig_embeddings diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index d72f55794d..aed5eb60d5 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -102,7 +102,7 @@ class ModelPatcher: def apply_lora( cls, model: torch.nn.Module, - loras: List[Tuple[LoRAModel, float]], + loras: List[Tuple[LoRAModel, float]], # THIS IS INCORRECT. IT IS ACTUALLY A LoRAModelRaw prefix: str, ): original_weights = {} @@ -194,6 +194,8 @@ class ModelPatcher: return f"<{trigger}>" def _get_ti_embedding(model_embeddings, ti): + print(f"DEBUG: model_embeddings={type(model_embeddings)}, ti={type(ti)}") + print(f"DEBUG: is it an nn.Module? {isinstance(model_embeddings, torch.nn.Module)}") # for SDXL models, select the embedding that matches the text encoder's dimensions if ti.embedding_2 is not None: return ( @@ -202,6 +204,7 @@ class ModelPatcher: else ti.embedding ) else: + print(f"DEBUG: ti.embedding={type(ti.embedding)}") return ti.embedding # modify tokenizer diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index e59a84d729..4488f8eafc 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -28,9 +28,11 @@ from diffusers import ModelMixin from pydantic import BaseModel, ConfigDict, Field, TypeAdapter from typing_extensions import Annotated, Any, Dict -from .onnx_runtime import IAIOnnxRuntimeModel +from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel + from ..ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus + class InvalidModelConfigException(Exception): """Exception for when config parser doesn't recognized this combination of model type and format.""" diff --git a/invokeai/backend/model_manager/load/load_default.py b/invokeai/backend/model_manager/load/load_default.py index 453283e9b4..adc84d2051 100644 --- a/invokeai/backend/model_manager/load/load_default.py +++ b/invokeai/backend/model_manager/load/load_default.py @@ -10,11 +10,17 @@ from diffusers import ModelMixin from diffusers.configuration_utils import ConfigMixin from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.backend.model_manager import AnyModel, AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType +from invokeai.backend.model_manager import ( + AnyModel, + AnyModelConfig, + InvalidModelConfigException, + ModelRepoVariant, + SubModelType, +) from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase -from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs, calc_model_size_by_data +from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data, calc_model_size_by_fs from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.util.devices import choose_torch_device, torch_dtype @@ -160,4 +166,3 @@ class ModelLoader(ModelLoaderBase): submodel_type: Optional[SubModelType] = None, ) -> AnyModel: raise NotImplementedError - diff --git a/invokeai/backend/model_manager/load/memory_snapshot.py b/invokeai/backend/model_manager/load/memory_snapshot.py index 295be0c551..346f5dc424 100644 --- a/invokeai/backend/model_manager/load/memory_snapshot.py +++ b/invokeai/backend/model_manager/load/memory_snapshot.py @@ -97,4 +97,4 @@ def get_pretty_snapshot_diff(snapshot_1: Optional[MemorySnapshot], snapshot_2: O if snapshot_1.vram is not None and snapshot_2.vram is not None: msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram) - return "\n"+msg if len(msg)>0 else msg + return "\n" + msg if len(msg) > 0 else msg diff --git a/invokeai/backend/model_manager/load/model_cache/__init__.py b/invokeai/backend/model_manager/load/model_cache/__init__.py index 50cafa3769..6c87e2519e 100644 --- a/invokeai/backend/model_manager/load/model_cache/__init__.py +++ b/invokeai/backend/model_manager/load/model_cache/__init__.py @@ -1,5 +1,3 @@ """Init file for RamCache.""" -from .model_cache_base import ModelCacheBase -from .model_cache_default import ModelCache _all__ = ["ModelCacheBase", "ModelCache"] diff --git a/invokeai/backend/model_manager/load/model_loaders/controlnet.py b/invokeai/backend/model_manager/load/model_loaders/controlnet.py index 8e6a80ceb2..e61e2b46a6 100644 --- a/invokeai/backend/model_manager/load/model_loaders/controlnet.py +++ b/invokeai/backend/model_manager/load/model_loaders/controlnet.py @@ -14,8 +14,10 @@ from invokeai.backend.model_manager import ( ) from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers from invokeai.backend.model_manager.load.load_base import AnyModelLoader + from .generic_diffusers import GenericDiffusersLoader + @AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers) @AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint) class ControlnetLoader(GenericDiffusersLoader): @@ -37,7 +39,7 @@ class ControlnetLoader(GenericDiffusersLoader): if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}: raise Exception(f"Vae conversion not supported for model type: {config.base}") else: - assert hasattr(config, 'config') + assert hasattr(config, "config") config_file = config.config if weights_path.suffix == ".safetensors": diff --git a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py index f92a9048c5..03c26f3a0c 100644 --- a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +++ b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py @@ -15,6 +15,7 @@ from invokeai.backend.model_manager import ( from invokeai.backend.model_manager.load.load_base import AnyModelLoader from invokeai.backend.model_manager.load.load_default import ModelLoader + @AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers) @AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers) class GenericDiffusersLoader(ModelLoader): diff --git a/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py b/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py index 63dc3790f1..27ced41c1e 100644 --- a/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py +++ b/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py @@ -1,11 +1,11 @@ # Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team """Class for IP Adapter model loading in InvokeAI.""" -import torch - from pathlib import Path from typing import Optional +import torch + from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter from invokeai.backend.model_manager import ( AnyModel, @@ -18,6 +18,7 @@ from invokeai.backend.model_manager import ( from invokeai.backend.model_manager.load.load_base import AnyModelLoader from invokeai.backend.model_manager.load.load_default import ModelLoader + @AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI) class IPAdapterInvokeAILoader(ModelLoader): """Class to load IP Adapter diffusers models.""" @@ -36,4 +37,3 @@ class IPAdapterInvokeAILoader(ModelLoader): dtype=self._torch_dtype, ) return model - diff --git a/invokeai/backend/model_manager/load/model_loaders/lora.py b/invokeai/backend/model_manager/load/model_loaders/lora.py index 4d19aadb7d..d8e5f920e2 100644 --- a/invokeai/backend/model_manager/load/model_loaders/lora.py +++ b/invokeai/backend/model_manager/load/model_loaders/lora.py @@ -2,13 +2,12 @@ """Class for LoRA model loading in InvokeAI.""" +from logging import Logger from pathlib import Path from typing import Optional, Tuple -from logging import Logger -from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase -from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.backend.embeddings.lora import LoRAModelRaw from invokeai.backend.model_manager import ( AnyModel, AnyModelConfig, @@ -18,9 +17,11 @@ from invokeai.backend.model_manager import ( ModelType, SubModelType, ) -from invokeai.backend.model_manager.lora import LoRAModelRaw +from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.load_base import AnyModelLoader from invokeai.backend.model_manager.load.load_default import ModelLoader +from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase + @AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers) @AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris) @@ -47,6 +48,7 @@ class LoraLoader(ModelLoader): ) -> AnyModel: if submodel_type is not None: raise ValueError("There are no submodels in a LoRA model.") + assert self._model_base is not None model = LoRAModelRaw.from_checkpoint( file_path=model_path, dtype=self._torch_dtype, @@ -56,9 +58,11 @@ class LoraLoader(ModelLoader): # override def _get_model_path( - self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None + self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None ) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]: - self._model_base = config.base # cheating a little - setting this variable for later call to _load_model() + self._model_base = ( + config.base + ) # cheating a little - we remember this variable for using in the subsequent call to _load_model() model_base_path = self._app_config.models_path model_path = model_base_path / config.path @@ -72,5 +76,3 @@ class LoraLoader(ModelLoader): result = model_path.resolve(), config, submodel_type return result - - diff --git a/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py b/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py new file mode 100644 index 0000000000..394fddc75d --- /dev/null +++ b/invokeai/backend/model_manager/load/model_loaders/textual_inversion.py @@ -0,0 +1,55 @@ +# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team +"""Class for TI model loading in InvokeAI.""" + + +from pathlib import Path +from typing import Optional, Tuple + +from invokeai.backend.embeddings.model_patcher import TextualInversionModel as TextualInversionModelRaw +from invokeai.backend.model_manager import ( + AnyModel, + AnyModelConfig, + BaseModelType, + ModelFormat, + ModelRepoVariant, + ModelType, + SubModelType, +) +from invokeai.backend.model_manager.load.load_base import AnyModelLoader +from invokeai.backend.model_manager.load.load_default import ModelLoader + + +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFile) +@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.TextualInversion, format=ModelFormat.EmbeddingFolder) +class TextualInversionLoader(ModelLoader): + """Class to load TI models.""" + + def _load_model( + self, + model_path: Path, + model_variant: Optional[ModelRepoVariant] = None, + submodel_type: Optional[SubModelType] = None, + ) -> AnyModel: + if submodel_type is not None: + raise ValueError("There are no submodels in a TI model.") + model = TextualInversionModelRaw.from_checkpoint( + file_path=model_path, + dtype=self._torch_dtype, + ) + return model + + # override + def _get_model_path( + self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None + ) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]: + model_path = self._app_config.models_path / config.path + + if config.format == ModelFormat.EmbeddingFolder: + path = model_path / "learned_embeds.bin" + else: + path = model_path + + if not path.exists(): + raise OSError(f"The embedding file at {path} was not found") + + return path, config, submodel_type diff --git a/invokeai/backend/model_manager/load/model_loaders/vae.py b/invokeai/backend/model_manager/load/model_loaders/vae.py index 7a35e53459..882ae05577 100644 --- a/invokeai/backend/model_manager/load/model_loaders/vae.py +++ b/invokeai/backend/model_manager/load/model_loaders/vae.py @@ -15,6 +15,7 @@ from invokeai.backend.model_manager import ( ) from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers from invokeai.backend.model_manager.load.load_base import AnyModelLoader + from .generic_diffusers import GenericDiffusersLoader diff --git a/invokeai/backend/model_manager/load/model_util.py b/invokeai/backend/model_manager/load/model_util.py index 404c88bbbc..3f2d22595e 100644 --- a/invokeai/backend/model_manager/load/model_util.py +++ b/invokeai/backend/model_manager/load/model_util.py @@ -3,13 +3,13 @@ import json from pathlib import Path -from typing import Optional, Union +from typing import Optional import torch from diffusers import DiffusionPipeline from invokeai.backend.model_manager.config import AnyModel -from invokeai.backend.model_manager.onnx_runtime import IAIOnnxRuntimeModel +from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel def calc_model_size_by_data(model: AnyModel) -> int: diff --git a/invokeai/backend/model_manager/onnx_runtime.py b/invokeai/backend/onnx/onnx_runtime.py similarity index 100% rename from invokeai/backend/model_manager/onnx_runtime.py rename to invokeai/backend/onnx/onnx_runtime.py