mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Handle loras in modular denoise
This commit is contained in:
@ -49,6 +49,9 @@ class LoRALayerBase:
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
for val in [self.bias]:
|
||||
@ -93,6 +96,9 @@ class LoRALayer(LoRALayerBase):
|
||||
|
||||
return weight
|
||||
|
||||
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
|
||||
return {"weight": self.get_weight(orig_module.weight)}
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.up, self.mid, self.down]:
|
||||
@ -149,6 +155,9 @@ class LoHALayer(LoRALayerBase):
|
||||
|
||||
return weight
|
||||
|
||||
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
|
||||
return {"weight": self.get_weight(orig_module.weight)}
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
|
||||
@ -241,6 +250,9 @@ class LoKRLayer(LoRALayerBase):
|
||||
|
||||
return weight
|
||||
|
||||
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
|
||||
return {"weight": self.get_weight(orig_module.weight)}
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
|
||||
@ -293,6 +305,9 @@ class FullLayer(LoRALayerBase):
|
||||
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return self.weight
|
||||
|
||||
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
|
||||
return {"weight": self.get_weight(orig_module.weight)}
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
@ -327,6 +342,9 @@ class IA3Layer(LoRALayerBase):
|
||||
assert orig_weight is not None
|
||||
return orig_weight * weight
|
||||
|
||||
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
|
||||
return {"weight": self.get_weight(orig_module.weight)}
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
|
Reference in New Issue
Block a user