mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add support for diff/full lora layers
This commit is contained in:
parent
cfc3a20565
commit
2f8b928486
@ -325,6 +325,43 @@ class LoKRLayer(LoRALayerBase):
|
|||||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class FullLayer(LoRALayerBase):
|
||||||
|
# weight: torch.Tensor
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
layer_key: str,
|
||||||
|
values: dict,
|
||||||
|
):
|
||||||
|
super().__init__(layer_key, values)
|
||||||
|
|
||||||
|
self.weight = values["diff"]
|
||||||
|
|
||||||
|
if len(values.keys()) > 1:
|
||||||
|
_keys = list(values.keys())
|
||||||
|
_keys.remove("diff")
|
||||||
|
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
|
||||||
|
|
||||||
|
self.rank = None # unscaled
|
||||||
|
|
||||||
|
def get_weight(self):
|
||||||
|
return self.weight
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
model_size = super().calc_size()
|
||||||
|
model_size += self.weight.nelement() * self.weight.element_size()
|
||||||
|
return model_size
|
||||||
|
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
):
|
||||||
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
class LoRAModel: # (torch.nn.Module):
|
class LoRAModel: # (torch.nn.Module):
|
||||||
_name: str
|
_name: str
|
||||||
layers: Dict[str, LoRALayer]
|
layers: Dict[str, LoRALayer]
|
||||||
@ -412,10 +449,13 @@ class LoRAModel: # (torch.nn.Module):
|
|||||||
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
||||||
layer = LoKRLayer(layer_key, values)
|
layer = LoKRLayer(layer_key, values)
|
||||||
|
|
||||||
|
elif "diff" in values:
|
||||||
|
layer = FullLayer(layer_key, values)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# TODO: diff/ia3/... format
|
# TODO: ia3/... format
|
||||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key}")
|
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
||||||
return
|
raise Exception("Unknown lora format!")
|
||||||
|
|
||||||
# lower memory consumption by removing already parsed layer values
|
# lower memory consumption by removing already parsed layer values
|
||||||
state_dict[layer_key].clear()
|
state_dict[layer_key].clear()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user