From 653f63ae7196f431006b1717222bab6fb76e311d Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Thu, 25 Jul 2024 02:03:08 +0300 Subject: [PATCH] Add layer keys check --- invokeai/backend/lora.py | 47 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/invokeai/backend/lora.py b/invokeai/backend/lora.py index b2cba07b2c..8f5b1f08f5 100644 --- a/invokeai/backend/lora.py +++ b/invokeai/backend/lora.py @@ -3,7 +3,7 @@ import bisect from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Set, Tuple, Union import torch from safetensors.torch import load_file @@ -70,6 +70,15 @@ class LoRALayerBase: if self.bias is not None: self.bias = self.bias.to(device=device, dtype=dtype) + def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]): + all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"} + unknown_keys = set(values.keys()) - all_known_keys + if unknown_keys: + # TODO: how to warn log? + print( + f"[WARN] Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}" + ) + # TODO: find and debug lora/locon with bias class LoRALayer(LoRALayerBase): @@ -89,6 +98,14 @@ class LoRALayer(LoRALayerBase): self.mid = values.get("lora_mid.weight", None) self.rank = self.down.shape[0] + self.check_keys( + values, + { + "lora_up.weight", + "lora_down.weight", + "lora_mid.weight", + }, + ) def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: if self.mid is not None: @@ -136,6 +153,17 @@ class LoHALayer(LoRALayerBase): self.t2 = values.get("hada_t2", None) self.rank = self.w1_b.shape[0] + self.check_keys( + values, + { + "hada_w1_a", + "hada_w1_b", + "hada_w2_a", + "hada_w2_b", + "hada_t1", + "hada_t2", + }, + ) def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: if self.t1 is None: @@ -204,6 +232,21 @@ class LoKRLayer(LoRALayerBase): else: self.rank = None # unscaled + # Although lokr_t1 not used in algo, it still defined in LoKR weights + self.check_keys( + values, + { + "lokr_w1", + "lokr_w1_a", + "lokr_w1_b", + "lokr_w2", + "lokr_w2_a", + "lokr_w2_b", + "lokr_t1", + "lokr_t2", + }, + ) + def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: w1: Optional[torch.Tensor] = self.w1 if w1 is None: @@ -275,6 +318,7 @@ class FullLayer(LoRALayerBase): self.bias = values.get("diff_b", None) self.rank = None # unscaled + self.check_keys(values, {"diff", "diff_b"}) def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: return self.weight @@ -305,6 +349,7 @@ class IA3Layer(LoRALayerBase): self.on_input = values["on_input"] self.rank = None # unscaled + self.check_keys(values, {"weight", "on_input"}) def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: weight = self.weight