mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Improve the robustness of the logic for determining the PEFT model type. In particular, so that it doesn't incorrectly detect DoRA models as LoRA models.
This commit is contained in:
parent
132aadca15
commit
4af258615f
@ -3,7 +3,7 @@
|
||||
|
||||
import bisect
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
@ -457,6 +457,55 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
|
||||
return new_state_dict
|
||||
|
||||
@classmethod
|
||||
def _keys_match(cls, keys: set[str], required_keys: set[str], optional_keys: set[str]) -> bool:
|
||||
"""Check if the set of keys matches the required and optional keys."""
|
||||
if len(required_keys - keys) > 0:
|
||||
# missing required keys.
|
||||
return False
|
||||
|
||||
non_required_keys = keys - required_keys
|
||||
for k in non_required_keys:
|
||||
if k not in optional_keys:
|
||||
# unexpected key
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_layer_type_from_state_dict_keys(cls, peft_layer_keys: set[str]) -> Type[AnyLoRALayer]:
|
||||
"""Infer the parameter-efficient finetuning model type from the state dict keys."""
|
||||
common_optional_keys = {"alpha", "bias_indices", "bias_values", "bias_size"}
|
||||
if cls._keys_match(
|
||||
peft_layer_keys,
|
||||
required_keys={"lora_down.weight", "lora_up.weight"},
|
||||
optional_keys=common_optional_keys | {"lora_mid.weight"},
|
||||
):
|
||||
return LoRALayer
|
||||
|
||||
if cls._keys_match(
|
||||
peft_layer_keys,
|
||||
required_keys={"hada_w1_b", "hada_w1_a", "hada_w2_b", "hada_w2_a"},
|
||||
optional_keys=common_optional_keys | {"hada_t1", "hada_t2"},
|
||||
):
|
||||
return LoHALayer
|
||||
|
||||
if cls._keys_match(
|
||||
peft_layer_keys,
|
||||
required_keys=set(),
|
||||
optional_keys=common_optional_keys
|
||||
| {"lokr_w1", "lokr_w1_a", "lokr_w1_b", "lokr_w2", "lokr_w2_a", "lokr_w2_b", "lokr_t2"},
|
||||
):
|
||||
return LoKRLayer
|
||||
|
||||
if cls._keys_match(peft_layer_keys, required_keys={"diff"}, optional_keys=common_optional_keys):
|
||||
return FullLayer
|
||||
|
||||
if cls._keys_match(peft_layer_keys, required_keys={"weight", "on_input"}, optional_keys=common_optional_keys):
|
||||
return IA3Layer
|
||||
|
||||
raise ValueError(f"Unsupported PEFT model type with keys: {peft_layer_keys}")
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
@ -486,30 +535,14 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
if base_model == BaseModelType.StableDiffusionXL:
|
||||
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||
|
||||
# We assume that all layers have the same PEFT layer type. This saves time by not having to infer the type for
|
||||
# each layer.
|
||||
first_module_key = next(iter(state_dict))
|
||||
peft_layer_keys = set(state_dict[first_module_key].keys())
|
||||
layer_cls = cls.get_layer_type_from_state_dict_keys(peft_layer_keys)
|
||||
|
||||
for layer_key, values in state_dict.items():
|
||||
# lora and locon
|
||||
if "lora_down.weight" in values:
|
||||
layer: AnyLoRALayer = LoRALayer(layer_key, values)
|
||||
|
||||
# loha
|
||||
elif "hada_w1_b" in values:
|
||||
layer = LoHALayer(layer_key, values)
|
||||
|
||||
# lokr
|
||||
elif "lokr_w1_b" in values or "lokr_w1" in values:
|
||||
layer = LoKRLayer(layer_key, values)
|
||||
|
||||
# diff
|
||||
elif "diff" in values:
|
||||
layer = FullLayer(layer_key, values)
|
||||
|
||||
# ia3
|
||||
elif "weight" in values and "on_input" in values:
|
||||
layer = IA3Layer(layer_key, values)
|
||||
|
||||
else:
|
||||
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
|
||||
raise Exception("Unknown lora format!")
|
||||
layer = layer_cls(layer_key, values)
|
||||
|
||||
# lower memory consumption by removing already parsed layer values
|
||||
state_dict[layer_key].clear()
|
||||
|
Loading…
Reference in New Issue
Block a user