mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add layer keys check
This commit is contained in:
parent
8a9e2f57a4
commit
653f63ae71
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
import bisect
|
import bisect
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
@ -70,6 +70,15 @@ class LoRALayerBase:
|
|||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
|
||||||
|
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
|
||||||
|
unknown_keys = set(values.keys()) - all_known_keys
|
||||||
|
if unknown_keys:
|
||||||
|
# TODO: how to warn log?
|
||||||
|
print(
|
||||||
|
f"[WARN] Unexpected keys found in LoRA/LyCORIS layer, model might work incorrectly! Keys: {unknown_keys}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# TODO: find and debug lora/locon with bias
|
# TODO: find and debug lora/locon with bias
|
||||||
class LoRALayer(LoRALayerBase):
|
class LoRALayer(LoRALayerBase):
|
||||||
@ -89,6 +98,14 @@ class LoRALayer(LoRALayerBase):
|
|||||||
self.mid = values.get("lora_mid.weight", None)
|
self.mid = values.get("lora_mid.weight", None)
|
||||||
|
|
||||||
self.rank = self.down.shape[0]
|
self.rank = self.down.shape[0]
|
||||||
|
self.check_keys(
|
||||||
|
values,
|
||||||
|
{
|
||||||
|
"lora_up.weight",
|
||||||
|
"lora_down.weight",
|
||||||
|
"lora_mid.weight",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
if self.mid is not None:
|
if self.mid is not None:
|
||||||
@ -136,6 +153,17 @@ class LoHALayer(LoRALayerBase):
|
|||||||
self.t2 = values.get("hada_t2", None)
|
self.t2 = values.get("hada_t2", None)
|
||||||
|
|
||||||
self.rank = self.w1_b.shape[0]
|
self.rank = self.w1_b.shape[0]
|
||||||
|
self.check_keys(
|
||||||
|
values,
|
||||||
|
{
|
||||||
|
"hada_w1_a",
|
||||||
|
"hada_w1_b",
|
||||||
|
"hada_w2_a",
|
||||||
|
"hada_w2_b",
|
||||||
|
"hada_t1",
|
||||||
|
"hada_t2",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
if self.t1 is None:
|
if self.t1 is None:
|
||||||
@ -204,6 +232,21 @@ class LoKRLayer(LoRALayerBase):
|
|||||||
else:
|
else:
|
||||||
self.rank = None # unscaled
|
self.rank = None # unscaled
|
||||||
|
|
||||||
|
# Although lokr_t1 not used in algo, it still defined in LoKR weights
|
||||||
|
self.check_keys(
|
||||||
|
values,
|
||||||
|
{
|
||||||
|
"lokr_w1",
|
||||||
|
"lokr_w1_a",
|
||||||
|
"lokr_w1_b",
|
||||||
|
"lokr_w2",
|
||||||
|
"lokr_w2_a",
|
||||||
|
"lokr_w2_b",
|
||||||
|
"lokr_t1",
|
||||||
|
"lokr_t2",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
w1: Optional[torch.Tensor] = self.w1
|
w1: Optional[torch.Tensor] = self.w1
|
||||||
if w1 is None:
|
if w1 is None:
|
||||||
@ -275,6 +318,7 @@ class FullLayer(LoRALayerBase):
|
|||||||
self.bias = values.get("diff_b", None)
|
self.bias = values.get("diff_b", None)
|
||||||
|
|
||||||
self.rank = None # unscaled
|
self.rank = None # unscaled
|
||||||
|
self.check_keys(values, {"diff", "diff_b"})
|
||||||
|
|
||||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
return self.weight
|
return self.weight
|
||||||
@ -305,6 +349,7 @@ class IA3Layer(LoRALayerBase):
|
|||||||
self.on_input = values["on_input"]
|
self.on_input = values["on_input"]
|
||||||
|
|
||||||
self.rank = None # unscaled
|
self.rank = None # unscaled
|
||||||
|
self.check_keys(values, {"weight", "on_input"})
|
||||||
|
|
||||||
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor:
|
||||||
weight = self.weight
|
weight = self.weight
|
||||||
|
Loading…
Reference in New Issue
Block a user