From 6bfb4927c74c19883c8df6806f94e605110e4117 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 8 Apr 2024 10:06:11 -0400 Subject: [PATCH] WIP - not working --- invokeai/backend/lora/lora_layer.py | 45 ++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/invokeai/backend/lora/lora_layer.py b/invokeai/backend/lora/lora_layer.py index ba5ce594c2..9c1db08b61 100644 --- a/invokeai/backend/lora/lora_layer.py +++ b/invokeai/backend/lora/lora_layer.py @@ -1,32 +1,48 @@ -from typing import Dict, Optional +from typing import Optional import torch from invokeai.backend.lora.lora_layer_base import LoRALayerBase -# TODO: find and debug lora/locon with bias class LoRALayer(LoRALayerBase): - # up: torch.Tensor - # mid: Optional[torch.Tensor] - # down: torch.Tensor - def __init__( self, layer_key: str, - values: Dict[str, torch.Tensor], + values: dict[str, torch.Tensor], ): super().__init__(layer_key, values) self.up = values["lora_up.weight"] self.down = values["lora_down.weight"] - if "lora_mid.weight" in values: - self.mid: Optional[torch.Tensor] = values["lora_mid.weight"] - else: - self.mid = None + self.mid: Optional[torch.Tensor] = values.get("lora_mid.weight", None) + self.dora_scale: Optional[torch.Tensor] = values.get("dora_scale", None) self.rank = self.down.shape[0] + def _apply_dora(self, orig_weight: torch.Tensor, lora_weight: torch.Tensor) -> torch.Tensor: + """Apply DoRA to the weight matrix. + + This function is based roughly on the reference implementation in PEFT, but handles scaling in a slightly + different way: + https://github.com/huggingface/peft/blob/26726bf1ddee6ca75ed4e1bfd292094526707a78/src/peft/tuners/lora/layer.py#L421-L433 + + """ + # Merge the original weight with the LoRA weight. + merged_weight = orig_weight + lora_weight + + # Calculate the vector-wise L2 norm of the weight matrix across each column vector. + weight_norm: torch.Tensor = torch.linalg.norm(merged_weight, dim=1) + + dora_factor = self.dora_scale / weight_norm + new_weight = dora_factor * merged_weight + + # TODO(ryand): This is wasteful. We already have the final weight, but we calculate the diff, because that is + # what the `get_weight()` API is expected to return. If we do refactor this, we'll have to give some thought to + # how lora weight scaling should be applied - having the full weight diff makes this easy. + weight_diff = new_weight - orig_weight + return weight_diff + def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: if self.mid is not None: up = self.up.reshape(self.up.shape[0], self.up.shape[1]) @@ -35,6 +51,10 @@ class LoRALayer(LoRALayerBase): else: weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1) + if self.dora_scale is not None: + assert orig_weight is not None + weight = self._apply_dora(orig_weight, weight) + return weight def calc_size(self) -> int: @@ -56,3 +76,6 @@ class LoRALayer(LoRALayerBase): if self.mid is not None: self.mid = self.mid.to(device=device, dtype=dtype) + + if self.dora_scale is not None: + self.dora_scale = self.dora_scale.to(device=device, dtype=dtype)