mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add support for diff/full lora layers
This commit is contained in:
parent
cfc3a20565
commit
2f8b928486
@ -325,6 +325,43 @@ class LoKRLayer(LoRALayerBase):
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class FullLayer(LoRALayerBase):
|
||||
# weight: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layer_key: str,
|
||||
values: dict,
|
||||
):
|
||||
super().__init__(layer_key, values)
|
||||
|
||||
self.weight = values["diff"]
|
||||
|
||||
if len(values.keys()) > 1:
|
||||
_keys = list(values.keys())
|
||||
_keys.remove("diff")
|
||||
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
|
||||
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self):
|
||||
return self.weight
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = super().calc_size()
|
||||
model_size += self.weight.nelement() * self.weight.element_size()
|
||||
return model_size
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
class LoRAModel: # (torch.nn.Module):
|
||||
_name: str
|
||||
layers: Dict[str, LoRALayer]
|
||||
@ -412,10 +449,13 @@ class LoRAModel: # (torch.nn.Module):
|
||||
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
||||
layer = LoKRLayer(layer_key, values)
|
||||
|
||||
elif "diff" in values:
|
||||
layer = FullLayer(layer_key, values)
|
||||
|
||||
else:
|
||||
# TODO: diff/ia3/... format
|
||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key}")
|
||||
return
|
||||
# TODO: ia3/... format
|
||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
||||
raise Exception("Unknown lora format!")
|
||||
|
||||
# lower memory consumption by removing already parsed layer values
|
||||
state_dict[layer_key].clear()
|
||||
|
Loading…
Reference in New Issue
Block a user