mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Refactor code a bit
This commit is contained in:
parent
0ccb304b8b
commit
31949ed2f2
@ -46,11 +46,18 @@ class LoRALayerBase:
|
|||||||
self.rank = None # set in layer implementation
|
self.rank = None # set in layer implementation
|
||||||
self.layer_key = layer_key
|
self.layer_key = layer_key
|
||||||
|
|
||||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
|
def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]:
|
||||||
raise NotImplementedError()
|
return self.bias
|
||||||
|
|
||||||
|
def get_parameters(self, orig_module: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
||||||
|
params = {"weight": self.get_weight(orig_module.weight)}
|
||||||
|
bias = self.get_bias(orig_module.bias)
|
||||||
|
if bias is not None:
|
||||||
|
params["bias"] = bias
|
||||||
|
return params
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
def calc_size(self) -> int:
|
||||||
model_size = 0
|
model_size = 0
|
||||||
@ -79,14 +86,11 @@ class LoRALayer(LoRALayerBase):
|
|||||||
|
|
||||||
self.up = values["lora_up.weight"]
|
self.up = values["lora_up.weight"]
|
||||||
self.down = values["lora_down.weight"]
|
self.down = values["lora_down.weight"]
|
||||||
if "lora_mid.weight" in values:
|
self.mid = values.get("lora_mid.weight", None)
|
||||||
self.mid: Optional[torch.Tensor] = values["lora_mid.weight"]
|
|
||||||
else:
|
|
||||||
self.mid = None
|
|
||||||
|
|
||||||
self.rank = self.down.shape[0]
|
self.rank = self.down.shape[0]
|
||||||
|
|
||||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
if self.mid is not None:
|
if self.mid is not None:
|
||||||
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
up = self.up.reshape(self.up.shape[0], self.up.shape[1])
|
||||||
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
down = self.down.reshape(self.down.shape[0], self.down.shape[1])
|
||||||
@ -96,9 +100,6 @@ class LoRALayer(LoRALayerBase):
|
|||||||
|
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
|
|
||||||
return {"weight": self.get_weight(orig_module.weight)}
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
def calc_size(self) -> int:
|
||||||
model_size = super().calc_size()
|
model_size = super().calc_size()
|
||||||
for val in [self.up, self.mid, self.down]:
|
for val in [self.up, self.mid, self.down]:
|
||||||
@ -131,20 +132,12 @@ class LoHALayer(LoRALayerBase):
|
|||||||
self.w1_b = values["hada_w1_b"]
|
self.w1_b = values["hada_w1_b"]
|
||||||
self.w2_a = values["hada_w2_a"]
|
self.w2_a = values["hada_w2_a"]
|
||||||
self.w2_b = values["hada_w2_b"]
|
self.w2_b = values["hada_w2_b"]
|
||||||
|
self.t1 = values.get("hada_t1", None)
|
||||||
if "hada_t1" in values:
|
self.t2 = values.get("hada_t2", None)
|
||||||
self.t1: Optional[torch.Tensor] = values["hada_t1"]
|
|
||||||
else:
|
|
||||||
self.t1 = None
|
|
||||||
|
|
||||||
if "hada_t2" in values:
|
|
||||||
self.t2: Optional[torch.Tensor] = values["hada_t2"]
|
|
||||||
else:
|
|
||||||
self.t2 = None
|
|
||||||
|
|
||||||
self.rank = self.w1_b.shape[0]
|
self.rank = self.w1_b.shape[0]
|
||||||
|
|
||||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
if self.t1 is None:
|
if self.t1 is None:
|
||||||
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
weight: torch.Tensor = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||||
|
|
||||||
@ -155,9 +148,6 @@ class LoHALayer(LoRALayerBase):
|
|||||||
|
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
|
|
||||||
return {"weight": self.get_weight(orig_module.weight)}
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
def calc_size(self) -> int:
|
||||||
model_size = super().calc_size()
|
model_size = super().calc_size()
|
||||||
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
|
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
|
||||||
@ -195,37 +185,26 @@ class LoKRLayer(LoRALayerBase):
|
|||||||
):
|
):
|
||||||
super().__init__(layer_key, values)
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
if "lokr_w1" in values:
|
self.w1 = values.get("lokr_w1", None)
|
||||||
self.w1: Optional[torch.Tensor] = values["lokr_w1"]
|
if self.w1 is None:
|
||||||
self.w1_a = None
|
|
||||||
self.w1_b = None
|
|
||||||
else:
|
|
||||||
self.w1 = None
|
|
||||||
self.w1_a = values["lokr_w1_a"]
|
self.w1_a = values["lokr_w1_a"]
|
||||||
self.w1_b = values["lokr_w1_b"]
|
self.w1_b = values["lokr_w1_b"]
|
||||||
|
|
||||||
if "lokr_w2" in values:
|
self.w2 = values.get("lokr_w2", None)
|
||||||
self.w2: Optional[torch.Tensor] = values["lokr_w2"]
|
if self.w2 is None:
|
||||||
self.w2_a = None
|
|
||||||
self.w2_b = None
|
|
||||||
else:
|
|
||||||
self.w2 = None
|
|
||||||
self.w2_a = values["lokr_w2_a"]
|
self.w2_a = values["lokr_w2_a"]
|
||||||
self.w2_b = values["lokr_w2_b"]
|
self.w2_b = values["lokr_w2_b"]
|
||||||
|
|
||||||
if "lokr_t2" in values:
|
self.t2 = values.get("lokr_t2", None)
|
||||||
self.t2: Optional[torch.Tensor] = values["lokr_t2"]
|
|
||||||
else:
|
|
||||||
self.t2 = None
|
|
||||||
|
|
||||||
if "lokr_w1_b" in values:
|
if self.w1_b is not None:
|
||||||
self.rank = values["lokr_w1_b"].shape[0]
|
self.rank = self.w1_b.shape[0]
|
||||||
elif "lokr_w2_b" in values:
|
elif self.w2_b is not None:
|
||||||
self.rank = values["lokr_w2_b"].shape[0]
|
self.rank = self.w2_b.shape[0]
|
||||||
else:
|
else:
|
||||||
self.rank = None # unscaled
|
self.rank = None # unscaled
|
||||||
|
|
||||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
w1: Optional[torch.Tensor] = self.w1
|
w1: Optional[torch.Tensor] = self.w1
|
||||||
if w1 is None:
|
if w1 is None:
|
||||||
assert self.w1_a is not None
|
assert self.w1_a is not None
|
||||||
@ -250,9 +229,6 @@ class LoKRLayer(LoRALayerBase):
|
|||||||
|
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
|
|
||||||
return {"weight": self.get_weight(orig_module.weight)}
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
def calc_size(self) -> int:
|
||||||
model_size = super().calc_size()
|
model_size = super().calc_size()
|
||||||
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
|
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
|
||||||
@ -302,12 +278,9 @@ class FullLayer(LoRALayerBase):
|
|||||||
|
|
||||||
self.rank = None # unscaled
|
self.rank = None # unscaled
|
||||||
|
|
||||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
return self.weight
|
return self.weight
|
||||||
|
|
||||||
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
|
|
||||||
return {"weight": self.get_weight(orig_module.weight)}
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
def calc_size(self) -> int:
|
||||||
model_size = super().calc_size()
|
model_size = super().calc_size()
|
||||||
model_size += self.weight.nelement() * self.weight.element_size()
|
model_size += self.weight.nelement() * self.weight.element_size()
|
||||||
@ -335,16 +308,13 @@ class IA3Layer(LoRALayerBase):
|
|||||||
|
|
||||||
self.rank = None # unscaled
|
self.rank = None # unscaled
|
||||||
|
|
||||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
weight = self.weight
|
weight = self.weight
|
||||||
if not self.on_input:
|
if not self.on_input:
|
||||||
weight = weight.reshape(-1, 1)
|
weight = weight.reshape(-1, 1)
|
||||||
assert orig_weight is not None
|
assert orig_weight is not None
|
||||||
return orig_weight * weight
|
return orig_weight * weight
|
||||||
|
|
||||||
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
|
|
||||||
return {"weight": self.get_weight(orig_module.weight)}
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
def calc_size(self) -> int:
|
||||||
model_size = super().calc_size()
|
model_size = super().calc_size()
|
||||||
model_size += self.weight.nelement() * self.weight.element_size()
|
model_size += self.weight.nelement() * self.weight.element_size()
|
||||||
|
Loading…
Reference in New Issue
Block a user