mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add support for norm layer
This commit is contained in:
parent
7da6120b39
commit
68f993998a
@ -378,7 +378,39 @@ class IA3Layer(LoRALayerBase):
|
|||||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
class NormLayer(LoRALayerBase):
|
||||||
|
# bias handled in LoRALayerBase(calc_size, to)
|
||||||
|
# weight: torch.Tensor
|
||||||
|
# bias: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_key: str,
|
||||||
|
values: Dict[str, torch.Tensor],
|
||||||
|
):
|
||||||
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
|
self.weight = values["w_norm"]
|
||||||
|
self.bias = values.get("b_norm", None)
|
||||||
|
|
||||||
|
self.rank = None # unscaled
|
||||||
|
self.check_keys(values, {"w_norm", "b_norm"})
|
||||||
|
|
||||||
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.weight
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = super().calc_size()
|
||||||
|
model_size += self.weight.nelement() * self.weight.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None:
|
||||||
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer]
|
||||||
|
|
||||||
|
|
||||||
class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||||
@ -519,6 +551,10 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
|||||||
elif "on_input" in values:
|
elif "on_input" in values:
|
||||||
layer = IA3Layer(layer_key, values)
|
layer = IA3Layer(layer_key, values)
|
||||||
|
|
||||||
|
# norms
|
||||||
|
elif "w_norm" in values:
|
||||||
|
layer = NormLayer(layer_key, values)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
||||||
raise Exception("Unknown lora format!")
|
raise Exception("Unknown lora format!")
|
||||||
|
Loading…
Reference in New Issue
Block a user