Add support for loading SDXL LoRA weights in diffusers format.

This commit is contained in:
Ryan Dick 2023-08-11 11:35:47 -04:00 committed by Kent Keirsey
parent 824ca92760
commit 6a8eb392b2

View File

@ -1,18 +1,21 @@
import bisect
import os import os
import torch
from enum import Enum from enum import Enum
from typing import Optional, Dict, Union, Literal, Any
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Union
import torch
from safetensors.torch import load_file from safetensors.torch import load_file
from .base import ( from .base import (
BaseModelType,
InvalidModelException,
ModelBase, ModelBase,
ModelConfigBase, ModelConfigBase,
BaseModelType, ModelNotFoundException,
ModelType, ModelType,
SubModelType, SubModelType,
classproperty, classproperty,
InvalidModelException,
ModelNotFoundException,
) )
@ -482,30 +485,61 @@ class LoRAModelRaw: # (torch.nn.Module):
return model_size return model_size
@classmethod @classmethod
def _convert_sdxl_compvis_keys(cls, state_dict): def _convert_sdxl_keys_to_diffusers_format(cls, state_dict):
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
diffusers format, then this function will have no effect.
This function is adapted from:
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
Args:
state_dict (Dict[str, Tensor]): The SDXL LoRA state_dict.
Raises:
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
Returns:
Dict[str, Tensor]: The diffusers-format state_dict.
"""
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
not_converted_count = 0 # The number of keys that were not converted.
# Get a sorted list of Stability AI UNet keys so that we can efficiently search for keys with matching prefixes.
# For example, we want to efficiently find `input_blocks_4_1` in the list when searching for
# `input_blocks_4_1_proj_in`.
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
stability_unet_keys.sort()
new_state_dict = dict() new_state_dict = dict()
for full_key, value in state_dict.items(): for full_key, value in state_dict.items():
if full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"): if full_key.startswith("lora_unet_"):
continue # clip same search_key = full_key.replace("lora_unet_", "")
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.
if not full_key.startswith("lora_unet_"): position = bisect.bisect_right(stability_unet_keys, search_key)
raise NotImplementedError(f"Unknown prefix for sdxl lora key - {full_key}") map_key = stability_unet_keys[position - 1]
src_key = full_key.replace("lora_unet_", "") # Now, check if the map_key *actually* matches the search_key.
try: if search_key.startswith(map_key):
dst_key = None new_key = full_key.replace(map_key, SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP[map_key])
while "_" in src_key:
if src_key in SDXL_UNET_COMPVIS_MAP:
dst_key = SDXL_UNET_COMPVIS_MAP[src_key]
break
src_key = "_".join(src_key.split("_")[:-1])
if dst_key is None:
raise Exception(f"Unknown sdxl lora key - {full_key}")
new_key = full_key.replace(src_key, dst_key)
except:
print(SDXL_UNET_COMPVIS_MAP)
raise
new_state_dict[new_key] = value new_state_dict[new_key] = value
converted_count += 1
else:
new_state_dict[full_key] = value
not_converted_count += 1
elif full_key.startswith("lora_te1_") or full_key.startswith("lora_te2_"):
# The CLIP text encoders have the same keys in both Stability AI and diffusers formats.
new_state_dict[full_key] = value
continue
else:
raise ValueError(f"Unrecognized SDXL LoRA key prefix: '{full_key}'.")
if converted_count > 0 and not_converted_count > 0:
raise ValueError(
f"The SDXL LoRA could only be partially converted to diffusers format. converted={converted_count},"
f" not_converted={not_converted_count}"
)
return new_state_dict return new_state_dict
@classmethod @classmethod
@ -537,7 +571,7 @@ class LoRAModelRaw: # (torch.nn.Module):
state_dict = cls._group_state(state_dict) state_dict = cls._group_state(state_dict)
if base_model == BaseModelType.StableDiffusionXL: if base_model == BaseModelType.StableDiffusionXL:
state_dict = cls._convert_sdxl_compvis_keys(state_dict) state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
for layer_key, values in state_dict.items(): for layer_key, values in state_dict.items():
# lora and locon # lora and locon
@ -588,6 +622,7 @@ class LoRAModelRaw: # (torch.nn.Module):
# code from # code from
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32 # https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
def make_sdxl_unet_conversion_map(): def make_sdxl_unet_conversion_map():
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
unet_conversion_map_layer = [] unet_conversion_map_layer = []
for i in range(3): # num_blocks is 3 in sdxl for i in range(3): # num_blocks is 3 in sdxl
@ -671,7 +706,6 @@ def make_sdxl_unet_conversion_map():
return unet_conversion_map return unet_conversion_map
SDXL_UNET_COMPVIS_MAP = { SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
f"{sd}".rstrip(".").replace(".", "_"): f"{hf}".rstrip(".").replace(".", "_") sd.rstrip(".").replace(".", "_"): hf.rstrip(".").replace(".", "_") for sd, hf in make_sdxl_unet_conversion_map()
for sd, hf in make_sdxl_unet_conversion_map()
} }