From 6a8eb392b28045412d97ec1e2496cb04ddb0c17f Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 11 Aug 2023 11:35:47 -0400 Subject: [PATCH] Add support for loading SDXL LoRA weights in diffusers format. --- .../backend/model_management/models/lora.py | 92 +++++++++++++------ 1 file changed, 63 insertions(+), 29 deletions(-) diff --git a/invokeai/backend/model_management/models/lora.py b/invokeai/backend/model_management/models/lora.py index 88a50fb4fd..b6f321d60b 100644 --- a/invokeai/backend/model_management/models/lora.py +++ b/invokeai/backend/model_management/models/lora.py @@ -1,18 +1,21 @@ +import bisect import os -import torch from enum import Enum -from typing import Optional, Dict, Union, Literal, Any from pathlib import Path +from typing import Dict, Optional, Union + +import torch from safetensors.torch import load_file + from .base import ( + BaseModelType, + InvalidModelException, ModelBase, ModelConfigBase, - BaseModelType, + ModelNotFoundException, ModelType, SubModelType, classproperty, - InvalidModelException, - ModelNotFoundException, ) @@ -482,30 +485,61 @@ class LoRAModelRaw: # (torch.nn.Module): return model_size @classmethod - def _convert_sdxl_compvis_keys(cls, state_dict): + def _convert_sdxl_keys_to_diffusers_format(cls, state_dict): + """Convert the keys of an SDXL LoRA state_dict to diffusers format. + + The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in + diffusers format, then this function will have no effect. + + This function is adapted from: + https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409 + + Args: + state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict. + + Raises: + ValueError: If state_dict contains an unrecognized key, or not all keys could be converted. + + Returns: + Dict[str, Tensor]: The diffusers-format state_dict. + """ + converted_count = 0 # The number of Stability AI keys converted to diffusers format. + not_converted_count = 0 # The number of keys that were not converted. + + # Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes. + # For example, we want to efficiently find `input_blocks_4_1` in the list when searching for + # `input_blocks_4_1_proj_in`. + stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP) + stability_unet_keys.sort() + new_state_dict = dict() for full_key, value in state_dict.items(): - if full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"): - continue # clip same + if full_key.startswith("lora_unet_"): + search_key = full_key.replace("lora_unet_", "") + # Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix. + position = bisect.bisect_right(stability_unet_keys, search_key) + map_key = stability_unet_keys[position - 1] + # Now, check if the map_key *actually* matches the search_key. + if search_key.startswith(map_key): + new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key]) + new_state_dict[new_key] = value + converted_count += 1 + else: + new_state_dict[full_key] = value + not_converted_count += 1 + elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"): + # The CLIP text encoders have the same keys in both Stability AI and diffusers formats. + new_state_dict[full_key] = value + continue + else: + raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.") - if not full_key.startswith("lora_unet_"): - raise NotImplementedError(f"Unknown prefix for sdxl lora key - {full_key}") - src_key = full_key.replace("lora_unet_", "") - try: - dst_key = None - while "_" in src_key: - if src_key in SDXL_UNET_COMPVIS_MAP: - dst_key = SDXL_UNET_COMPVIS_MAP[src_key] - break - src_key = "_".join(src_key.split("_")[:-1]) + if converted_count > 0 and not_converted_count > 0: + raise ValueError( + f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count}," + f" not_converted={not_converted_count}" + ) - if dst_key is None: - raise Exception(f"Unknown sdxl lora key - {full_key}") - new_key = full_key.replace(src_key, dst_key) - except: - print(SDXL_UNET_COMPVIS_MAP) - raise - new_state_dict[new_key] = value return new_state_dict @classmethod @@ -537,7 +571,7 @@ class LoRAModelRaw: # (torch.nn.Module): state_dict = cls._group_state(state_dict) if base_model == BaseModelType.StableDiffusionXL: - state_dict = cls._convert_sdxl_compvis_keys(state_dict) + state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict) for layer_key, values in state_dict.items(): # lora and locon @@ -588,6 +622,7 @@ class LoRAModelRaw: # (torch.nn.Module): # code from # https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32 def make_sdxl_unet_conversion_map(): + """Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format.""" unet_conversion_map_layer = [] for i in range(3): # num_blocks is 3 in sdxl @@ -671,7 +706,6 @@ def make_sdxl_unet_conversion_map(): return unet_conversion_map -SDXL_UNET_COMPVIS_MAP = { - f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_") - for sd, hf in make_sdxl_unet_conversion_map() +SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = { + sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map() }