mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
43 lines
1.1 KiB
Python
43 lines
1.1 KiB
Python
from typing import Dict, Optional
|
|
|
|
import torch
|
|
|
|
from invokeai.backend.lora.lora_layer_base import LoRALayerBase
|
|
|
|
|
|
class FullLayer(LoRALayerBase):
|
|
# weight: torch.Tensor
|
|
|
|
def __init__(
|
|
self,
|
|
layer_key: str,
|
|
values: Dict[str, torch.Tensor],
|
|
):
|
|
super().__init__(layer_key, values)
|
|
|
|
self.weight = values["diff"]
|
|
|
|
if len(values.keys()) > 1:
|
|
_keys = list(values.keys())
|
|
_keys.remove("diff")
|
|
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
|
|
|
|
self.rank = None # unscaled
|
|
|
|
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
|
return self.weight
|
|
|
|
def calc_size(self) -> int:
|
|
model_size = super().calc_size()
|
|
model_size += self.weight.nelement() * self.weight.element_size()
|
|
return model_size
|
|
|
|
def to(
|
|
self,
|
|
device: Optional[torch.device] = None,
|
|
dtype: Optional[torch.dtype] = None,
|
|
) -> None:
|
|
super().to(device=device, dtype=dtype)
|
|
|
|
self.weight = self.weight.to(device=device, dtype=dtype)
|