Handle bias in full/diff lora layer

This commit is contained in:
Sergey Borisov 2024-07-25 02:02:37 +03:00
parent 31949ed2f2
commit 8a9e2f57a4

View File

@ -260,7 +260,9 @@ class LoKRLayer(LoRALayerBase):
class FullLayer(LoRALayerBase): class FullLayer(LoRALayerBase):
# bias handled in LoRALayerBase(calc_size, to)
# weight: torch.Tensor # weight: torch.Tensor
# bias: Optional[torch.Tensor]
def __init__( def __init__(
self, self,
@ -270,11 +272,7 @@ class FullLayer(LoRALayerBase):
super().__init__(layer_key, values) super().__init__(layer_key, values)
self.weight = values["diff"] self.weight = values["diff"]
self.bias = values.get("diff_b", None)
if len(values.keys()) > 1:
_keys = list(values.keys())
_keys.remove("diff")
raise NotImplementedError(f"Unexpected keys in lora diff layer: {_keys}")
self.rank = None # unscaled self.rank = None # unscaled