mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
(minor) tidy types in sdxl_format_utils.py
This commit is contained in:
parent
24950dea8c
commit
bfd5cdb311
@ -1,14 +1,15 @@
|
|||||||
import bisect
|
import bisect
|
||||||
from typing import Dict, List, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
# code from
|
def make_sdxl_unet_conversion_map() -> list[tuple[str, str]]:
|
||||||
# https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format.
|
||||||
def make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
|
|
||||||
"""Create a dict mapping state_dict keys from Stability AI SDXL format to diffusers SDXL format."""
|
Ported from:
|
||||||
unet_conversion_map_layer = []
|
https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L15C1-L97C32
|
||||||
|
"""
|
||||||
|
unet_conversion_map_layer: list[tuple[str, str]] = []
|
||||||
|
|
||||||
for i in range(3): # num_blocks is 3 in sdxl
|
for i in range(3): # num_blocks is 3 in sdxl
|
||||||
# loop over downblocks/upblocks
|
# loop over downblocks/upblocks
|
||||||
@ -66,7 +67,7 @@ def make_sdxl_unet_conversion_map() -> List[Tuple[str, str]]:
|
|||||||
("skip_connection.", "conv_shortcut."),
|
("skip_connection.", "conv_shortcut."),
|
||||||
]
|
]
|
||||||
|
|
||||||
unet_conversion_map = []
|
unet_conversion_map: list[tuple[str, str]] = []
|
||||||
for sd, hf in unet_conversion_map_layer:
|
for sd, hf in unet_conversion_map_layer:
|
||||||
if "resnets" in hf:
|
if "resnets" in hf:
|
||||||
for sd_res, hf_res in unet_conversion_map_resnet:
|
for sd_res, hf_res in unet_conversion_map_resnet:
|
||||||
@ -96,8 +97,7 @@ SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
def convert_sdxl_keys_to_diffusers_format(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||||
def convert_sdxl_keys_to_diffusers_format(cls, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
||||||
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
|
"""Convert the keys of an SDXL LoRA state_dict to diffusers format.
|
||||||
|
|
||||||
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
|
The input state_dict can be in either Stability AI format or diffusers format. If the state_dict is already in
|
||||||
@ -124,7 +124,7 @@ def convert_sdxl_keys_to_diffusers_format(cls, state_dict: Dict[str, torch.Tenso
|
|||||||
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
stability_unet_keys = list(SDXL_UNET_STABILITY_TO_DIFFUSERS_MAP)
|
||||||
stability_unet_keys.sort()
|
stability_unet_keys.sort()
|
||||||
|
|
||||||
new_state_dict = {}
|
new_state_dict: dict[str, torch.Tensor] = {}
|
||||||
for full_key, value in state_dict.items():
|
for full_key, value in state_dict.items():
|
||||||
if full_key.startswith("lora_unet_"):
|
if full_key.startswith("lora_unet_"):
|
||||||
search_key = full_key.replace("lora_unet_", "")
|
search_key = full_key.replace("lora_unet_", "")
|
||||||
|
Loading…
Reference in New Issue
Block a user