Improve the robustness of the logic for determining the PEFT model type. In particular, so that it doesn't incorrectly detect DoRA models as LoRA models.

This commit is contained in:
Ryan Dick 2024-04-04 11:10:09 -04:00
parent 132aadca15
commit 4af258615f

View File

@ -3,7 +3,7 @@
import bisect import bisect
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Type, Union
import torch import torch
from safetensors.torch import load_file from safetensors.torch import load_file
@ -457,6 +457,55 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
return new_state_dict return new_state_dict
@classmethod
def _keys_match(cls, keys: set[str], required_keys: set[str], optional_keys: set[str]) -> bool:
"""Check if the set of keys matches the required and optional keys."""
if len(required_keys - keys) > 0:
# missing required keys.
return False
non_required_keys = keys - required_keys
for k in non_required_keys:
if k not in optional_keys:
# unexpected key
return False
return True
@classmethod
def get_layer_type_from_state_dict_keys(cls, peft_layer_keys: set[str]) -> Type[AnyLoRALayer]:
"""Infer the parameter-efficient finetuning model type from the state dict keys."""
common_optional_keys = {"alpha", "bias_indices", "bias_values", "bias_size"}
if cls._keys_match(
peft_layer_keys,
required_keys={"lora_down.weight", "lora_up.weight"},
optional_keys=common_optional_keys | {"lora_mid.weight"},
):
return LoRALayer
if cls._keys_match(
peft_layer_keys,
required_keys={"hada_w1_b", "hada_w1_a", "hada_w2_b", "hada_w2_a"},
optional_keys=common_optional_keys | {"hada_t1", "hada_t2"},
):
return LoHALayer
if cls._keys_match(
peft_layer_keys,
required_keys=set(),
optional_keys=common_optional_keys
| {"lokr_w1", "lokr_w1_a", "lokr_w1_b", "lokr_w2", "lokr_w2_a", "lokr_w2_b", "lokr_t2"},
):
return LoKRLayer
if cls._keys_match(peft_layer_keys, required_keys={"diff"}, optional_keys=common_optional_keys):
return FullLayer
if cls._keys_match(peft_layer_keys, required_keys={"weight", "on_input"}, optional_keys=common_optional_keys):
return IA3Layer
raise ValueError(f"Unsupported PEFT model type with keys: {peft_layer_keys}")
@classmethod @classmethod
def from_checkpoint( def from_checkpoint(
cls, cls,
@ -486,30 +535,14 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
if base_model == BaseModelType.StableDiffusionXL: if base_model == BaseModelType.StableDiffusionXL:
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict) state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
# We assume that all layers have the same PEFT layer type. This saves time by not having to infer the type for
# each layer.
first_module_key = next(iter(state_dict))
peft_layer_keys = set(state_dict[first_module_key].keys())
layer_cls = cls.get_layer_type_from_state_dict_keys(peft_layer_keys)
for layer_key, values in state_dict.items(): for layer_key, values in state_dict.items():
# lora and locon layer = layer_cls(layer_key, values)
if "lora_down.weight" in values:
layer: AnyLoRALayer = LoRALayer(layer_key, values)
# loha
elif "hada_w1_b" in values:
layer = LoHALayer(layer_key, values)
# lokr
elif "lokr_w1_b" in values or "lokr_w1" in values:
layer = LoKRLayer(layer_key, values)
# diff
elif "diff" in values:
layer = FullLayer(layer_key, values)
# ia3
elif "weight" in values and "on_input" in values:
layer = IA3Layer(layer_key, values)
else:
print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}")
raise Exception("Unknown lora format!")
# lower memory consumption by removing already parsed layer values # lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear() state_dict[layer_key].clear()