diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index e8e2b3f51f..3da27988dc 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -143,7 +143,7 @@ class ModelPatcher: # with torch.autocast(device_type="cpu"): layer.to(dtype=torch.float32) layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 - layer_weight = layer.get_weight() * lora_weight * layer_scale + layer_weight = layer.get_weight(original_weights[module_key]) * lora_weight * layer_scale if module.weight.shape != layer_weight.shape: # TODO: debug on lycoris @@ -361,7 +361,8 @@ class ONNXModelPatcher: layer.to(dtype=torch.float32) layer_key = layer_key.replace(prefix, "") - layer_weight = layer.get_weight().detach().cpu().numpy() * lora_weight + # TODO: rewrite to pass original tensor weight(required by ia3) + layer_weight = layer.get_weight(None).detach().cpu().numpy() * lora_weight if layer_key is blended_loras: blended_loras[layer_key] += layer_weight else: diff --git a/invokeai/backend/model_management/models/lora.py b/invokeai/backend/model_management/models/lora.py index 1983c05503..c3f25e6852 100644 --- a/invokeai/backend/model_management/models/lora.py +++ b/invokeai/backend/model_management/models/lora.py @@ -122,41 +122,7 @@ class LoRALayerBase: self.rank = None # set in layer implementation self.layer_key = layer_key - def forward( - self, - module: torch.nn.Module, - input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure - multiplier: float, - ): - if type(module) == torch.nn.Conv2d: - op = torch.nn.functional.conv2d - extra_args = dict( - stride=module.stride, - padding=module.padding, - dilation=module.dilation, - groups=module.groups, - ) - - else: - op = torch.nn.functional.linear - extra_args = {} - - weight = self.get_weight() - - bias = self.bias if self.bias is not None else 0 - scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0 - return ( - op( - *input_h, - (weight + bias).view(module.weight.shape), - None, - **extra_args, - ) - * multiplier - * scale - ) - - def get_weight(self): + def get_weight(self, orig_weight: torch.Tensor): raise NotImplementedError() def calc_size(self) -> int: @@ -197,7 +163,7 @@ class LoRALayer(LoRALayerBase): self.rank = self.down.shape[0] - def get_weight(self): + def get_weight(self, orig_weight: torch.Tensor): if self.mid is not None: up = self.up.reshape(self.up.shape[0], self.up.shape[1]) down = self.down.reshape(self.down.shape[0], self.down.shape[1]) @@ -260,7 +226,7 @@ class LoHALayer(LoRALayerBase): self.rank = self.w1_b.shape[0] - def get_weight(self): + def get_weight(self, orig_weight: torch.Tensor): if self.t1 is None: weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b) @@ -342,7 +308,7 @@ class LoKRLayer(LoRALayerBase): else: self.rank = None # unscaled - def get_weight(self): + def get_weight(self, orig_weight: torch.Tensor): w1 = self.w1 if w1 is None: w1 = self.w1_a @ self.w1_b @@ -410,7 +376,7 @@ class FullLayer(LoRALayerBase): self.rank = None # unscaled - def get_weight(self): + def get_weight(self, orig_weight: torch.Tensor): return self.weight def calc_size(self) -> int: @@ -427,6 +393,44 @@ class FullLayer(LoRALayerBase): self.weight = self.weight.to(device=device, dtype=dtype) +class IA3Layer(LoRALayerBase): + # weight: torch.Tensor + # on_input: torch.Tensor + + def __init__( + self, + layer_key: str, + values: dict, + ): + super().__init__(layer_key, values) + + self.weight = values["weight"] + self.on_input = values["on_input"] + + self.rank = None # unscaled + + def get_weight(self, orig_weight: torch.Tensor): + weight = self.weight + if not self.on_input: + weight = weight.reshape(-1, 1) + return orig_weight * weight + + def calc_size(self) -> int: + model_size = super().calc_size() + model_size += self.weight.nelement() * self.weight.element_size() + model_size += self.on_input.nelement() * self.on_input.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().to(device=device, dtype=dtype) + + self.weight = self.weight.to(device=device, dtype=dtype) + self.on_input = self.on_input.to(device=device, dtype=dtype) + # TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix class LoRAModelRaw: # (torch.nn.Module): @@ -547,11 +551,15 @@ class LoRAModelRaw: # (torch.nn.Module): elif "lokr_w1_b" in values or "lokr_w1" in values: layer = LoKRLayer(layer_key, values) + # diff elif "diff" in values: layer = FullLayer(layer_key, values) + # ia3 + elif "weight" in values and "on_input" in values: + layer = IA3Layer(layer_key, values) + else: - # TODO: ia3/... format print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}") raise Exception("Unknown lora format!") diff --git a/invokeai/backend/model_management/util.py b/invokeai/backend/model_management/util.py index f435ab79b6..0702224bc7 100644 --- a/invokeai/backend/model_management/util.py +++ b/invokeai/backend/model_management/util.py @@ -12,37 +12,43 @@ def lora_token_vector_length(checkpoint: dict) -> int: def _get_shape_1(key, tensor, checkpoint): lora_token_vector_length = None + if "." not in key: + return lora_token_vector_length # wrong key format + model_key, lora_key = key.split(".", 1) + # check lora/locon - if ".lora_down.weight" in key: + if lora_key == "lora_down.weight": lora_token_vector_length = tensor.shape[1] # check loha (don't worry about hada_t1/hada_t2 as it used only in 4d shapes) - elif ".hada_w1_b" in key or ".hada_w2_b" in key: + elif lora_key in ["hada_w1_b", "hada_w2_b"]: lora_token_vector_length = tensor.shape[1] # check lokr (don't worry about lokr_t2 as it used only in 4d shapes) - elif ".lokr_" in key: - _lokr_key = key.split(".")[0] - - if _lokr_key + ".lokr_w1" in checkpoint: - _lokr_w1 = checkpoint[_lokr_key + ".lokr_w1"] - elif _lokr_key + "lokr_w1_b" in checkpoint: - _lokr_w1 = checkpoint[_lokr_key + ".lokr_w1_b"] + elif "lokr_" in lora_key: + if model_key + ".lokr_w1" in checkpoint: + _lokr_w1 = checkpoint[model_key + ".lokr_w1"] + elif model_key + "lokr_w1_b" in checkpoint: + _lokr_w1 = checkpoint[model_key + ".lokr_w1_b"] else: return lora_token_vector_length # unknown format - if _lokr_key + ".lokr_w2" in checkpoint: - _lokr_w2 = checkpoint[_lokr_key + ".lokr_w2"] - elif _lokr_key + "lokr_w2_b" in checkpoint: - _lokr_w2 = checkpoint[_lokr_key + ".lokr_w2_b"] + if model_key + ".lokr_w2" in checkpoint: + _lokr_w2 = checkpoint[model_key + ".lokr_w2"] + elif model_key + "lokr_w2_b" in checkpoint: + _lokr_w2 = checkpoint[model_key + ".lokr_w2_b"] else: return lora_token_vector_length # unknown format lora_token_vector_length = _lokr_w1.shape[1] * _lokr_w2.shape[1] - elif ".diff" in key: + elif lora_key == "diff": lora_token_vector_length = tensor.shape[1] + # ia3 can be detected only by shape[0] in text encoder + elif lora_key == "weight" and "lora_unet_" not in model_key: + lora_token_vector_length = tensor.shape[0] + return lora_token_vector_length lora_token_vector_length = None