Suggested changes

Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
Sergey Borisov
2024-07-27 04:25:15 +03:00
parent faa88f72bf
commit 9e582563eb
5 changed files with 41 additions and 26 deletions

View File

@ -71,6 +71,9 @@ class LoRALayerBase:
self.bias = self.bias.to(device=device, dtype=dtype)
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
"""Log a warning if values contains unhandled keys."""
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
unknown_keys = set(values.keys()) - all_known_keys
if unknown_keys:
@ -232,7 +235,6 @@ class LoKRLayer(LoRALayerBase):
else:
self.rank = None # unscaled
# Although lokr_t1 not used in algo, it still defined in LoKR weights
self.check_keys(
values,
{
@ -242,7 +244,6 @@ class LoKRLayer(LoRALayerBase):
"lokr_w2",
"lokr_w2_a",
"lokr_w2_b",
"lokr_t1",
"lokr_t2",
},
)