test: add some aitoolkit lora tests

This commit is contained in:
Kevin Turner
2025-05-31 16:02:57 -07:00
committed by psychedelicious
parent b08f90c99f
commit 2981591c36
5 changed files with 548 additions and 18 deletions

View File

@ -21,7 +21,7 @@ from invokeai.backend.model_manager.taxonomy import (
SubModelType,
)
from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import (
is_state_dict_likely_in_aitoolkit_format,
is_state_dict_likely_in_flux_aitoolkit_format,
lora_model_from_flux_aitoolkit_state_dict,
)
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import (
@ -96,7 +96,7 @@ class LoRALoader(ModelLoader):
model = lora_model_from_flux_onetrainer_state_dict(state_dict=state_dict)
elif is_state_dict_likely_flux_control(state_dict=state_dict):
model = lora_model_from_flux_control_state_dict(state_dict=state_dict)
elif is_state_dict_likely_in_aitoolkit_format(state_dict=state_dict):
elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict=state_dict):
model = lora_model_from_flux_aitoolkit_state_dict(state_dict=state_dict)
else:
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")

View File

@ -1,5 +1,6 @@
import json
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any
import torch
@ -8,33 +9,44 @@ from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.util import InvokeAILogger
def is_state_dict_likely_in_aitoolkit_format(state_dict: dict[str, Any], metadata: dict[str, Any] = None) -> bool:
def is_state_dict_likely_in_flux_aitoolkit_format(state_dict: dict[str, Any], metadata: dict[str, Any] = None) -> bool:
if metadata:
software = json.loads(metadata.get("software", "{}"))
try:
software = json.loads(metadata.get("software", "{}"))
except json.JSONDecodeError:
return False
return software.get("name") == "ai-toolkit"
# metadata got lost somewhere
return any("diffusion_model" == k.split(".", 1)[0] for k in state_dict.keys())
def lora_model_from_flux_aitoolkit_state_dict(state_dict: dict[str, torch.Tensor]) -> ModelPatchRaw:
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = defaultdict(dict)
@dataclass
class GroupedStateDict:
transformer: dict = field(default_factory=dict)
# might also grow CLIP and T5 submodels
def _group_state_by_submodel(state_dict: dict[str, torch.Tensor]) -> GroupedStateDict:
logger = InvokeAILogger.get_logger()
grouped = GroupedStateDict()
for key, value in state_dict.items():
layer_name, param_name = key.split(".", 1)
grouped_state_dict[layer_name][param_name] = value
submodel_name, param_name = key.split(".", 1)
match submodel_name:
case "diffusion_model":
grouped.transformer[param_name] = value
case _:
logger.warning(f"Unexpected submodel name: {submodel_name}")
return grouped
transformer_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
for layer_name, layer_state_dict in grouped_state_dict.items():
if layer_name.startswith("diffusion_model"):
transformer_grouped_sd[layer_name] = layer_state_dict
else:
raise ValueError(f"Layer '{layer_name}' does not match the expected pattern for FLUX LoRA weights.")
def lora_model_from_flux_aitoolkit_state_dict(state_dict: dict[str, torch.Tensor]) -> ModelPatchRaw:
grouped = _group_state_by_submodel(state_dict)
layers: dict[str, BaseLayerPatch] = {}
for layer_key, layer_state_dict in transformer_grouped_sd.items():
for layer_key, layer_state_dict in grouped.transformer.items():
layers[FLUX_LORA_TRANSFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
return ModelPatchRaw(layers=layers)

View File

@ -1,6 +1,6 @@
from invokeai.backend.model_manager.taxonomy import FluxLoRAFormat
from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import (
is_state_dict_likely_in_aitoolkit_format,
is_state_dict_likely_in_flux_aitoolkit_format,
)
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
@ -15,7 +15,7 @@ from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_u
def flux_format_from_state_dict(state_dict: dict, metadata: dict | None = None) -> FluxLoRAFormat | None:
if is_state_dict_likely_in_aitoolkit_format(state_dict, metadata):
if is_state_dict_likely_in_flux_aitoolkit_format(state_dict, metadata):
return FluxLoRAFormat.AIToolkit
if is_state_dict_likely_in_flux_kohya_format(state_dict):
return FluxLoRAFormat.Kohya

View File

@ -0,0 +1,458 @@
state_dict_keys = {
"diffusion_model.double_blocks.0.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.0.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.0.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.0.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.0.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.0.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.0.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.0.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.0.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.0.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.0.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.0.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.0.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.0.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.0.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.0.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.1.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.1.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.1.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.1.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.1.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.1.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.1.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.1.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.1.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.1.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.1.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.1.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.1.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.1.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.1.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.1.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.10.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.10.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.10.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.10.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.10.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.10.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.10.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.10.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.10.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.10.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.10.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.10.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.10.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.10.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.10.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.10.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.11.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.11.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.11.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.11.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.11.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.11.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.11.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.11.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.11.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.11.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.11.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.11.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.11.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.11.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.11.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.11.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.12.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.12.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.12.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.12.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.12.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.12.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.12.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.12.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.12.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.12.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.12.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.12.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.12.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.12.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.12.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.12.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.13.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.13.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.13.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.13.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.13.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.13.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.13.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.13.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.13.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.13.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.13.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.13.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.13.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.13.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.13.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.13.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.14.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.14.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.14.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.14.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.14.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.14.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.14.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.14.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.14.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.14.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.14.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.14.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.14.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.14.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.14.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.14.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.15.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.15.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.15.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.15.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.15.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.15.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.15.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.15.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.15.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.15.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.15.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.15.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.15.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.15.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.15.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.15.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.16.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.16.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.16.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.16.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.16.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.16.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.16.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.16.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.16.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.16.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.16.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.16.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.16.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.16.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.16.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.16.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.17.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.17.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.17.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.17.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.17.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.17.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.17.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.17.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.17.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.17.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.17.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.17.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.17.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.17.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.17.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.17.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.18.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.18.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.18.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.18.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.18.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.18.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.18.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.18.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.18.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.18.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.18.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.18.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.18.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.18.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.18.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.18.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.2.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.2.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.2.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.2.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.2.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.2.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.2.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.2.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.2.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.2.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.2.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.2.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.2.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.2.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.2.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.2.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.3.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.3.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.3.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.3.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.3.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.3.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.3.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.3.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.3.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.3.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.3.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.3.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.3.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.3.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.3.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.3.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.4.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.4.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.4.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.4.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.4.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.4.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.4.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.4.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.4.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.4.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.4.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.4.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.4.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.4.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.4.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.4.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.5.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.5.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.5.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.5.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.5.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.5.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.5.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.5.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.5.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.5.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.5.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.5.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.5.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.5.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.5.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.5.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.6.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.6.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.6.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.6.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.6.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.6.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.6.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.6.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.6.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.6.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.6.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.6.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.6.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.6.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.6.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.6.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.7.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.7.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.7.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.7.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.7.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.7.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.7.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.7.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.7.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.7.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.7.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.7.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.7.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.7.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.7.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.7.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.8.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.8.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.8.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.8.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.8.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.8.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.8.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.8.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.8.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.8.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.8.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.8.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.8.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.8.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.8.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.8.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.9.img_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.9.img_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.9.img_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.9.img_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.9.img_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.9.img_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.9.img_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.9.img_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.9.txt_attn.proj.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.9.txt_attn.proj.lora_B.weight": [3072, 16],
"diffusion_model.double_blocks.9.txt_attn.qkv.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.9.txt_attn.qkv.lora_B.weight": [9216, 16],
"diffusion_model.double_blocks.9.txt_mlp.0.lora_A.weight": [16, 3072],
"diffusion_model.double_blocks.9.txt_mlp.0.lora_B.weight": [12288, 16],
"diffusion_model.double_blocks.9.txt_mlp.2.lora_A.weight": [16, 12288],
"diffusion_model.double_blocks.9.txt_mlp.2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.0.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.0.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.0.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.0.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.1.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.1.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.1.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.1.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.10.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.10.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.10.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.10.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.11.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.11.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.11.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.11.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.12.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.12.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.12.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.12.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.13.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.13.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.13.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.13.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.14.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.14.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.14.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.14.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.15.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.15.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.15.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.15.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.16.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.16.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.16.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.16.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.17.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.17.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.17.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.17.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.18.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.18.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.18.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.18.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.19.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.19.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.19.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.19.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.2.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.2.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.2.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.2.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.20.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.20.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.20.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.20.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.21.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.21.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.21.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.21.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.22.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.22.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.22.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.22.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.23.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.23.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.23.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.23.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.24.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.24.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.24.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.24.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.25.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.25.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.25.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.25.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.26.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.26.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.26.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.26.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.27.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.27.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.27.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.27.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.28.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.28.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.28.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.28.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.29.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.29.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.29.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.29.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.3.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.3.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.3.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.3.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.30.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.30.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.30.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.30.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.31.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.31.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.31.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.31.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.32.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.32.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.32.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.32.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.33.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.33.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.33.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.33.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.34.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.34.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.34.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.34.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.35.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.35.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.35.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.35.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.36.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.36.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.36.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.36.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.37.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.37.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.37.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.37.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.4.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.4.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.4.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.4.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.5.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.5.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.5.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.5.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.6.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.6.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.6.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.6.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.7.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.7.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.7.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.7.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.8.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.8.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.8.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.8.linear2.lora_B.weight": [3072, 16],
"diffusion_model.single_blocks.9.linear1.lora_A.weight": [16, 3072],
"diffusion_model.single_blocks.9.linear1.lora_B.weight": [21504, 16],
"diffusion_model.single_blocks.9.linear2.lora_A.weight": [16, 15360],
"diffusion_model.single_blocks.9.linear2.lora_B.weight": [3072, 16],
}

View File

@ -0,0 +1,60 @@
import accelerate
import pytest
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import params
from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import (
_group_state_by_submodel,
is_state_dict_likely_in_flux_aitoolkit_format,
lora_model_from_flux_aitoolkit_state_dict,
)
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_dora_onetrainer_format import (
state_dict_keys as flux_onetrainer_state_dict_keys,
)
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_diffusers_format import (
state_dict_keys as flux_diffusers_state_dict_keys,
)
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_aitoolkit_format import state_dict_keys as flux_aitoolkit_state_dict_keys
from tests.backend.patches.lora_conversions.lora_state_dicts.utils import keys_to_mock_state_dict
def test_is_state_dict_likely_in_flux_aitoolkit_format():
state_dict = keys_to_mock_state_dict(flux_aitoolkit_state_dict_keys)
assert is_state_dict_likely_in_flux_aitoolkit_format(state_dict)
@pytest.mark.parametrize("sd_keys", [flux_diffusers_state_dict_keys, flux_onetrainer_state_dict_keys])
def test_is_state_dict_likely_in_flux_kohya_format_false(sd_keys: dict[str, list[int]]):
state_dict = keys_to_mock_state_dict(sd_keys)
assert not is_state_dict_likely_in_flux_aitoolkit_format(state_dict)
def test_flux_aitoolkit_transformer_state_dict_is_in_invoke_format():
state_dict = keys_to_mock_state_dict(flux_aitoolkit_state_dict_keys)
converted_state_dict = _group_state_by_submodel(state_dict).transformer
# Extract the prefixes from the converted state dict (without the lora suffixes)
converted_key_prefixes: list[str] = []
for k in converted_state_dict.keys():
k = k.replace(".lora_A.weight", "")
k = k.replace(".lora_B.weight", "")
converted_key_prefixes.append(k)
# Initialize a FLUX model on the meta device.
with accelerate.init_empty_weights():
model = Flux(params["flux-schnell"])
model_keys = set(model.state_dict().keys())
for converted_key_prefix in converted_key_prefixes:
assert any(model_key.startswith(converted_key_prefix) for model_key in model_keys), f"'{converted_key_prefix}' did not match any model keys."
def test_lora_model_from_flux_aitoolkit_state_dict():
state_dict = keys_to_mock_state_dict(flux_aitoolkit_state_dict_keys)
lora_model = lora_model_from_flux_aitoolkit_state_dict(state_dict)
# Assert that the lora_model has the expected layers.
# lora_model_keys = set(lora_model.layers.keys())
# lora_model_keys = {k.replace(".", "_") for k in lora_model_keys}
# assert lora_model_keys == expected_layer_keys