Ignore bundled embeddings in conversion

This commit is contained in:
Kent Keirsey
2025-06-23 10:05:55 -04:00
parent 61b049ad35
commit 77e029a49f

View File

@ -10,6 +10,8 @@ def convert_sdxl_keys_to_diffusers_format(state_dict: Dict[str, T]) -> dict[str,
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
diffusers format, then this function will have no effect.
Keys that start with "bundle_emb" will be dropped/ignored from the output state_dict.
This function is adapted from:
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L385-L409
@ -20,7 +22,7 @@ def convert_sdxl_keys_to_diffusers_format(state_dict: Dict[str, T]) -> dict[str,
ValueError: If state_dict contains an unrecognized key, or not all keys could be converted.
Returns:
Dict[str, Tensor]: The diffusers-format state_dict.
Dict[str, Tensor]: The diffusers-format state_dict with bundle_emb keys removed.
"""
converted_count = 0 # The number of Stability AI keys converted to diffusers format.
not_converted_count = 0 # The number of keys that were not converted.
@ -33,6 +35,10 @@ def convert_sdxl_keys_to_diffusers_format(state_dict: Dict[str, T]) -> dict[str,
new_state_dict: dict[str, T] = {}
for full_key, value in state_dict.items():
# Skip keys that start with "bundle_emb"
if full_key.startswith("bundle_emb"):
continue
if full_key.startswith("lora_unet_"):
search_key = full_key.replace("lora_unet_", "")
# Use bisect to find the key in stability_unet_keys that *may* match the search_key's prefix.