mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Suggested changes
Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com>
This commit is contained in:
@ -71,6 +71,9 @@ class LoRALayerBase:
|
||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||
|
||||
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
|
||||
"""Log a warning if values contains unhandled keys."""
|
||||
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
|
||||
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
|
||||
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
|
||||
unknown_keys = set(values.keys()) - all_known_keys
|
||||
if unknown_keys:
|
||||
@ -232,7 +235,6 @@ class LoKRLayer(LoRALayerBase):
|
||||
else:
|
||||
self.rank = None # unscaled
|
||||
|
||||
# Although lokr_t1 not used in algo, it still defined in LoKR weights
|
||||
self.check_keys(
|
||||
values,
|
||||
{
|
||||
@ -242,7 +244,6 @@ class LoKRLayer(LoRALayerBase):
|
||||
"lokr_w2",
|
||||
"lokr_w2_a",
|
||||
"lokr_w2_b",
|
||||
"lokr_t1",
|
||||
"lokr_t2",
|
||||
},
|
||||
)
|
||||
|
Reference in New Issue
Block a user