diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index 4287072a65..14a78693ea 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -325,6 +325,43 @@ class LoKRLayer(LoRALayerBase): self.t2 = self.t2.to(device=device, dtype=dtype) +class FullLayer(LoRALayerBase): + # weight: torch.Tensor + + def __init__( + self, + layer_key: str, + values: dict, + ): + super().__init__(layer_key, values) + + self.weight = values["diff"] + + if len(values.keys()) > 1: + _keys = list(values.keys()) + _keys.remove("diff") + raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}") + + self.rank = None # unscaled + + def get_weight(self): + return self.weight + + def calc_size(self) -> int: + model_size = super().calc_size() + model_size += self.weight.nelement() * self.weight.element_size() + return model_size + + def to( + self, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().to(device=device, dtype=dtype) + + self.weight = self.weight.to(device=device, dtype=dtype) + + class LoRAModel: # (torch.nn.Module): _name: str layers: Dict[str, LoRALayer] @@ -412,10 +449,13 @@ class LoRAModel: # (torch.nn.Module): elif "lokr_w1_b" in values or "lokr_w1" in values: layer = LoKRLayer(layer_key, values) + elif "diff" in values: + layer = FullLayer(layer_key, values) + else: - # TODO: diff/ia3/... format - print(f">> Encountered unknown lora layer module in {model.name}: {layer_key}") - return + # TODO: ia3/... format + print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}") + raise Exception("Unknown lora format!") # lower memory consumption by removing already parsed layer values state_dict[layer_key].clear()