from __future__ import annotations import copy from pathlib import Path from contextlib import contextmanager from typing import Optional, Dict, Tuple, Any import torch from safetensors.torch import load_file from torch.utils.hooks import RemovableHandle from diffusers.models import UNet2DConditionModel from transformers import CLIPTextModel from compel.embeddings_provider import BaseTextualInversionManager class LoRALayerBase: #rank: Optional[int] #alpha: Optional[float] #bias: Optional[torch.Tensor] #layer_key: str #@property #def scale(self): # return self.alpha / self.rank if (self.alpha and self.rank) else 1.0 def __init__( self, layer_key: str, values: dict, ): if "alpha" in values: self.alpha = values["alpha"].item() else: self.alpha = None if ( "bias_indices" in values and "bias_values" in values and "bias_size" in values ): self.bias = torch.sparse_coo_tensor( values["bias_indices"], values["bias_values"], tuple(values["bias_size"]), ) else: self.bias = None self.rank = None # set in layer implementation self.layer_key = layer_key def forward( self, module: torch.nn.Module, input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure multiplier: float, ): if type(module) == torch.nn.Conv2d: op = torch.nn.functional.conv2d extra_args = dict( stride=module.stride, padding=module.padding, dilation=module.dilation, groups=module.groups, ) else: op = torch.nn.functional.linear extra_args = {} weight = self.get_weight() bias = self.bias if self.bias is not None else 0 scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0 return op( *input_h, (weight + bias).view(module.weight.shape), None, **extra_args, ) * multiplier * scale def get_weight(self): raise NotImplementedError() def calc_size(self) -> int: model_size = 0 for val in [self.bias]: if val is not None: model_size += val.nelement() * val.element_size() return model_size def to( self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): if self.bias is not None: self.bias = self.bias.to(device=device, dtype=dtype) # TODO: find and debug lora/locon with bias class LoRALayer(LoRALayerBase): #up: torch.Tensor #mid: Optional[torch.Tensor] #down: torch.Tensor def __init__( self, layer_key: str, values: dict, ): super().__init__(layer_key, values) self.up = values["lora_up.weight"] self.down = values["lora_down.weight"] if "lora_mid.weight" in values: self.mid = values["lora_mid.weight"] else: self.mid = None self.rank = self.down.shape[0] def get_weight(self): if self.mid is not None: up = self.up.reshape(up.shape[0], up.shape[1]) down = self.down.reshape(up.shape[0], up.shape[1]) weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down) else: weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1) return weight def calc_size(self) -> int: model_size = super().calc_size() for val in [self.up, self.mid, self.down]: if val is not None: model_size += val.nelement() * val.element_size() return model_size def to( self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): super().to(device=device, dtype=dtype) self.up = self.up.to(device=device, dtype=dtype) self.down = self.down.to(device=device, dtype=dtype) if self.mid is not None: self.mid = self.mid.to(device=device, dtype=dtype) class LoHALayer(LoRALayerBase): #w1_a: torch.Tensor #w1_b: torch.Tensor #w2_a: torch.Tensor #w2_b: torch.Tensor #t1: Optional[torch.Tensor] = None #t2: Optional[torch.Tensor] = None def __init__( self, layer_key: str, values: dict, ): super().__init__(layer_key, values) self.w1_a = values["hada_w1_a"] self.w1_b = values["hada_w1_b"] self.w2_a = values["hada_w2_a"] self.w2_b = values["hada_w2_b"] if "hada_t1" in values: self.t1 = values["hada_t1"] else: self.t1 = None if "hada_t2" in values: self.t2 = values["hada_t2"] else: self.t2 = None self.rank = self.w1_b.shape[0] def get_weight(self): if self.t1 is None: weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b) else: rebuild1 = torch.einsum( "i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a ) rebuild2 = torch.einsum( "i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a ) weight = rebuild1 * rebuild2 return weight def calc_size(self) -> int: model_size = super().calc_size() for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]: if val is not None: model_size += val.nelement() * val.element_size() return model_size def to( self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): super().to(device=device, dtype=dtype) self.w1_a = self.w1_a.to(device=device, dtype=dtype) self.w1_b = self.w1_b.to(device=device, dtype=dtype) if self.t1 is not None: self.t1 = self.t1.to(device=device, dtype=dtype) self.w2_a = self.w2_a.to(device=device, dtype=dtype) self.w2_b = self.w2_b.to(device=device, dtype=dtype) if self.t2 is not None: self.t2 = self.t2.to(device=device, dtype=dtype) class LoKRLayer(LoRALayerBase): #w1: Optional[torch.Tensor] = None #w1_a: Optional[torch.Tensor] = None #w1_b: Optional[torch.Tensor] = None #w2: Optional[torch.Tensor] = None #w2_a: Optional[torch.Tensor] = None #w2_b: Optional[torch.Tensor] = None #t2: Optional[torch.Tensor] = None def __init__( self, layer_key: str, values: dict, ): super().__init__(layer_key, values) if "lokr_w1" in values: self.w1 = values["lokr_w1"] self.w1_a = None self.w1_b = None else: self.w1 = None self.w1_a = values["lokr_w1_a"] self.w1_b = values["lokr_w1_b"] if "lokr_w2" in values: self.w2 = values["lokr_w2"] self.w2_a = None self.w2_b = None else: self.w2 = None self.w2_a = values["lokr_w2_a"] self.w2_b = values["lokr_w2_b"] if "lokr_t2" in values: self.t2 = values["lokr_t2"] else: self.t2 = None if "lokr_w1_b" in values: self.rank = values["lokr_w1_b"].shape[0] elif "lokr_w2_b" in values: self.rank = values["lokr_w2_b"].shape[0] else: self.rank = None # unscaled def get_weight(self): w1 = self.w1 if w1 is None: w1 = self.w1_a @ self.w1_b w2 = self.w2 if w2 is None: if self.t2 is None: w2 = self.w2_a @ self.w2_b else: w2 = torch.einsum('i j k l, i p, j r -> p r k l', self.t2, self.w2_a, self.w2_b) if len(w2.shape) == 4: w1 = w1.unsqueeze(2).unsqueeze(2) w2 = w2.contiguous() weight = torch.kron(w1, w2) return weight def calc_size(self) -> int: model_size = super().calc_size() for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]: if val is not None: model_size += val.nelement() * val.element_size() return model_size def to( self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): super().to(device=device, dtype=dtype) if self.w1 is not None: self.w1 = self.w1.to(device=device, dtype=dtype) else: self.w1_a = self.w1_a.to(device=device, dtype=dtype) self.w1_b = self.w1_b.to(device=device, dtype=dtype) if self.w2 is not None: self.w2 = self.w2.to(device=device, dtype=dtype) else: self.w2_a = self.w2_a.to(device=device, dtype=dtype) self.w2_b = self.w2_b.to(device=device, dtype=dtype) if self.t2 is not None: self.t2 = self.t2.to(device=device, dtype=dtype) class LoRAModel: #(torch.nn.Module): _name: str layers: Dict[str, LoRALayer] _device: torch.device _dtype: torch.dtype def __init__( self, name: str, layers: Dict[str, LoRALayer], device: torch.device, dtype: torch.dtype, ): self._name = name self._device = device or torch.cpu self._dtype = dtype or torch.float32 self.layers = layers @property def name(self): return self._name @property def device(self): return self._device @property def dtype(self): return self._dtype def to( self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> LoRAModel: # TODO: try revert if exception? for key, layer in self.layers.items(): layer.to(device=device, dtype=dtype) self._device = device self._dtype = dtype def calc_size(self) -> int: model_size = 0 for _, layer in self.layers.items(): model_size += layer.calc_size() return model_size @classmethod def from_checkpoint( cls, file_path: Union[str, Path], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): device = device or torch.device("cpu") dtype = dtype or torch.float32 if isinstance(file_path, str): file_path = Path(file_path) model = cls( device=device, dtype=dtype, name=file_path.stem, # TODO: layers=dict(), ) if file_path.suffix == ".safetensors": state_dict = load_file(file_path.absolute().as_posix(), device="cpu") else: state_dict = torch.load(file_path, map_location="cpu") state_dict = cls._group_state(state_dict) for layer_key, values in state_dict.items(): # lora and locon if "lora_down.weight" in values: layer = LoRALayer(layer_key, values) # loha elif "hada_w1_b" in values: layer = LoHALayer(layer_key, values) # lokr elif "lokr_w1_b" in values or "lokr_w1" in values: layer = LoKRLayer(layer_key, values) else: # TODO: diff/ia3/... format print( f">> Encountered unknown lora layer module in {self.name}: {layer_key}" ) return # lower memory consumption by removing already parsed layer values state_dict[layer_key].clear() layer.to(device=device, dtype=dtype) model.layers[layer_key] = layer return model @staticmethod def _group_state(state_dict: dict): state_dict_groupped = dict() for key, value in state_dict.items(): stem, leaf = key.split(".", 1) if stem not in state_dict_groupped: state_dict_groupped[stem] = dict() state_dict_groupped[stem][leaf] = value return state_dict_groupped """ loras = [ (lora_model1, 0.7), (lora_model2, 0.4), ] with LoRAHelper.apply_lora_unet(unet, loras): # unet with applied loras # unmodified unet """ # TODO: rename smth like ModelPatcher and add TI method? class ModelPatcher: @staticmethod def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: assert "." not in lora_key if not lora_key.startswith(prefix): raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}") module = model module_key = "" key_parts = lora_key[len(prefix):].split('_') submodule_name = key_parts.pop(0) while len(key_parts) > 0: try: module = module.get_submodule(submodule_name) module_key += "." + submodule_name submodule_name = key_parts.pop(0) except: submodule_name += "_" + key_parts.pop(0) module = module.get_submodule(submodule_name) module_key = (module_key + "." + submodule_name).lstrip(".") return (module_key, module) @staticmethod def _lora_forward_hook( applied_loras: List[Tuple[LoraModel, float]], layer_name: str, ): def lora_forward(module, input_h, output): if len(applied_loras) == 0: return output for lora, weight in applied_loras: layer = lora.layers.get(layer_name, None) if layer is None: continue output += layer.forward(module, input_h, weight) return output return lora_forward @classmethod @contextmanager def apply_lora_unet( cls, unet: UNet2DConditionModel, loras: List[Tuple[LoRAModel, float]], ): with cls.apply_lora(unet, loras, "lora_unet_"): yield @classmethod @contextmanager def apply_lora_text_encoder( cls, text_encoder: CLIPTextModel, loras: List[Tuple[LoRAModel, float]], ): with cls.apply_lora(text_encoder, loras, "lora_te_"): yield @classmethod @contextmanager def apply_lora( cls, model: torch.nn.Module, loras: List[Tuple[LoraModel, float]], prefix: str, ): original_weights = dict() try: with torch.no_grad(): for lora, lora_weight in loras: #assert lora.device.type == "cpu" for layer_key, layer in lora.layers.items(): if not layer_key.startswith(prefix): continue module_key, module = cls._resolve_lora_key(model, layer_key, prefix) if module_key not in original_weights: original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) # enable autocast to calc fp16 loras on cpu with torch.autocast(device_type="cpu"): layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 layer_weight = layer.get_weight() * lora_weight * layer_scale if module.weight.shape != layer_weight.shape: # TODO: debug on lycoris layer_weight = layer_weight.reshape(module.weight.shape) module.weight += layer_weight.to(device=module.weight.device, dtype=module.weight.dtype) yield # wait for context manager exit finally: with torch.no_grad(): for module_key, weight in original_weights.items(): model.get_submodule(module_key).weight.copy_(weight) @classmethod @contextmanager def apply_ti( cls, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, ti_list: List[Any], ) -> Tuple[CLIPTokenizer, TextualInversionManager]: init_tokens_count = None new_tokens_added = None try: ti_tokenizer = copy.deepcopy(tokenizer) ti_manager = TextualInversionManager(ti_tokenizer) init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings def _get_trigger(ti, index): trigger = ti.name if index > 0: trigger += f"-!pad-{i}" return f"<{trigger}>" # modify tokenizer new_tokens_added = 0 for ti in ti_list: for i in range(ti.embedding.shape[0]): new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i)) # modify text_encoder text_encoder.resize_token_embeddings(init_tokens_count + new_tokens_added) model_embeddings = text_encoder.get_input_embeddings() for ti in ti_list: ti_tokens = [] for i in range(ti.embedding.shape[0]): embedding = ti.embedding[i] trigger = _get_trigger(ti, i) token_id = ti_tokenizer.convert_tokens_to_ids(trigger) if token_id == ti_tokenizer.unk_token_id: raise RuntimeError(f"Unable to find token id for token '{trigger}'") if model_embeddings.weight.data[token_id].shape != embedding.shape: raise ValueError( f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {model_embeddings.weight.data[token_id].shape[0]}." ) model_embeddings.weight.data[token_id] = embedding.to(device=text_encoder.device, dtype=text_encoder.dtype) ti_tokens.append(token_id) if len(ti_tokens) > 1: ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:] yield ti_tokenizer, ti_manager finally: if init_tokens_count and new_tokens_added: text_encoder.resize_token_embeddings(init_tokens_count) class TextualInversionModel: name: str embedding: torch.Tensor # [n, 768]|[n, 1280] @classmethod def from_checkpoint( cls, file_path: Union[str, Path], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): if not isinstance(file_path, Path): file_path = Path(file_path) result = cls() # TODO: result.name = file_path.stem # TODO: if file_path.suffix == ".safetensors": state_dict = load_file(file_path.absolute().as_posix(), device="cpu") else: state_dict = torch.load(file_path, map_location="cpu") # both v1 and v2 format embeddings # difference mostly in metadata if "string_to_param" in state_dict: if len(state_dict["string_to_param"]) > 1: print(f"Warn: Embedding \"{file_path.name}\" contains multiple tokens, which is not supported. The first token will be used.") result.embedding = next(iter(state_dict["string_to_param"].values())) # v3 (easynegative) elif "emb_params" in state_dict: result.embedding = state_dict["emb_params"] # v4(diffusers bin files) else: result.embedding = next(iter(state_dict.values())) if not isinstance(result.embedding, torch.Tensor): raise ValueError(f"Invalid embeddings file: {file_path.name}") return result class TextualInversionManager(BaseTextualInversionManager): pad_tokens: Dict[int, List[int]] tokenizer: CLIPTokenizer def __init__(self, tokenizer: CLIPTokenizer): self.pad_tokens = dict() self.tokenizer = tokenizer def expand_textual_inversion_token_ids_if_necessary( self, token_ids: list[int] ) -> list[int]: if len(self.pad_tokens) == 0: return token_ids if token_ids[0] == self.tokenizer.bos_token_id: raise ValueError("token_ids must not start with bos_token_id") if token_ids[-1] == self.tokenizer.eos_token_id: raise ValueError("token_ids must not end with eos_token_id") new_token_ids = [] for token_id in token_ids: new_token_ids.append(token_id) if token_id in self.pad_tokens: new_token_ids.extend(self.pad_tokens[token_id]) return new_token_ids