mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
WIP - not working
This commit is contained in:
parent
c15e9e23ca
commit
6bfb4927c7
@ -1,32 +1,48 @@
|
|||||||
from typing import Dict, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.backend.lora.lora_layer_base import LoRALayerBase
|
from invokeai.backend.lora.lora_layer_base import LoRALayerBase
|
||||||
|
|
||||||
|
|
||||||
# TODO: find and debug lora/locon with bias
|
|
||||||
class LoRALayer(LoRALayerBase):
|
class LoRALayer(LoRALayerBase):
|
||||||
# up: torch.Tensor
|
|
||||||
# mid: Optional[torch.Tensor]
|
|
||||||
# down: torch.Tensor
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
layer_key: str,
|
layer_key: str,
|
||||||
values: Dict[str, torch.Tensor],
|
values: dict[str, torch.Tensor],
|
||||||
):
|
):
|
||||||
super().__init__(layer_key, values)
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
self.up = values["lora_up.weight"]
|
self.up = values["lora_up.weight"]
|
||||||
self.down = values["lora_down.weight"]
|
self.down = values["lora_down.weight"]
|
||||||
if "lora_mid.weight" in values:
|
|
||||||
self.mid: Optional[torch.Tensor] = values["lora_mid.weight"]
|
|
||||||
else:
|
|
||||||
self.mid = None
|
|
||||||
|
|
||||||
|
self.mid: Optional[torch.Tensor] = values.get("lora_mid.weight", None)
|
||||||
|
self.dora_scale: Optional[torch.Tensor] = values.get("dora_scale", None)
|
||||||
self.rank = self.down.shape[0]
|
self.rank = self.down.shape[0]
|
||||||
|
|
||||||
|
def _apply_dora(self, orig_weight: torch.Tensor, lora_weight: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Apply DoRA to the weight matrix.
|
||||||
|
|
||||||
|
This function is based roughly on the reference implementation in PEFT, but handles scaling in a slightly
|
||||||
|
different way:
|
||||||
|
https://github.com/huggingface/peft/blob/26726bf1ddee6ca75ed4e1bfd292094526707a78/src/peft/tuners/lora/layer.py#L421-L433
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Merge the original weight with the LoRA weight.
|
||||||
|
merged_weight = orig_weight + lora_weight
|
||||||
|
|
||||||
|
# Calculate the vector-wise L2 norm of the weight matrix across each column vector.
|
||||||
|
weight_norm: torch.Tensor = torch.linalg.norm(merged_weight, dim=1)
|
||||||
|
|
||||||
|
dora_factor = self.dora_scale / weight_norm
|
||||||
|
new_weight = dora_factor * merged_weight
|
||||||
|
|
||||||
|
# TODO(ryand): This is wasteful. We already have the final weight, but we calculate the diff, because that is
|
||||||
|
# what the `get_weight()` API is expected to return. If we do refactor this, we'll have to give some thought to
|
||||||
|
# how lora weight scaling should be applied - having the full weight diff makes this easy.
|
||||||
|
weight_diff = new_weight - orig_weight
|
||||||
|
return weight_diff
|
||||||
|
|
||||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
if self.mid is not None:
|
if self.mid is not None:
|
||||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||||
@ -35,6 +51,10 @@ class LoRALayer(LoRALayerBase):
|
|||||||
else:
|
else:
|
||||||
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
||||||
|
|
||||||
|
if self.dora_scale is not None:
|
||||||
|
assert orig_weight is not None
|
||||||
|
weight = self._apply_dora(orig_weight, weight)
|
||||||
|
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
def calc_size(self) -> int:
|
||||||
@ -56,3 +76,6 @@ class LoRALayer(LoRALayerBase):
|
|||||||
|
|
||||||
if self.mid is not None:
|
if self.mid is not None:
|
||||||
self.mid = self.mid.to(device=device, dtype=dtype)
|
self.mid = self.mid.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if self.dora_scale is not None:
|
||||||
|
self.dora_scale = self.dora_scale.to(device=device, dtype=dtype)
|
||||||
|
Loading…
Reference in New Issue
Block a user