diff --git a/invokeai/backend/lora.py b/invokeai/backend/lora.py index 0b7128034a..77c2150f33 100644 --- a/invokeai/backend/lora.py +++ b/invokeai/backend/lora.py @@ -3,7 +3,7 @@ import bisect from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Type, Union import torch from safetensors.torch import load_file @@ -457,6 +457,55 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module): return new_state_dict + @classmethod + def _keys_match(cls, keys: set[str], required_keys: set[str], optional_keys: set[str]) -> bool: + """Check if the set of keys matches the required and optional keys.""" + if len(required_keys - keys) > 0: + # missing required keys. + return False + + non_required_keys = keys - required_keys + for k in non_required_keys: + if k not in optional_keys: + # unexpected key + return False + + return True + + @classmethod + def get_layer_type_from_state_dict_keys(cls, peft_layer_keys: set[str]) -> Type[AnyLoRALayer]: + """Infer the parameter-efficient finetuning model type from the state dict keys.""" + common_optional_keys = {"alpha", "bias_indices", "bias_values", "bias_size"} + if cls._keys_match( + peft_layer_keys, + required_keys={"lora_down.weight", "lora_up.weight"}, + optional_keys=common_optional_keys | {"lora_mid.weight"}, + ): + return LoRALayer + + if cls._keys_match( + peft_layer_keys, + required_keys={"hada_w1_b", "hada_w1_a", "hada_w2_b", "hada_w2_a"}, + optional_keys=common_optional_keys | {"hada_t1", "hada_t2"}, + ): + return LoHALayer + + if cls._keys_match( + peft_layer_keys, + required_keys=set(), + optional_keys=common_optional_keys + | {"lokr_w1", "lokr_w1_a", "lokr_w1_b", "lokr_w2", "lokr_w2_a", "lokr_w2_b", "lokr_t2"}, + ): + return LoKRLayer + + if cls._keys_match(peft_layer_keys, required_keys={"diff"}, optional_keys=common_optional_keys): + return FullLayer + + if cls._keys_match(peft_layer_keys, required_keys={"weight", "on_input"}, optional_keys=common_optional_keys): + return IA3Layer + + raise ValueError(f"Unsupported PEFT model type with keys: {peft_layer_keys}") + @classmethod def from_checkpoint( cls, @@ -486,30 +535,14 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module): if base_model == BaseModelType.StableDiffusionXL: state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict) + # We assume that all layers have the same PEFT layer type. This saves time by not having to infer the type for + # each layer. + first_module_key = next(iter(state_dict)) + peft_layer_keys = set(state_dict[first_module_key].keys()) + layer_cls = cls.get_layer_type_from_state_dict_keys(peft_layer_keys) + for layer_key, values in state_dict.items(): - # lora and locon - if "lora_down.weight" in values: - layer: AnyLoRALayer = LoRALayer(layer_key, values) - - # loha - elif "hada_w1_b" in values: - layer = LoHALayer(layer_key, values) - - # lokr - elif "lokr_w1_b" in values or "lokr_w1" in values: - layer = LoKRLayer(layer_key, values) - - # diff - elif "diff" in values: - layer = FullLayer(layer_key, values) - - # ia3 - elif "weight" in values and "on_input" in values: - layer = IA3Layer(layer_key, values) - - else: - print(f">> Encountered unknown lora layer module in {model.name}: {layer_key} - {list(values.keys())}") - raise Exception("Unknown lora format!") + layer = layer_cls(layer_key, values) # lower memory consumption by removing already parsed layer values state_dict[layer_key].clear()