Ruff formatting

This commit is contained in:
Billy
2025-06-26 17:29:25 +10:00
committed by psychedelicious
parent 3cd4306eec
commit 6afbf31750
9 changed files with 125 additions and 97 deletions

View File

@ -37,7 +37,7 @@ from invokeai.app.util.misc import uuid_string
from invokeai.backend.model_hash.hash_validator import validate_hash
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from invokeai.backend.model_manager.model_on_disk import ModelOnDisk
from invokeai.backend.model_manager.omi import stable_diffusion_xl_1_lora, flux_dev_1_lora
from invokeai.backend.model_manager.omi import flux_dev_1_lora, stable_diffusion_xl_1_lora
from invokeai.backend.model_manager.taxonomy import (
AnyVariant,
BaseModelType,

View File

@ -1,6 +1,7 @@
from invokeai.backend.model_manager.omi.omi import convert_from_omi
from invokeai.backend.model_manager.omi.vendor.model_spec.architecture import (
flux_dev_1_lora,
stable_diffusion_xl_1_lora,
flux_dev_1_lora
)
from invokeai.backend.model_manager.omi.omi import convert_from_omi
__all__ = ["flux_dev_1_lora", "stable_diffusion_xl_1_lora", "convert_from_omi"]

View File

@ -1,10 +1,13 @@
from invokeai.backend.model_manager.model_on_disk import StateDict
from invokeai.backend.model_manager.omi.vendor.convert.lora import (
convert_flux_lora as omi_flux,
)
from invokeai.backend.model_manager.omi.vendor.convert.lora import (
convert_lora_util as lora_util,
convert_flux_lora as omi_flux,
)
from invokeai.backend.model_manager.omi.vendor.convert.lora import (
convert_sdxl_lora as omi_sdxl,
)
from invokeai.backend.model_manager.model_on_disk import StateDict
from invokeai.backend.model_manager.taxonomy import BaseModelType

View File

@ -1,4 +1,7 @@
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import LoraConversionKeySet, map_prefix_range
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import (
LoraConversionKeySet,
map_prefix_range,
)
def map_clip(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:

View File

@ -1,5 +1,8 @@
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_clip import map_clip
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import LoraConversionKeySet, map_prefix_range
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import (
LoraConversionKeySet,
map_prefix_range,
)
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_t5 import map_t5
@ -45,10 +48,16 @@ def __map_transformer(key_prefix: LoraConversionKeySet) -> list[LoraConversionKe
keys = []
keys += [LoraConversionKeySet("txt_in", "context_embedder", parent=key_prefix)]
keys += [LoraConversionKeySet("final_layer.adaLN_modulation.1", "norm_out.linear", parent=key_prefix, swap_chunks=True)]
keys += [
LoraConversionKeySet("final_layer.adaLN_modulation.1", "norm_out.linear", parent=key_prefix, swap_chunks=True)
]
keys += [LoraConversionKeySet("final_layer.linear", "proj_out", parent=key_prefix)]
keys += [LoraConversionKeySet("guidance_in.in_layer", "time_text_embed.guidance_embedder.linear_1", parent=key_prefix)]
keys += [LoraConversionKeySet("guidance_in.out_layer", "time_text_embed.guidance_embedder.linear_2", parent=key_prefix)]
keys += [
LoraConversionKeySet("guidance_in.in_layer", "time_text_embed.guidance_embedder.linear_1", parent=key_prefix)
]
keys += [
LoraConversionKeySet("guidance_in.out_layer", "time_text_embed.guidance_embedder.linear_2", parent=key_prefix)
]
keys += [LoraConversionKeySet("vector_in.in_layer", "time_text_embed.text_embedder.linear_1", parent=key_prefix)]
keys += [LoraConversionKeySet("vector_in.out_layer", "time_text_embed.text_embedder.linear_2", parent=key_prefix)]
keys += [LoraConversionKeySet("time_in.in_layer", "time_text_embed.timestep_embedder.linear_1", parent=key_prefix)]

View File

@ -1,20 +1,19 @@
import torch
from torch import Tensor
from typing_extensions import Self
class LoraConversionKeySet:
def __init__(
self,
omi_prefix: str,
diffusers_prefix: str,
legacy_diffusers_prefix: str | None = None,
parent: Self | None = None,
swap_chunks: bool = False,
filter_is_last: bool | None = None,
next_omi_prefix: str | None = None,
next_diffusers_prefix: str | None = None,
self,
omi_prefix: str,
diffusers_prefix: str,
legacy_diffusers_prefix: str | None = None,
parent: Self | None = None,
swap_chunks: bool = False,
filter_is_last: bool | None = None,
next_omi_prefix: str | None = None,
next_diffusers_prefix: str | None = None,
):
if parent is not None:
self.omi_prefix = combine(parent.omi_prefix, omi_prefix)
@ -24,9 +23,11 @@ class LoraConversionKeySet:
self.diffusers_prefix = diffusers_prefix
if legacy_diffusers_prefix is None:
self.legacy_diffusers_prefix = self.diffusers_prefix.replace('.', '_')
self.legacy_diffusers_prefix = self.diffusers_prefix.replace(".", "_")
elif parent is not None:
self.legacy_diffusers_prefix = combine(parent.legacy_diffusers_prefix, legacy_diffusers_prefix).replace('.', '_')
self.legacy_diffusers_prefix = combine(parent.legacy_diffusers_prefix, legacy_diffusers_prefix).replace(
".", "_"
)
else:
self.legacy_diffusers_prefix = legacy_diffusers_prefix
@ -42,11 +43,13 @@ class LoraConversionKeySet:
elif next_omi_prefix is not None and parent is not None:
self.next_omi_prefix = combine(parent.omi_prefix, next_omi_prefix)
self.next_diffusers_prefix = combine(parent.diffusers_prefix, next_diffusers_prefix)
self.next_legacy_diffusers_prefix = combine(parent.legacy_diffusers_prefix, next_diffusers_prefix).replace('.', '_')
self.next_legacy_diffusers_prefix = combine(parent.legacy_diffusers_prefix, next_diffusers_prefix).replace(
".", "_"
)
elif next_omi_prefix is not None and parent is None:
self.next_omi_prefix = next_omi_prefix
self.next_diffusers_prefix = next_diffusers_prefix
self.next_legacy_diffusers_prefix = next_diffusers_prefix.replace('.', '_')
self.next_legacy_diffusers_prefix = next_diffusers_prefix.replace(".", "_")
else:
self.next_omi_prefix = None
self.next_diffusers_prefix = None
@ -61,19 +64,19 @@ class LoraConversionKeySet:
def __get_legacy_diffusers(self, in_prefix: str, key: str) -> str:
key = self.legacy_diffusers_prefix + key.removeprefix(in_prefix)
suffix = key[key.rfind('.'):]
if suffix not in ['.alpha', '.dora_scale']: # some keys only have a single . in the suffix
suffix = key[key.removesuffix(suffix).rfind('.'):]
suffix = key[key.rfind(".") :]
if suffix not in [".alpha", ".dora_scale"]: # some keys only have a single . in the suffix
suffix = key[key.removesuffix(suffix).rfind(".") :]
key = key.removesuffix(suffix)
return key.replace('.', '_') + suffix
return key.replace(".", "_") + suffix
def get_key(self, in_prefix: str, key: str, target: str) -> str:
if target == 'omi':
if target == "omi":
return self.__get_omi(in_prefix, key)
elif target == 'diffusers':
elif target == "diffusers":
return self.__get_diffusers(in_prefix, key)
elif target == 'legacy_diffusers':
elif target == "legacy_diffusers":
return self.__get_legacy_diffusers(in_prefix, key)
return key
@ -82,8 +85,8 @@ class LoraConversionKeySet:
def combine(left: str, right: str) -> str:
left = left.rstrip('.')
right = right.lstrip('.')
left = left.rstrip(".")
right = right.lstrip(".")
if left == "" or left is None:
return right
elif right == "" or right is None:
@ -93,25 +96,28 @@ def combine(left: str, right: str) -> str:
def map_prefix_range(
omi_prefix: str,
diffusers_prefix: str,
parent: LoraConversionKeySet,
omi_prefix: str,
diffusers_prefix: str,
parent: LoraConversionKeySet,
) -> list[LoraConversionKeySet]:
# 100 should be a safe upper bound. increase if it's not enough in the future
return [LoraConversionKeySet(
omi_prefix=f"{omi_prefix}.{i}",
diffusers_prefix=f"{diffusers_prefix}.{i}",
parent=parent,
next_omi_prefix=f"{omi_prefix}.{i + 1}",
next_diffusers_prefix=f"{diffusers_prefix}.{i + 1}",
) for i in range(100)]
return [
LoraConversionKeySet(
omi_prefix=f"{omi_prefix}.{i}",
diffusers_prefix=f"{diffusers_prefix}.{i}",
parent=parent,
next_omi_prefix=f"{omi_prefix}.{i + 1}",
next_diffusers_prefix=f"{diffusers_prefix}.{i + 1}",
)
for i in range(100)
]
def __convert(
state_dict: dict[str, Tensor],
key_sets: list[LoraConversionKeySet],
source: str,
target: str,
state_dict: dict[str, Tensor],
key_sets: list[LoraConversionKeySet],
source: str,
target: str,
) -> dict[str, Tensor]:
out_states = {}
@ -121,13 +127,13 @@ def __convert(
# TODO: maybe replace with a non O(n^2) algorithm
for key, tensor in state_dict.items():
for key_set in key_sets:
in_prefix = ''
in_prefix = ""
if source == 'omi':
if source == "omi":
in_prefix = key_set.omi_prefix
elif source == 'diffusers':
elif source == "diffusers":
in_prefix = key_set.diffusers_prefix
elif source == 'legacy_diffusers':
elif source == "legacy_diffusers":
in_prefix = key_set.legacy_diffusers_prefix
if not key.startswith(in_prefix):
@ -135,11 +141,11 @@ def __convert(
if key_set.filter_is_last is not None:
next_prefix = None
if source == 'omi':
if source == "omi":
next_prefix = key_set.next_omi_prefix
elif source == 'diffusers':
elif source == "diffusers":
next_prefix = key_set.next_diffusers_prefix
elif source == 'legacy_diffusers':
elif source == "legacy_diffusers":
next_prefix = key_set.next_legacy_diffusers_prefix
is_last = not any(k.startswith(next_prefix) for k in state_dict)
@ -148,8 +154,8 @@ def __convert(
name = key_set.get_key(in_prefix, key, target)
can_swap_chunks = target == 'omi' or source == 'omi'
if key_set.swap_chunks and name.endswith('.lora_up.weight') and can_swap_chunks:
can_swap_chunks = target == "omi" or source == "omi"
if key_set.swap_chunks and name.endswith(".lora_up.weight") and can_swap_chunks:
chunk_0, chunk_1 = tensor.chunk(2, dim=0)
tensor = torch.cat([chunk_1, chunk_0], dim=0)
@ -161,8 +167,8 @@ def __convert(
def __detect_source(
state_dict: dict[str, Tensor],
key_sets: list[LoraConversionKeySet],
state_dict: dict[str, Tensor],
key_sets: list[LoraConversionKeySet],
) -> str:
omi_count = 0
diffusers_count = 0
@ -178,34 +184,34 @@ def __detect_source(
legacy_diffusers_count += 1
if omi_count > diffusers_count and omi_count > legacy_diffusers_count:
return 'omi'
return "omi"
if diffusers_count > omi_count and diffusers_count > legacy_diffusers_count:
return 'diffusers'
return "diffusers"
if legacy_diffusers_count > omi_count and legacy_diffusers_count > diffusers_count:
return 'legacy_diffusers'
return "legacy_diffusers"
return ''
return ""
def convert_to_omi(
state_dict: dict[str, Tensor],
key_sets: list[LoraConversionKeySet],
state_dict: dict[str, Tensor],
key_sets: list[LoraConversionKeySet],
) -> dict[str, Tensor]:
source = __detect_source(state_dict, key_sets)
return __convert(state_dict, key_sets, source, 'omi')
return __convert(state_dict, key_sets, source, "omi")
def convert_to_diffusers(
state_dict: dict[str, Tensor],
key_sets: list[LoraConversionKeySet],
state_dict: dict[str, Tensor],
key_sets: list[LoraConversionKeySet],
) -> dict[str, Tensor]:
source = __detect_source(state_dict, key_sets)
return __convert(state_dict, key_sets, source, 'diffusers')
return __convert(state_dict, key_sets, source, "diffusers")
def convert_to_legacy_diffusers(
state_dict: dict[str, Tensor],
key_sets: list[LoraConversionKeySet],
state_dict: dict[str, Tensor],
key_sets: list[LoraConversionKeySet],
) -> dict[str, Tensor]:
source = __detect_source(state_dict, key_sets)
return __convert(state_dict, key_sets, source, 'legacy_diffusers')
return __convert(state_dict, key_sets, source, "legacy_diffusers")

View File

@ -1,5 +1,8 @@
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_clip import map_clip
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import LoraConversionKeySet, map_prefix_range
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import (
LoraConversionKeySet,
map_prefix_range,
)
def __map_unet_resnet_block(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:
@ -115,7 +118,7 @@ def convert_sdxl_lora_key_sets() -> list[LoraConversionKeySet]:
keys = []
keys += [LoraConversionKeySet("bundle_emb", "bundle_emb")]
keys += __map_unet(LoraConversionKeySet( "unet", "lora_unet"))
keys += __map_unet(LoraConversionKeySet("unet", "lora_unet"))
keys += map_clip(LoraConversionKeySet("clip_l", "lora_te1"))
keys += map_clip(LoraConversionKeySet("clip_g", "lora_te2"))

View File

@ -1,4 +1,7 @@
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import LoraConversionKeySet, map_prefix_range
from invokeai.backend.model_manager.omi.vendor.convert.lora.convert_lora_util import (
LoraConversionKeySet,
map_prefix_range,
)
def map_t5(key_prefix: LoraConversionKeySet) -> list[LoraConversionKeySet]:

View File

@ -1,31 +1,31 @@
stable_diffusion_1_lora = 'stable-diffusion-v1/lora'
stable_diffusion_1_inpainting_lora = 'stable-diffusion-v1-inpainting/lora'
stable_diffusion_1_lora = "stable-diffusion-v1/lora"
stable_diffusion_1_inpainting_lora = "stable-diffusion-v1-inpainting/lora"
stable_diffusion_2_512_lora = 'stable-diffusion-v2-512/lora'
stable_diffusion_2_768_v_lora = 'stable-diffusion-v2-768-v/lora'
stable_diffusion_2_depth_lora = 'stable-diffusion-v2-depth/lora'
stable_diffusion_2_inpainting_lora = 'stable-diffusion-v2-inpainting/lora'
stable_diffusion_2_512_lora = "stable-diffusion-v2-512/lora"
stable_diffusion_2_768_v_lora = "stable-diffusion-v2-768-v/lora"
stable_diffusion_2_depth_lora = "stable-diffusion-v2-depth/lora"
stable_diffusion_2_inpainting_lora = "stable-diffusion-v2-inpainting/lora"
stable_diffusion_3_medium_lora = 'stable-diffusion-v3-medium/lora'
stable_diffusion_35_medium_lora = 'stable-diffusion-v3.5-medium/lora'
stable_diffusion_35_large_lora = 'stable-diffusion-v3.5-large/lora'
stable_diffusion_3_medium_lora = "stable-diffusion-v3-medium/lora"
stable_diffusion_35_medium_lora = "stable-diffusion-v3.5-medium/lora"
stable_diffusion_35_large_lora = "stable-diffusion-v3.5-large/lora"
stable_diffusion_xl_1_lora = 'stable-diffusion-xl-v1-base/lora'
stable_diffusion_xl_1_inpainting_lora = 'stable-diffusion-xl-v1-base-inpainting/lora'
stable_diffusion_xl_1_lora = "stable-diffusion-xl-v1-base/lora"
stable_diffusion_xl_1_inpainting_lora = "stable-diffusion-xl-v1-base-inpainting/lora"
wuerstchen_2_lora = 'wuerstchen-v2-prior/lora'
stable_cascade_1_stage_a_lora = 'stable-cascade-v1-stage-a/lora'
stable_cascade_1_stage_b_lora = 'stable-cascade-v1-stage-b/lora'
stable_cascade_1_stage_c_lora = 'stable-cascade-v1-stage-c/lora'
wuerstchen_2_lora = "wuerstchen-v2-prior/lora"
stable_cascade_1_stage_a_lora = "stable-cascade-v1-stage-a/lora"
stable_cascade_1_stage_b_lora = "stable-cascade-v1-stage-b/lora"
stable_cascade_1_stage_c_lora = "stable-cascade-v1-stage-c/lora"
pixart_alpha_lora = 'pixart-alpha/lora'
pixart_sigma_lora = 'pixart-sigma/lora'
pixart_alpha_lora = "pixart-alpha/lora"
pixart_sigma_lora = "pixart-sigma/lora"
flux_dev_1_lora = 'Flux.1-dev/lora'
flux_fill_dev_1_lora = 'Flux.1-fill-dev/lora'
flux_dev_1_lora = "Flux.1-dev/lora"
flux_fill_dev_1_lora = "Flux.1-fill-dev/lora"
sana_lora = 'sana/lora'
sana_lora = "sana/lora"
hunyuan_video_lora = 'hunyuan-video/lora'
hunyuan_video_lora = "hunyuan-video/lora"
hi_dream_i1_lora = 'hidream-i1/lora'
hi_dream_i1_lora = "hidream-i1/lora"