diff --git a/invokeai/backend/lora.py b/invokeai/backend/lora.py index 11e1764a0a..098b0fe0d7 100644 --- a/invokeai/backend/lora.py +++ b/invokeai/backend/lora.py @@ -378,7 +378,39 @@ class IA3Layer(LoRALayerBase): self.on_input = self.on_input.to(device=device, dtype=dtype) -AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer] +class NormLayer(LoRALayerBase): + # bias handled in LoRALayerBase(calc_size, to) + # weight: torch.Tensor + # bias: Optional[torch.Tensor] + + def __init__( + self, + layer_key: str, + values: Dict[str, torch.Tensor], + ): + super().__init__(layer_key, values) + + self.weight = values["w_norm"] + self.bias = values.get("b_norm", None) + + self.rank = None # unscaled + self.check_keys(values, {"w_norm", "b_norm"}) + + def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: + return self.weight + + def calc_size(self) -> int: + model_size = super().calc_size() + model_size += self.weight.nelement() * self.weight.element_size() + return model_size + + def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> None: + super().to(device=device, dtype=dtype) + + self.weight = self.weight.to(device=device, dtype=dtype) + + +AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer, NormLayer] class LoRAModelRaw(RawModel): # (torch.nn.Module): @@ -519,6 +551,10 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module): elif "on_input" in values: layer = IA3Layer(layer_key, values) + # norms + elif "w_norm" in values: + layer = NormLayer(layer_key, values) + else: print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}") raise Exception("Unknown lora format!")