diff --git a/invokeai/backend/lora/sdxl_state_dict_utils.py b/invokeai/backend/lora/sdxl_state_dict_utils.py index 2d107b5d5f..643c7d353b 100644 --- a/invokeai/backend/lora/sdxl_state_dict_utils.py +++ b/invokeai/backend/lora/sdxl_state_dict_utils.py @@ -1,74 +1,15 @@ import bisect -from typing import Dict, List, Tuple import torch -from invokeai.backend.lora.sdxl_state_dict_utils import SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP +def make_sdxl_unet_conversion_map() -> list[tuple[str, str]]: + """Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format. -def convert_sdxl_keys_to_diffusers_format(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - """Convert the keys of an SDXL LoRA state_dict to diffusers format. - - The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in - diffusers format, then this function will have no effect. - - This function is adapted from: - https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409 - - Args: - state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict. - - Raises: - ValueError: If state_dict contains an unrecognized key, or not all keys could be converted. - - Returns: - Dict[str, Tensor]: The diffusers-format state_dict. + Ported from: + https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32 """ - converted_count = 0 # The number of Stability AI keys converted to diffusers format. - not_converted_count = 0 # The number of keys that were not converted. - - # Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes. - # For example, we want to efficiently find `input_blocks_4_1` in the list when searching for - # `input_blocks_4_1_proj_in`. - stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP) - stability_unet_keys.sort() - - new_state_dict = {} - for full_key, value in state_dict.items(): - if full_key.startswith("lora_unet_"): - search_key = full_key.replace("lora_unet_", "") - # Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix. - position = bisect.bisect_right(stability_unet_keys, search_key) - map_key = stability_unet_keys[position - 1] - # Now, check if the map_key *actually* matches the search_key. - if search_key.startswith(map_key): - new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key]) - new_state_dict[new_key] = value - converted_count += 1 - else: - new_state_dict[full_key] = value - not_converted_count += 1 - elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"): - # The CLIP text encoders have the same keys in both Stability AI and diffusers formats. - new_state_dict[full_key] = value - continue - else: - raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.") - - if converted_count > 0 and not_converted_count > 0: - raise ValueError( - f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count}," - f" not_converted={not_converted_count}" - ) - - return new_state_dict - - -# code from -# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32 -def make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]: - """Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format.""" - unet_conversion_map_layer = [] + unet_conversion_map_layer: list[tuple[str, str]] = [] for i in range(3): # num_blocks is 3 in sdxl # loop over downblocks/upblocks @@ -126,7 +67,7 @@ def make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]: ("skip_connection.", "conv_shortcut."), ] - unet_conversion_map = [] + unet_conversion_map: list[tuple[str, str]] = [] for sd, hf in unet_conversion_map_layer: if "resnets" in hf: for sd_res, hf_res in unet_conversion_map_resnet: @@ -154,3 +95,61 @@ def make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]: SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = { sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map() } + + +def convert_sdxl_keys_to_diffusers_format(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """Convert the keys of an SDXL LoRA state_dict to diffusers format. + + The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in + diffusers format, then this function will have no effect. + + This function is adapted from: + https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409 + + Args: + state_dict (dict[str, Tensor]): The SDXL LoRA state_dict. + + Raises: + ValueError: If state_dict contains an unrecognized key, or not all keys could be converted. + + Returns: + dict[str, Tensor]: The diffusers-format state_dict. + """ + converted_count = 0 # The number of Stability AI keys converted to diffusers format. + not_converted_count = 0 # The number of keys that were not converted. + + # Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes. + # For example, we want to efficiently find `input_blocks_4_1` in the list when searching for + # `input_blocks_4_1_proj_in`. + stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP) + stability_unet_keys.sort() + + new_state_dict: dict[str, torch.Tensor] = {} + for full_key, value in state_dict.items(): + if full_key.startswith("lora_unet_"): + search_key = full_key.replace("lora_unet_", "") + # Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix. + position = bisect.bisect_right(stability_unet_keys, search_key) + map_key = stability_unet_keys[position - 1] + # Now, check if the map_key *actually* matches the search_key. + if search_key.startswith(map_key): + new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key]) + new_state_dict[new_key] = value + converted_count += 1 + else: + new_state_dict[full_key] = value + not_converted_count += 1 + elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"): + # The CLIP text encoders have the same keys in both Stability AI and diffusers formats. + new_state_dict[full_key] = value + continue + else: + raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.") + + if converted_count > 0 and not_converted_count > 0: + raise ValueError( + f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count}," + f" not_converted={not_converted_count}" + ) + + return new_state_dict