mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Improve the robustness of the logic for determining the PEFT model type. In particular, so that it doesn't incorrectly detect DoRA models as LoRA models.
This commit is contained in:
parent
132aadca15
commit
4af258615f
@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
import bisect
|
import bisect
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
@ -457,6 +457,55 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
|||||||
|
|
||||||
return new_state_dict
|
return new_state_dict
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _keys_match(cls, keys: set[str], required_keys: set[str], optional_keys: set[str]) -> bool:
|
||||||
|
"""Check if the set of keys matches the required and optional keys."""
|
||||||
|
if len(required_keys - keys) > 0:
|
||||||
|
# missing required keys.
|
||||||
|
return False
|
||||||
|
|
||||||
|
non_required_keys = keys - required_keys
|
||||||
|
for k in non_required_keys:
|
||||||
|
if k not in optional_keys:
|
||||||
|
# unexpected key
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_layer_type_from_state_dict_keys(cls, peft_layer_keys: set[str]) -> Type[AnyLoRALayer]:
|
||||||
|
"""Infer the parameter-efficient finetuning model type from the state dict keys."""
|
||||||
|
common_optional_keys = {"alpha", "bias_indices", "bias_values", "bias_size"}
|
||||||
|
if cls._keys_match(
|
||||||
|
peft_layer_keys,
|
||||||
|
required_keys={"lora_down.weight", "lora_up.weight"},
|
||||||
|
optional_keys=common_optional_keys | {"lora_mid.weight"},
|
||||||
|
):
|
||||||
|
return LoRALayer
|
||||||
|
|
||||||
|
if cls._keys_match(
|
||||||
|
peft_layer_keys,
|
||||||
|
required_keys={"hada_w1_b", "hada_w1_a", "hada_w2_b", "hada_w2_a"},
|
||||||
|
optional_keys=common_optional_keys | {"hada_t1", "hada_t2"},
|
||||||
|
):
|
||||||
|
return LoHALayer
|
||||||
|
|
||||||
|
if cls._keys_match(
|
||||||
|
peft_layer_keys,
|
||||||
|
required_keys=set(),
|
||||||
|
optional_keys=common_optional_keys
|
||||||
|
| {"lokr_w1", "lokr_w1_a", "lokr_w1_b", "lokr_w2", "lokr_w2_a", "lokr_w2_b", "lokr_t2"},
|
||||||
|
):
|
||||||
|
return LoKRLayer
|
||||||
|
|
||||||
|
if cls._keys_match(peft_layer_keys, required_keys={"diff"}, optional_keys=common_optional_keys):
|
||||||
|
return FullLayer
|
||||||
|
|
||||||
|
if cls._keys_match(peft_layer_keys, required_keys={"weight", "on_input"}, optional_keys=common_optional_keys):
|
||||||
|
return IA3Layer
|
||||||
|
|
||||||
|
raise ValueError(f"Unsupported PEFT model type with keys: {peft_layer_keys}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_checkpoint(
|
def from_checkpoint(
|
||||||
cls,
|
cls,
|
||||||
@ -486,30 +535,14 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
|||||||
if base_model == BaseModelType.StableDiffusionXL:
|
if base_model == BaseModelType.StableDiffusionXL:
|
||||||
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||||
|
|
||||||
|
# We assume that all layers have the same PEFT layer type. This saves time by not having to infer the type for
|
||||||
|
# each layer.
|
||||||
|
first_module_key = next(iter(state_dict))
|
||||||
|
peft_layer_keys = set(state_dict[first_module_key].keys())
|
||||||
|
layer_cls = cls.get_layer_type_from_state_dict_keys(peft_layer_keys)
|
||||||
|
|
||||||
for layer_key, values in state_dict.items():
|
for layer_key, values in state_dict.items():
|
||||||
# lora and locon
|
layer = layer_cls(layer_key, values)
|
||||||
if "lora_down.weight" in values:
|
|
||||||
layer: AnyLoRALayer = LoRALayer(layer_key, values)
|
|
||||||
|
|
||||||
# loha
|
|
||||||
elif "hada_w1_b" in values:
|
|
||||||
layer = LoHALayer(layer_key, values)
|
|
||||||
|
|
||||||
# lokr
|
|
||||||
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
|
||||||
layer = LoKRLayer(layer_key, values)
|
|
||||||
|
|
||||||
# diff
|
|
||||||
elif "diff" in values:
|
|
||||||
layer = FullLayer(layer_key, values)
|
|
||||||
|
|
||||||
# ia3
|
|
||||||
elif "weight" in values and "on_input" in values:
|
|
||||||
layer = IA3Layer(layer_key, values)
|
|
||||||
|
|
||||||
else:
|
|
||||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
|
||||||
raise Exception("Unknown lora format!")
|
|
||||||
|
|
||||||
# lower memory consumption by removing already parsed layer values
|
# lower memory consumption by removing already parsed layer values
|
||||||
state_dict[layer_key].clear()
|
state_dict[layer_key].clear()
|
||||||
|
Loading…
Reference in New Issue
Block a user