Improve the robustness of the logic for determining the PEFT model type. In particular, so that it doesn't incorrectly detect DoRA models as LoRA models.

This commit is contained in:
Ryan Dick 2024-04-04 11:10:09 -04:00
parent 132aadca15
commit 4af258615f

View File

@ -3,7 +3,7 @@
import bisect
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Type, Union
import torch
from safetensors.torch import load_file
@ -457,6 +457,55 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
return new_state_dict
@classmethod
def _keys_match(cls, keys: set[str], required_keys: set[str], optional_keys: set[str]) -> bool:
"""Check if the set of keys matches the required and optional keys."""
if len(required_keys - keys) > 0:
# missing required keys.
return False
non_required_keys = keys - required_keys
for k in non_required_keys:
if k not in optional_keys:
# unexpected key
return False
return True
@classmethod
def get_layer_type_from_state_dict_keys(cls, peft_layer_keys: set[str]) -> Type[AnyLoRALayer]:
"""Infer the parameter-efficient finetuning model type from the state dict keys."""
common_optional_keys = {"alpha", "bias_indices", "bias_values", "bias_size"}
if cls._keys_match(
peft_layer_keys,
required_keys={"lora_down.weight", "lora_up.weight"},
optional_keys=common_optional_keys | {"lora_mid.weight"},
):
return LoRALayer
if cls._keys_match(
peft_layer_keys,
required_keys={"hada_w1_b", "hada_w1_a", "hada_w2_b", "hada_w2_a"},
optional_keys=common_optional_keys | {"hada_t1", "hada_t2"},
):
return LoHALayer
if cls._keys_match(
peft_layer_keys,
required_keys=set(),
optional_keys=common_optional_keys
| {"lokr_w1", "lokr_w1_a", "lokr_w1_b", "lokr_w2", "lokr_w2_a", "lokr_w2_b", "lokr_t2"},
):
return LoKRLayer
if cls._keys_match(peft_layer_keys, required_keys={"diff"}, optional_keys=common_optional_keys):
return FullLayer
if cls._keys_match(peft_layer_keys, required_keys={"weight", "on_input"}, optional_keys=common_optional_keys):
return IA3Layer
raise ValueError(f"Unsupported PEFT model type with keys: {peft_layer_keys}")
@classmethod
def from_checkpoint(
cls,
@ -486,30 +535,14 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
if base_model == BaseModelType.StableDiffusionXL:
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
# We assume that all layers have the same PEFT layer type. This saves time by not having to infer the type for
# each layer.
first_module_key = next(iter(state_dict))
peft_layer_keys = set(state_dict[first_module_key].keys())
layer_cls = cls.get_layer_type_from_state_dict_keys(peft_layer_keys)
for layer_key, values in state_dict.items():
# lora and locon
if "lora_down.weight" in values:
layer: AnyLoRALayer = LoRALayer(layer_key, values)
# loha
elif "hada_w1_b" in values:
layer = LoHALayer(layer_key, values)
# lokr
elif "lokr_w1_b" in values or "lokr_w1" in values:
layer = LoKRLayer(layer_key, values)
# diff
elif "diff" in values:
layer = FullLayer(layer_key, values)
# ia3
elif "weight" in values and "on_input" in values:
layer = IA3Layer(layer_key, values)
else:
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
raise Exception("Unknown lora format!")
layer = layer_cls(layer_key, values)
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()