mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
111 lines
3.5 KiB
Python
111 lines
3.5 KiB
Python
from typing import Dict, Optional
|
|
|
|
import torch
|
|
|
|
from invokeai.backend.lora.lora_layer_base import LoRALayerBase
|
|
|
|
|
|
class LoKRLayer(LoRALayerBase):
|
|
# w1: Optional[torch.Tensor] = None
|
|
# w1_a: Optional[torch.Tensor] = None
|
|
# w1_b: Optional[torch.Tensor] = None
|
|
# w2: Optional[torch.Tensor] = None
|
|
# w2_a: Optional[torch.Tensor] = None
|
|
# w2_b: Optional[torch.Tensor] = None
|
|
# t2: Optional[torch.Tensor] = None
|
|
|
|
def __init__(
|
|
self,
|
|
layer_key: str,
|
|
values: Dict[str, torch.Tensor],
|
|
):
|
|
super().__init__(layer_key, values)
|
|
|
|
if "lokr_w1" in values:
|
|
self.w1: Optional[torch.Tensor] = values["lokr_w1"]
|
|
self.w1_a = None
|
|
self.w1_b = None
|
|
else:
|
|
self.w1 = None
|
|
self.w1_a = values["lokr_w1_a"]
|
|
self.w1_b = values["lokr_w1_b"]
|
|
|
|
if "lokr_w2" in values:
|
|
self.w2: Optional[torch.Tensor] = values["lokr_w2"]
|
|
self.w2_a = None
|
|
self.w2_b = None
|
|
else:
|
|
self.w2 = None
|
|
self.w2_a = values["lokr_w2_a"]
|
|
self.w2_b = values["lokr_w2_b"]
|
|
|
|
if "lokr_t2" in values:
|
|
self.t2: Optional[torch.Tensor] = values["lokr_t2"]
|
|
else:
|
|
self.t2 = None
|
|
|
|
if "lokr_w1_b" in values:
|
|
self.rank = values["lokr_w1_b"].shape[0]
|
|
elif "lokr_w2_b" in values:
|
|
self.rank = values["lokr_w2_b"].shape[0]
|
|
else:
|
|
self.rank = None # unscaled
|
|
|
|
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
|
w1: Optional[torch.Tensor] = self.w1
|
|
if w1 is None:
|
|
assert self.w1_a is not None
|
|
assert self.w1_b is not None
|
|
w1 = self.w1_a @ self.w1_b
|
|
|
|
w2 = self.w2
|
|
if w2 is None:
|
|
if self.t2 is None:
|
|
assert self.w2_a is not None
|
|
assert self.w2_b is not None
|
|
w2 = self.w2_a @ self.w2_b
|
|
else:
|
|
w2 = torch.einsum("i j k l, i p, j r -> p r k l", self.t2, self.w2_a, self.w2_b)
|
|
|
|
if len(w2.shape) == 4:
|
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
|
w2 = w2.contiguous()
|
|
assert w1 is not None
|
|
assert w2 is not None
|
|
weight = torch.kron(w1, w2)
|
|
|
|
return weight
|
|
|
|
def calc_size(self) -> int:
|
|
model_size = super().calc_size()
|
|
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
|
|
if val is not None:
|
|
model_size += val.nelement() * val.element_size()
|
|
return model_size
|
|
|
|
def to(
|
|
self,
|
|
device: Optional[torch.device] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
) -> None:
|
|
super().to(device=device, dtype=dtype)
|
|
|
|
if self.w1 is not None:
|
|
self.w1 = self.w1.to(device=device, dtype=dtype)
|
|
else:
|
|
assert self.w1_a is not None
|
|
assert self.w1_b is not None
|
|
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
|
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
|
|
|
if self.w2 is not None:
|
|
self.w2 = self.w2.to(device=device, dtype=dtype)
|
|
else:
|
|
assert self.w2_a is not None
|
|
assert self.w2_b is not None
|
|
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
|
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
|
|
|
if self.t2 is not None:
|
|
self.t2 = self.t2.to(device=device, dtype=dtype)
|