mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
59 lines
1.8 KiB
Python
59 lines
1.8 KiB
Python
from typing import Dict, Optional
|
|
|
|
import torch
|
|
|
|
from invokeai.backend.lora.lora_layer_base import LoRALayerBase
|
|
|
|
|
|
# TODO: find and debug lora/locon with bias
|
|
class LoRALayer(LoRALayerBase):
|
|
# up: torch.Tensor
|
|
# mid: Optional[torch.Tensor]
|
|
# down: torch.Tensor
|
|
|
|
def __init__(
|
|
self,
|
|
layer_key: str,
|
|
values: Dict[str, torch.Tensor],
|
|
):
|
|
super().__init__(layer_key, values)
|
|
|
|
self.up = values["lora_up.weight"]
|
|
self.down = values["lora_down.weight"]
|
|
if "lora_mid.weight" in values:
|
|
self.mid: Optional[torch.Tensor] = values["lora_mid.weight"]
|
|
else:
|
|
self.mid = None
|
|
|
|
self.rank = self.down.shape[0]
|
|
|
|
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
|
if self.mid is not None:
|
|
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
|
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
|
weight = torch.einsum("m n w h, i m, n j -> i j w h", self.mid, up, down)
|
|
else:
|
|
weight = self.up.reshape(self.up.shape[0], -1) @ self.down.reshape(self.down.shape[0], -1)
|
|
|
|
return weight
|
|
|
|
def calc_size(self) -> int:
|
|
model_size = super().calc_size()
|
|
for val in [self.up, self.mid, self.down]:
|
|
if val is not None:
|
|
model_size += val.nelement() * val.element_size()
|
|
return model_size
|
|
|
|
def to(
|
|
self,
|
|
device: Optional[torch.device] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
) -> None:
|
|
super().to(device=device, dtype=dtype)
|
|
|
|
self.up = self.up.to(device=device, dtype=dtype)
|
|
self.down = self.down.to(device=device, dtype=dtype)
|
|
|
|
if self.mid is not None:
|
|
self.mid = self.mid.to(device=device, dtype=dtype)
|