mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add support for LyCORIS IA3 format
This commit is contained in:
@ -122,41 +122,7 @@ class LoRALayerBase:
|
||||
self.rank = None # set in layer implementation
|
||||
self.layer_key = layer_key
|
||||
|
||||
def forward(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
input_h: Any, # for real looks like Tuple[torch.nn.Tensor] but not sure
|
||||
multiplier: float,
|
||||
):
|
||||
if type(module) == torch.nn.Conv2d:
|
||||
op = torch.nn.functional.conv2d
|
||||
extra_args = dict(
|
||||
stride=module.stride,
|
||||
padding=module.padding,
|
||||
dilation=module.dilation,
|
||||
groups=module.groups,
|
||||
)
|
||||
|
||||
else:
|
||||
op = torch.nn.functional.linear
|
||||
extra_args = {}
|
||||
|
||||
weight = self.get_weight()
|
||||
|
||||
bias = self.bias if self.bias is not None else 0
|
||||
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||
return (
|
||||
op(
|
||||
*input_h,
|
||||
(weight + bias).view(module.weight.shape),
|
||||
None,
|
||||
**extra_args,
|
||||
)
|
||||
* multiplier
|
||||
* scale
|
||||
)
|
||||
|
||||
def get_weight(self):
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
raise NotImplementedError()
|
||||
|
||||
def calc_size(self) -> int:
|
||||
@ -197,7 +163,7 @@ class LoRALayer(LoRALayerBase):
|
||||
|
||||
self.rank = self.down.shape[0]
|
||||
|
||||
def get_weight(self):
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
if self.mid is not None:
|
||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||
@ -260,7 +226,7 @@ class LoHALayer(LoRALayerBase):
|
||||
|
||||
self.rank = self.w1_b.shape[0]
|
||||
|
||||
def get_weight(self):
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
if self.t1 is None:
|
||||
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||
|
||||
@ -342,7 +308,7 @@ class LoKRLayer(LoRALayerBase):
|
||||
else:
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self):
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
w1 = self.w1
|
||||
if w1 is None:
|
||||
w1 = self.w1_a @ self.w1_b
|
||||
@ -410,7 +376,7 @@ class FullLayer(LoRALayerBase):
|
||||
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self):
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
@ -427,6 +393,44 @@ class FullLayer(LoRALayerBase):
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
|
||||
class IA3Layer(LoRALayerBase):
|
||||
# weight: torch.Tensor
|
||||
# on_input: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["weight"]
|
||||
self.on_input = values["on_input"]
|
||||
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, orig_weight: torch.Tensor):
|
||||
weight = self.weight
|
||||
if not self.on_input:
|
||||
weight = weight.reshape(-1, 1)
|
||||
return orig_weight * weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
model_size += self.on_input.nelement() * self.on_input.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
# TODO: rename all methods used in model logic with Info postfix and remove here Raw postfix
|
||||
class LoRAModelRaw: # (torch.nn.Module):
|
||||
@ -547,11 +551,15 @@ class LoRAModelRaw: # (torch.nn.Module):
|
||||
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
||||
layer = LoKRLayer(layer_key, values)
|
||||
|
||||
# diff
|
||||
elif "diff" in values:
|
||||
layer = FullLayer(layer_key, values)
|
||||
|
||||
# ia3
|
||||
elif "weight" in values and "on_input" in values:
|
||||
layer = IA3Layer(layer_key, values)
|
||||
|
||||
else:
|
||||
# TODO: ia3/... format
|
||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
||||
raise Exception("Unknown lora format!")
|
||||
|
||||
|
Reference in New Issue
Block a user