mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Refactor code a bit
This commit is contained in:
parent
0ccb304b8b
commit
31949ed2f2
@ -46,11 +46,18 @@ class LoRALayerBase:
|
||||
self.rank = None # set in layer implementation
|
||||
self.layer_key = layer_key
|
||||
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
|
||||
raise NotImplementedError()
|
||||
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
return self.bias
|
||||
|
||||
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
||||
params = {"weight": self.get_weight(orig_module.weight)}
|
||||
bias = self.get_bias(orig_module.bias)
|
||||
if bias is not None:
|
||||
params["bias"] = bias
|
||||
return params
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
@ -79,14 +86,11 @@ class LoRALayer(LoRALayerBase):
|
||||
|
||||
self.up = values["lora_up.weight"]
|
||||
self.down = values["lora_down.weight"]
|
||||
if "lora_mid.weight" in values:
|
||||
self.mid: Optional[torch.Tensor] = values["lora_mid.weight"]
|
||||
else:
|
||||
self.mid = None
|
||||
self.mid = values.get("lora_mid.weight", None)
|
||||
|
||||
self.rank = self.down.shape[0]
|
||||
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
if self.mid is not None:
|
||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||
@ -96,9 +100,6 @@ class LoRALayer(LoRALayerBase):
|
||||
|
||||
return weight
|
||||
|
||||
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
|
||||
return {"weight": self.get_weight(orig_module.weight)}
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.up, self.mid, self.down]:
|
||||
@ -131,20 +132,12 @@ class LoHALayer(LoRALayerBase):
|
||||
self.w1_b = values["hada_w1_b"]
|
||||
self.w2_a = values["hada_w2_a"]
|
||||
self.w2_b = values["hada_w2_b"]
|
||||
|
||||
if "hada_t1" in values:
|
||||
self.t1: Optional[torch.Tensor] = values["hada_t1"]
|
||||
else:
|
||||
self.t1 = None
|
||||
|
||||
if "hada_t2" in values:
|
||||
self.t2: Optional[torch.Tensor] = values["hada_t2"]
|
||||
else:
|
||||
self.t2 = None
|
||||
self.t1 = values.get("hada_t1", None)
|
||||
self.t2 = values.get("hada_t2", None)
|
||||
|
||||
self.rank = self.w1_b.shape[0]
|
||||
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
if self.t1 is None:
|
||||
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||
|
||||
@ -155,9 +148,6 @@ class LoHALayer(LoRALayerBase):
|
||||
|
||||
return weight
|
||||
|
||||
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
|
||||
return {"weight": self.get_weight(orig_module.weight)}
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
|
||||
@ -195,37 +185,26 @@ class LoKRLayer(LoRALayerBase):
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
if "lokr_w1" in values:
|
||||
self.w1: Optional[torch.Tensor] = values["lokr_w1"]
|
||||
self.w1_a = None
|
||||
self.w1_b = None
|
||||
else:
|
||||
self.w1 = None
|
||||
self.w1 = values.get("lokr_w1", None)
|
||||
if self.w1 is None:
|
||||
self.w1_a = values["lokr_w1_a"]
|
||||
self.w1_b = values["lokr_w1_b"]
|
||||
|
||||
if "lokr_w2" in values:
|
||||
self.w2: Optional[torch.Tensor] = values["lokr_w2"]
|
||||
self.w2_a = None
|
||||
self.w2_b = None
|
||||
else:
|
||||
self.w2 = None
|
||||
self.w2 = values.get("lokr_w2", None)
|
||||
if self.w2 is None:
|
||||
self.w2_a = values["lokr_w2_a"]
|
||||
self.w2_b = values["lokr_w2_b"]
|
||||
|
||||
if "lokr_t2" in values:
|
||||
self.t2: Optional[torch.Tensor] = values["lokr_t2"]
|
||||
else:
|
||||
self.t2 = None
|
||||
self.t2 = values.get("lokr_t2", None)
|
||||
|
||||
if "lokr_w1_b" in values:
|
||||
self.rank = values["lokr_w1_b"].shape[0]
|
||||
elif "lokr_w2_b" in values:
|
||||
self.rank = values["lokr_w2_b"].shape[0]
|
||||
if self.w1_b is not None:
|
||||
self.rank = self.w1_b.shape[0]
|
||||
elif self.w2_b is not None:
|
||||
self.rank = self.w2_b.shape[0]
|
||||
else:
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
w1: Optional[torch.Tensor] = self.w1
|
||||
if w1 is None:
|
||||
assert self.w1_a is not None
|
||||
@ -250,9 +229,6 @@ class LoKRLayer(LoRALayerBase):
|
||||
|
||||
return weight
|
||||
|
||||
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
|
||||
return {"weight": self.get_weight(orig_module.weight)}
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
|
||||
@ -302,12 +278,9 @@ class FullLayer(LoRALayerBase):
|
||||
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
return self.weight
|
||||
|
||||
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
|
||||
return {"weight": self.get_weight(orig_module.weight)}
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
@ -335,16 +308,13 @@ class IA3Layer(LoRALayerBase):
|
||||
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||
weight = self.weight
|
||||
if not self.on_input:
|
||||
weight = weight.reshape(-1, 1)
|
||||
assert orig_weight is not None
|
||||
return orig_weight * weight
|
||||
|
||||
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
|
||||
return {"weight": self.get_weight(orig_module.weight)}
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
|
Loading…
Reference in New Issue
Block a user