From 2f8b928486eb8d4e7c94a7eca122e1ab08fbc0e4 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Tue, 1 Aug 2023 17:02:57 +0300 Subject: [PATCH 1/2] Add support for diff/full lora layers --- invokeai/backend/model_management/lora.py | 46 +++++++++++++++++++++-- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index 4287072a65..14a78693ea 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -325,6 +325,43 @@ class LoKRLayer(LoRALayerBase): self.t2 = self.t2.to(device=device, dtype=dtype) +class FullLayer(LoRALayerBase): + # weight: torch.Tensor + + def __init__( + self, + layer_key: str, + values: dict, + ): + super().__init__(layer_key, values) + + self.weight = values["diff"] + + if len(values.keys()) > 1: + _keys = list(values.keys()) + _keys.remove("diff") + raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}") + + self.rank = None # unscaled + + def get_weight(self): + return self.weight + + def calc_size(self) -> int: + model_size = super().calc_size() + model_size += self.weight.nelement() * self.weight.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().to(device=device, dtype=dtype) + + self.weight = self.weight.to(device=device, dtype=dtype) + + class LoRAModel: # (torch.nn.Module): _name: str layers: Dict[str, LoRALayer] @@ -412,10 +449,13 @@ class LoRAModel: # (torch.nn.Module): elif "lokr_w1_b" in values or "lokr_w1" in values: layer = LoKRLayer(layer_key, values) + elif "diff" in values: + layer = FullLayer(layer_key, values) + else: - # TODO: diff/ia3/... format - print(f">> Encountered unknown lora layer module in {model.name}: {layer_key}") - return + # TODO: ia3/... format + print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}") + raise Exception("Unknown lora format!") # lower memory consumption by removing already parsed layer values state_dict[layer_key].clear() From 7d0cc6ec3f97a4d0c8b61c6af6e3bc4e29c4f1b9 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 3 Aug 2023 11:18:22 +1000 Subject: [PATCH 2/2] chore: black --- invokeai/backend/model_management/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index 14a78693ea..9f196d659d 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -359,7 +359,7 @@ class FullLayer(LoRALayerBase): ): super().to(device=device, dtype=dtype) - self.weight = self.weight.to(device=device, dtype=dtype) + self.weight = self.weight.to(device=device, dtype=dtype) class LoRAModel: # (torch.nn.Module):