# Copyright (c) 2024 The InvokeAI Development team """LoRA model support.""" import torch from safetensors.torch import load_file from pathlib import Path from typing import Dict, Optional, Union, List, Tuple from typing_extensions import Self from invokeai.backend.model_manager import BaseModelType class LoRALayerBase: # rank: Optional[int] # alpha: Optional[float] # bias: Optional[torch.Tensor] # layer_key: str # @property # def scale(self): # return self.alpha / self.rank if (self.alpha and self.rank) else 1.0 def __init__( self, layer_key: str, values: Dict[str, torch.Tensor], ): if "alpha" in values: self.alpha = values["alpha"].item() else: self.alpha = None if "bias_indices" in values and "bias_values" in values and "bias_size" in values: self.bias: Optional[torch.Tensor] = torch.sparse_coo_tensor( values["bias_indices"], values["bias_values"], tuple(values["bias_size"]), ) else: self.bias = None self.rank = None # set in layer implementation self.layer_key = layer_key def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: raise NotImplementedError() def calc_size(self) -> int: model_size = 0 for val in [self.bias]: if val is not None: model_size += val.nelement() * val.element_size() return model_size def to( self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> None: if self.bias is not None: self.bias = self.bias.to(device=device, dtype=dtype) # TODO: find and debug lora/locon with bias class LoRALayer(LoRALayerBase): # up: torch.Tensor # mid: Optional[torch.Tensor] # down: torch.Tensor def __init__( self, layer_key: str, values: Dict[str, torch.Tensor], ): super().__init__(layer_key, values) self.up = values["lora_up.weight"] self.down = values["lora_down.weight"] if "lora_mid.weight" in values: self.mid: Optional[torch.Tensor] = values["lora_mid.weight"] else: self.mid = None self.rank = self.down.shape[0] def get_weight(self, orig_weight: torch.Tensor): if self.mid is not None: up = self.up.reshape(self.up.shape[0], self.up.shape[1]) down = self.down.reshape(self.down.shape[0], self.down.shape[1]) weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down) else: weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1) return weight def calc_size(self) -> int: model_size = super().calc_size() for val in [self.up, self.mid, self.down]: if val is not None: model_size += val.nelement() * val.element_size() return model_size def to( self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> None: super().to(device=device, dtype=dtype) self.up = self.up.to(device=device, dtype=dtype) self.down = self.down.to(device=device, dtype=dtype) if self.mid is not None: self.mid = self.mid.to(device=device, dtype=dtype) class LoHALayer(LoRALayerBase): # w1_a: torch.Tensor # w1_b: torch.Tensor # w2_a: torch.Tensor # w2_b: torch.Tensor # t1: Optional[torch.Tensor] = None # t2: Optional[torch.Tensor] = None def __init__( self, layer_key: str, values: Dict[str, torch.Tensor] ): super().__init__(layer_key, values) self.w1_a = values["hada_w1_a"] self.w1_b = values["hada_w1_b"] self.w2_a = values["hada_w2_a"] self.w2_b = values["hada_w2_b"] if "hada_t1" in values: self.t1: Optional[torch.Tensor] = values["hada_t1"] else: self.t1 = None if "hada_t2" in values: self.t2: Optional[torch.Tensor] = values["hada_t2"] else: self.t2 = None self.rank = self.w1_b.shape[0] def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: if self.t1 is None: weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b) else: rebuild1 = torch.einsum("i j k l, j r, i p -> p r k l", self.t1, self.w1_b, self.w1_a) rebuild2 = torch.einsum("i j k l, j r, i p -> p r k l", self.t2, self.w2_b, self.w2_a) weight = rebuild1 * rebuild2 return weight def calc_size(self) -> int: model_size = super().calc_size() for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]: if val is not None: model_size += val.nelement() * val.element_size() return model_size def to( self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> None: super().to(device=device, dtype=dtype) self.w1_a = self.w1_a.to(device=device, dtype=dtype) self.w1_b = self.w1_b.to(device=device, dtype=dtype) if self.t1 is not None: self.t1 = self.t1.to(device=device, dtype=dtype) self.w2_a = self.w2_a.to(device=device, dtype=dtype) self.w2_b = self.w2_b.to(device=device, dtype=dtype) if self.t2 is not None: self.t2 = self.t2.to(device=device, dtype=dtype) class LoKRLayer(LoRALayerBase): # w1: Optional[torch.Tensor] = None # w1_a: Optional[torch.Tensor] = None # w1_b: Optional[torch.Tensor] = None # w2: Optional[torch.Tensor] = None # w2_a: Optional[torch.Tensor] = None # w2_b: Optional[torch.Tensor] = None # t2: Optional[torch.Tensor] = None def __init__( self, layer_key: str, values: Dict[str, torch.Tensor], ): super().__init__(layer_key, values) if "lokr_w1" in values: self.w1: Optional[torch.Tensor] = values["lokr_w1"] self.w1_a = None self.w1_b = None else: self.w1 = None self.w1_a = values["lokr_w1_a"] self.w1_b = values["lokr_w1_b"] if "lokr_w2" in values: self.w2: Optional[torch.Tensor] = values["lokr_w2"] self.w2_a = None self.w2_b = None else: self.w2 = None self.w2_a = values["lokr_w2_a"] self.w2_b = values["lokr_w2_b"] if "lokr_t2" in values: self.t2: Optional[torch.Tensor] = values["lokr_t2"] else: self.t2 = None if "lokr_w1_b" in values: self.rank = values["lokr_w1_b"].shape[0] elif "lokr_w2_b" in values: self.rank = values["lokr_w2_b"].shape[0] else: self.rank = None # unscaled def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: w1: Optional[torch.Tensor] = self.w1 if w1 is None: assert self.w1_a is not None assert self.w1_b is not None w1 = self.w1_a @ self.w1_b w2 = self.w2 if w2 is None: if self.t2 is None: assert self.w2_a is not None assert self.w2_b is not None w2 = self.w2_a @ self.w2_b else: w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b) if len(w2.shape) == 4: w1 = w1.unsqueeze(2).unsqueeze(2) w2 = w2.contiguous() assert w1 is not None assert w2 is not None weight = torch.kron(w1, w2) return weight def calc_size(self) -> int: model_size = super().calc_size() for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]: if val is not None: model_size += val.nelement() * val.element_size() return model_size def to( self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> None: super().to(device=device, dtype=dtype) if self.w1 is not None: self.w1 = self.w1.to(device=device, dtype=dtype) else: assert self.w1_a is not None assert self.w1_b is not None self.w1_a = self.w1_a.to(device=device, dtype=dtype) self.w1_b = self.w1_b.to(device=device, dtype=dtype) if self.w2 is not None: self.w2 = self.w2.to(device=device, dtype=dtype) else: assert self.w2_a is not None assert self.w2_b is not None self.w2_a = self.w2_a.to(device=device, dtype=dtype) self.w2_b = self.w2_b.to(device=device, dtype=dtype) if self.t2 is not None: self.t2 = self.t2.to(device=device, dtype=dtype) class FullLayer(LoRALayerBase): # weight: torch.Tensor def __init__( self, layer_key: str, values: Dict[str, torch.Tensor], ): super().__init__(layer_key, values) self.weight = values["diff"] if len(values.keys()) > 1: _keys = list(values.keys()) _keys.remove("diff") raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}") self.rank = None # unscaled def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: return self.weight def calc_size(self) -> int: model_size = super().calc_size() model_size += self.weight.nelement() * self.weight.element_size() return model_size def to( self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): super().to(device=device, dtype=dtype) self.weight = self.weight.to(device=device, dtype=dtype) class IA3Layer(LoRALayerBase): # weight: torch.Tensor # on_input: torch.Tensor def __init__( self, layer_key: str, values: Dict[str, torch.Tensor], ): super().__init__(layer_key, values) self.weight = values["weight"] self.on_input = values["on_input"] self.rank = None # unscaled def get_weight(self, orig_weight: torch.Tensor): weight = self.weight if not self.on_input: weight = weight.reshape(-1, 1) return orig_weight * weight def calc_size(self) -> int: model_size = super().calc_size() model_size += self.weight.nelement() * self.weight.element_size() model_size += self.on_input.nelement() * self.on_input.element_size() return model_size def to( self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): super().to(device=device, dtype=dtype) self.weight = self.weight.to(device=device, dtype=dtype) self.on_input = self.on_input.to(device=device, dtype=dtype) AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer] # TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix class LoRAModelRaw: # (torch.nn.Module): _name: str layers: Dict[str, AnyLoRALayer] def __init__( self, name: str, layers: Dict[str, AnyLoRALayer], ): self._name = name self.layers = layers @property def name(self) -> str: return self._name def to( self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> None: # TODO: try revert if exception? for _key, layer in self.layers.items(): layer.to(device=device, dtype=dtype) def calc_size(self) -> int: model_size = 0 for _, layer in self.layers.items(): model_size += layer.calc_size() return model_size @classmethod def _convert_sdxl_keys_to_diffusers_format(cls, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """Convert the keys of an SDXL LoRA state_dict to diffusers format. The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in diffusers format, then this function will have no effect. This function is adapted from: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409 Args: state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict. Raises: ValueError: If state_dict contains an unrecognized key, or not all keys could be converted. Returns: Dict[str, Tensor]: The diffusers-format state_dict. """ converted_count = 0 # The number of Stability AI keys converted to diffusers format. not_converted_count = 0 # The number of keys that were not converted. # Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes. # For example, we want to efficiently find `input_blocks_4_1` in the list when searching for # `input_blocks_4_1_proj_in`. stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP) stability_unet_keys.sort() new_state_dict = {} for full_key, value in state_dict.items(): if full_key.startswith("lora_unet_"): search_key = full_key.replace("lora_unet_", "") # Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix. position = bisect.bisect_right(stability_unet_keys, search_key) map_key = stability_unet_keys[position - 1] # Now, check if the map_key *actually* matches the search_key. if search_key.startswith(map_key): new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key]) new_state_dict[new_key] = value converted_count += 1 else: new_state_dict[full_key] = value not_converted_count += 1 elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"): # The CLIP text encoders have the same keys in both Stability AI and diffusers formats. new_state_dict[full_key] = value continue else: raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.") if converted_count > 0 and not_converted_count > 0: raise ValueError( f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count}," f" not_converted={not_converted_count}" ) return new_state_dict @classmethod def from_checkpoint( cls, file_path: Union[str, Path], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, base_model: Optional[BaseModelType] = None, ) -> Self: device = device or torch.device("cpu") dtype = dtype or torch.float32 if isinstance(file_path, str): file_path = Path(file_path) model = cls( name=file_path.stem, # TODO: layers={}, ) if file_path.suffix == ".safetensors": state_dict = load_file(file_path.absolute().as_posix(), device="cpu") else: state_dict = torch.load(file_path, map_location="cpu") state_dict = cls._group_state(state_dict) if base_model == BaseModelType.StableDiffusionXL: state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict) for layer_key, values in state_dict.items(): # lora and locon if "lora_down.weight" in values: layer: AnyLoRALayer = LoRALayer(layer_key, values) # loha elif "hada_w1_b" in values: layer = LoHALayer(layer_key, values) # lokr elif "lokr_w1_b" in values or "lokr_w1" in values: layer = LoKRLayer(layer_key, values) # diff elif "diff" in values: layer = FullLayer(layer_key, values) # ia3 elif "weight" in values and "on_input" in values: layer = IA3Layer(layer_key, values) else: print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}") raise Exception("Unknown lora format!") # lower memory consumption by removing already parsed layer values state_dict[layer_key].clear() layer.to(device=device, dtype=dtype) model.layers[layer_key] = layer return model @staticmethod def _group_state(state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]: state_dict_groupped: Dict[str, Dict[str, torch.Tensor]] = {} for key, value in state_dict.items(): stem, leaf = key.split(".", 1) if stem not in state_dict_groupped: state_dict_groupped[stem] = {} state_dict_groupped[stem][leaf] = value return state_dict_groupped # code from # https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32 def make_sdxl_unet_conversion_map() -> List[Tuple[str,str]]: """Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format.""" unet_conversion_map_layer = [] for i in range(3): # num_blocks is 3 in sdxl # loop over downblocks/upblocks for j in range(2): # loop over resnets/attentions for downblocks hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) if i < 3: # no attention layers in down_blocks.3 hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) for j in range(3): # loop over resnets/attentions for upblocks hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." sd_up_res_prefix = f"output_blocks.{3*i + j}.0." unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) # if i > 0: commentout for sdxl # no attention layers in up_blocks.0 hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) if i < 3: # no downsample in down_blocks.3 hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) # no upsample in up_blocks.3 hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) hf_mid_atn_prefix = "mid_block.attentions.0." sd_mid_atn_prefix = "middle_block.1." unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) for j in range(2): hf_mid_res_prefix = f"mid_block.resnets.{j}." sd_mid_res_prefix = f"middle_block.{2*j}." unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) unet_conversion_map_resnet = [ # (stable-diffusion, HF Diffusers) ("in_layers.0.", "norm1."), ("in_layers.2.", "conv1."), ("out_layers.0.", "norm2."), ("out_layers.3.", "conv2."), ("emb_layers.1.", "time_emb_proj."), ("skip_connection.", "conv_shortcut."), ] unet_conversion_map = [] for sd, hf in unet_conversion_map_layer: if "resnets" in hf: for sd_res, hf_res in unet_conversion_map_resnet: unet_conversion_map.append((sd + sd_res, hf + hf_res)) else: unet_conversion_map.append((sd, hf)) for j in range(2): hf_time_embed_prefix = f"time_embedding.linear_{j+1}." sd_time_embed_prefix = f"time_embed.{j*2}." unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix)) for j in range(2): hf_label_embed_prefix = f"add_embedding.linear_{j+1}." sd_label_embed_prefix = f"label_emb.0.{j*2}." unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix)) unet_conversion_map.append(("input_blocks.0.0.", "conv_in.")) unet_conversion_map.append(("out.0.", "conv_norm_out.")) unet_conversion_map.append(("out.2.", "conv_out.")) return unet_conversion_map SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = { sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map() }