Add layer keys check

This commit is contained in:
Sergey Borisov 2024-07-25 02:03:08 +03:00
parent 8a9e2f57a4
commit 653f63ae71

View File

@ -3,7 +3,7 @@
import bisect import bisect
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Set, Tuple, Union
import torch import torch
from safetensors.torch import load_file from safetensors.torch import load_file
@ -70,6 +70,15 @@ class LoRALayerBase:
if self.bias is not None: if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype) self.bias = self.bias.to(device=device, dtype=dtype)
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
unknown_keys = set(values.keys()) - all_known_keys
if unknown_keys:
# TODO: how to warn log?
print(
f"[WARN] Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
)
# TODO: find and debug lora/locon with bias # TODO: find and debug lora/locon with bias
class LoRALayer(LoRALayerBase): class LoRALayer(LoRALayerBase):
@ -89,6 +98,14 @@ class LoRALayer(LoRALayerBase):
self.mid = values.get("lora_mid.weight", None) self.mid = values.get("lora_mid.weight", None)
self.rank = self.down.shape[0] self.rank = self.down.shape[0]
self.check_keys(
values,
{
"lora_up.weight",
"lora_down.weight",
"lora_mid.weight",
},
)
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.mid is not None: if self.mid is not None:
@ -136,6 +153,17 @@ class LoHALayer(LoRALayerBase):
self.t2 = values.get("hada_t2", None) self.t2 = values.get("hada_t2", None)
self.rank = self.w1_b.shape[0] self.rank = self.w1_b.shape[0]
self.check_keys(
values,
{
"hada_w1_a",
"hada_w1_b",
"hada_w2_a",
"hada_w2_b",
"hada_t1",
"hada_t2",
},
)
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
if self.t1 is None: if self.t1 is None:
@ -204,6 +232,21 @@ class LoKRLayer(LoRALayerBase):
else: else:
self.rank = None # unscaled self.rank = None # unscaled
# Although lokr_t1 not used in algo, it still defined in LoKR weights
self.check_keys(
values,
{
"lokr_w1",
"lokr_w1_a",
"lokr_w1_b",
"lokr_w2",
"lokr_w2_a",
"lokr_w2_b",
"lokr_t1",
"lokr_t2",
},
)
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
w1: Optional[torch.Tensor] = self.w1 w1: Optional[torch.Tensor] = self.w1
if w1 is None: if w1 is None:
@ -275,6 +318,7 @@ class FullLayer(LoRALayerBase):
self.bias = values.get("diff_b", None) self.bias = values.get("diff_b", None)
self.rank = None # unscaled self.rank = None # unscaled
self.check_keys(values, {"diff", "diff_b"})
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
return self.weight return self.weight
@ -305,6 +349,7 @@ class IA3Layer(LoRALayerBase):
self.on_input = values["on_input"] self.on_input = values["on_input"]
self.rank = None # unscaled self.rank = None # unscaled
self.check_keys(values, {"weight", "on_input"})
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
weight = self.weight weight = self.weight