diff --git a/invokeai/backend/lora.py b/invokeai/backend/lora.py index 21b99d7f6c..714a4a8a2d 100644 --- a/invokeai/backend/lora.py +++ b/invokeai/backend/lora.py @@ -46,11 +46,18 @@ class LoRALayerBase: self.rank = None # set in layer implementation self.layer_key = layer_key - def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: + def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: raise NotImplementedError() - def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]: - raise NotImplementedError() + def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]: + return self.bias + + def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]: + params = {"weight": self.get_weight(orig_module.weight)} + bias = self.get_bias(orig_module.bias) + if bias is not None: + params["bias"] = bias + return params def calc_size(self) -> int: model_size = 0 @@ -79,14 +86,11 @@ class LoRALayer(LoRALayerBase): self.up = values["lora_up.weight"] self.down = values["lora_down.weight"] - if "lora_mid.weight" in values: - self.mid: Optional[torch.Tensor] = values["lora_mid.weight"] - else: - self.mid = None + self.mid = values.get("lora_mid.weight", None) self.rank = self.down.shape[0] - def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: + def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: if self.mid is not None: up = self.up.reshape(self.up.shape[0], self.up.shape[1]) down = self.down.reshape(self.down.shape[0], self.down.shape[1]) @@ -96,9 +100,6 @@ class LoRALayer(LoRALayerBase): return weight - def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]: - return {"weight": self.get_weight(orig_module.weight)} - def calc_size(self) -> int: model_size = super().calc_size() for val in [self.up, self.mid, self.down]: @@ -131,20 +132,12 @@ class LoHALayer(LoRALayerBase): self.w1_b = values["hada_w1_b"] self.w2_a = values["hada_w2_a"] self.w2_b = values["hada_w2_b"] - - if "hada_t1" in values: - self.t1: Optional[torch.Tensor] = values["hada_t1"] - else: - self.t1 = None - - if "hada_t2" in values: - self.t2: Optional[torch.Tensor] = values["hada_t2"] - else: - self.t2 = None + self.t1 = values.get("hada_t1", None) + self.t2 = values.get("hada_t2", None) self.rank = self.w1_b.shape[0] - def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: + def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: if self.t1 is None: weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b) @@ -155,9 +148,6 @@ class LoHALayer(LoRALayerBase): return weight - def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]: - return {"weight": self.get_weight(orig_module.weight)} - def calc_size(self) -> int: model_size = super().calc_size() for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]: @@ -195,37 +185,26 @@ class LoKRLayer(LoRALayerBase): ): super().__init__(layer_key, values) - if "lokr_w1" in values: - self.w1: Optional[torch.Tensor] = values["lokr_w1"] - self.w1_a = None - self.w1_b = None - else: - self.w1 = None + self.w1 = values.get("lokr_w1", None) + if self.w1 is None: self.w1_a = values["lokr_w1_a"] self.w1_b = values["lokr_w1_b"] - if "lokr_w2" in values: - self.w2: Optional[torch.Tensor] = values["lokr_w2"] - self.w2_a = None - self.w2_b = None - else: - self.w2 = None + self.w2 = values.get("lokr_w2", None) + if self.w2 is None: self.w2_a = values["lokr_w2_a"] self.w2_b = values["lokr_w2_b"] - if "lokr_t2" in values: - self.t2: Optional[torch.Tensor] = values["lokr_t2"] - else: - self.t2 = None + self.t2 = values.get("lokr_t2", None) - if "lokr_w1_b" in values: - self.rank = values["lokr_w1_b"].shape[0] - elif "lokr_w2_b" in values: - self.rank = values["lokr_w2_b"].shape[0] + if self.w1_b is not None: + self.rank = self.w1_b.shape[0] + elif self.w2_b is not None: + self.rank = self.w2_b.shape[0] else: self.rank = None # unscaled - def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: + def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: w1: Optional[torch.Tensor] = self.w1 if w1 is None: assert self.w1_a is not None @@ -250,9 +229,6 @@ class LoKRLayer(LoRALayerBase): return weight - def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]: - return {"weight": self.get_weight(orig_module.weight)} - def calc_size(self) -> int: model_size = super().calc_size() for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]: @@ -302,12 +278,9 @@ class FullLayer(LoRALayerBase): self.rank = None # unscaled - def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: + def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: return self.weight - def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]: - return {"weight": self.get_weight(orig_module.weight)} - def calc_size(self) -> int: model_size = super().calc_size() model_size += self.weight.nelement() * self.weight.element_size() @@ -335,16 +308,13 @@ class IA3Layer(LoRALayerBase): self.rank = None # unscaled - def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor: + def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: weight = self.weight if not self.on_input: weight = weight.reshape(-1, 1) assert orig_weight is not None return orig_weight * weight - def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]: - return {"weight": self.get_weight(orig_module.weight)} - def calc_size(self) -> int: model_size = super().calc_size() model_size += self.weight.nelement() * self.weight.element_size()