Handle loras in modular denoise

This commit is contained in:
Sergey Borisov
2024-07-24 05:07:29 +03:00
parent 7c975f0d00
commit ab0bfa709a
4 changed files with 227 additions and 4 deletions

View File

@ -49,6 +49,9 @@ class LoRALayerBase:
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
raise NotImplementedError()
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
raise NotImplementedError()
def calc_size(self) -> int:
model_size = 0
for val in [self.bias]:
@ -93,6 +96,9 @@ class LoRALayer(LoRALayerBase):
return weight
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
return {"weight": self.get_weight(orig_module.weight)}
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.up, self.mid, self.down]:
@ -149,6 +155,9 @@ class LoHALayer(LoRALayerBase):
return weight
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
return {"weight": self.get_weight(orig_module.weight)}
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
@ -241,6 +250,9 @@ class LoKRLayer(LoRALayerBase):
return weight
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
return {"weight": self.get_weight(orig_module.weight)}
def calc_size(self) -> int:
model_size = super().calc_size()
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
@ -293,6 +305,9 @@ class FullLayer(LoRALayerBase):
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
return self.weight
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
return {"weight": self.get_weight(orig_module.weight)}
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()
@ -327,6 +342,9 @@ class IA3Layer(LoRALayerBase):
assert orig_weight is not None
return orig_weight * weight
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
return {"weight": self.get_weight(orig_module.weight)}
def calc_size(self) -> int:
model_size = super().calc_size()
model_size += self.weight.nelement() * self.weight.element_size()