mirror of
https://github.com/invoke-ai/InvokeAI
synced 2025-07-26 05:17:55 +00:00
test: add some aitoolkit lora tests
This commit is contained in:
committed by
psychedelicious
parent
b08f90c99f
commit
2981591c36
@ -21,7 +21,7 @@ from invokeai.backend.model_manager.taxonomy import (
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_aitoolkit_format,
|
||||
is_state_dict_likely_in_flux_aitoolkit_format,
|
||||
lora_model_from_flux_aitoolkit_state_dict,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import (
|
||||
@ -96,7 +96,7 @@ class LoRALoader(ModelLoader):
|
||||
model = lora_model_from_flux_onetrainer_state_dict(state_dict=state_dict)
|
||||
elif is_state_dict_likely_flux_control(state_dict=state_dict):
|
||||
model = lora_model_from_flux_control_state_dict(state_dict=state_dict)
|
||||
elif is_state_dict_likely_in_aitoolkit_format(state_dict=state_dict):
|
||||
elif is_state_dict_likely_in_flux_aitoolkit_format(state_dict=state_dict):
|
||||
model = lora_model_from_flux_aitoolkit_state_dict(state_dict=state_dict)
|
||||
else:
|
||||
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
|
||||
|
@ -1,5 +1,6 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@ -8,33 +9,44 @@ from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
from invokeai.backend.util import InvokeAILogger
|
||||
|
||||
|
||||
def is_state_dict_likely_in_aitoolkit_format(state_dict: dict[str, Any], metadata: dict[str, Any] = None) -> bool:
|
||||
def is_state_dict_likely_in_flux_aitoolkit_format(state_dict: dict[str, Any], metadata: dict[str, Any] = None) -> bool:
|
||||
if metadata:
|
||||
software = json.loads(metadata.get("software", "{}"))
|
||||
try:
|
||||
software = json.loads(metadata.get("software", "{}"))
|
||||
except json.JSONDecodeError:
|
||||
return False
|
||||
return software.get("name") == "ai-toolkit"
|
||||
# metadata got lost somewhere
|
||||
return any("diffusion_model" == k.split(".", 1)[0] for k in state_dict.keys())
|
||||
|
||||
|
||||
def lora_model_from_flux_aitoolkit_state_dict(state_dict: dict[str, torch.Tensor]) -> ModelPatchRaw:
|
||||
# Group keys by layer.
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = defaultdict(dict)
|
||||
@dataclass
|
||||
class GroupedStateDict:
|
||||
transformer: dict = field(default_factory=dict)
|
||||
# might also grow CLIP and T5 submodels
|
||||
|
||||
|
||||
def _group_state_by_submodel(state_dict: dict[str, torch.Tensor]) -> GroupedStateDict:
|
||||
logger = InvokeAILogger.get_logger()
|
||||
grouped = GroupedStateDict()
|
||||
for key, value in state_dict.items():
|
||||
layer_name, param_name = key.split(".", 1)
|
||||
grouped_state_dict[layer_name][param_name] = value
|
||||
submodel_name, param_name = key.split(".", 1)
|
||||
match submodel_name:
|
||||
case "diffusion_model":
|
||||
grouped.transformer[param_name] = value
|
||||
case _:
|
||||
logger.warning(f"Unexpected submodel name: {submodel_name}")
|
||||
return grouped
|
||||
|
||||
transformer_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
|
||||
|
||||
for layer_name, layer_state_dict in grouped_state_dict.items():
|
||||
if layer_name.startswith("diffusion_model"):
|
||||
transformer_grouped_sd[layer_name] = layer_state_dict
|
||||
else:
|
||||
raise ValueError(f"Layer '{layer_name}' does not match the expected pattern for FLUX LoRA weights.")
|
||||
def lora_model_from_flux_aitoolkit_state_dict(state_dict: dict[str, torch.Tensor]) -> ModelPatchRaw:
|
||||
grouped = _group_state_by_submodel(state_dict)
|
||||
|
||||
layers: dict[str, BaseLayerPatch] = {}
|
||||
for layer_key, layer_state_dict in transformer_grouped_sd.items():
|
||||
for layer_key, layer_state_dict in grouped.transformer.items():
|
||||
layers[FLUX_LORA_TRANSFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
|
||||
|
||||
return ModelPatchRaw(layers=layers)
|
||||
|
@ -1,6 +1,6 @@
|
||||
from invokeai.backend.model_manager.taxonomy import FluxLoRAFormat
|
||||
from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_aitoolkit_format,
|
||||
is_state_dict_likely_in_flux_aitoolkit_format,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
|
||||
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
|
||||
@ -15,7 +15,7 @@ from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_u
|
||||
|
||||
|
||||
def flux_format_from_state_dict(state_dict: dict, metadata: dict | None = None) -> FluxLoRAFormat | None:
|
||||
if is_state_dict_likely_in_aitoolkit_format(state_dict, metadata):
|
||||
if is_state_dict_likely_in_flux_aitoolkit_format(state_dict, metadata):
|
||||
return FluxLoRAFormat.AIToolkit
|
||||
if is_state_dict_likely_in_flux_kohya_format(state_dict):
|
||||
return FluxLoRAFormat.Kohya
|
||||
|
@ -0,0 +1,458 @@
|
||||
state_dict_keys = {
|
||||
"diffusion_model.double_blocks.0.img_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.0.img_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.0.img_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.0.img_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.0.img_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.0.img_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.0.img_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.0.img_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.0.txt_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.0.txt_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.0.txt_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.0.txt_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.0.txt_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.0.txt_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.0.txt_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.0.txt_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.1.img_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.1.img_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.1.img_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.1.img_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.1.img_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.1.img_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.1.img_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.1.img_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.1.txt_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.1.txt_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.1.txt_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.1.txt_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.1.txt_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.1.txt_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.1.txt_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.1.txt_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.10.img_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.10.img_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.10.img_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.10.img_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.10.img_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.10.img_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.10.img_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.10.img_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.10.txt_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.10.txt_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.10.txt_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.10.txt_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.10.txt_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.10.txt_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.10.txt_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.10.txt_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.11.img_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.11.img_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.11.img_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.11.img_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.11.img_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.11.img_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.11.img_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.11.img_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.11.txt_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.11.txt_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.11.txt_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.11.txt_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.11.txt_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.11.txt_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.11.txt_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.11.txt_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.12.img_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.12.img_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.12.img_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.12.img_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.12.img_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.12.img_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.12.img_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.12.img_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.12.txt_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.12.txt_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.12.txt_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.12.txt_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.12.txt_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.12.txt_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.12.txt_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.12.txt_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.13.img_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.13.img_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.13.img_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.13.img_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.13.img_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.13.img_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.13.img_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.13.img_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.13.txt_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.13.txt_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.13.txt_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.13.txt_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.13.txt_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.13.txt_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.13.txt_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.13.txt_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.14.img_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.14.img_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.14.img_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.14.img_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.14.img_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.14.img_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.14.img_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.14.img_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.14.txt_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.14.txt_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.14.txt_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.14.txt_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.14.txt_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.14.txt_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.14.txt_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.14.txt_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.15.img_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.15.img_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.15.img_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.15.img_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.15.img_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.15.img_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.15.img_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.15.img_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.15.txt_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.15.txt_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.15.txt_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.15.txt_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.15.txt_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.15.txt_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.15.txt_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.15.txt_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.16.img_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.16.img_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.16.img_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.16.img_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.16.img_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.16.img_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.16.img_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.16.img_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.16.txt_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.16.txt_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.16.txt_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.16.txt_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.16.txt_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.16.txt_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.16.txt_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.16.txt_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.17.img_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.17.img_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.17.img_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.17.img_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.17.img_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.17.img_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.17.img_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.17.img_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.17.txt_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.17.txt_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.17.txt_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.17.txt_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.17.txt_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.17.txt_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.17.txt_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.17.txt_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.18.img_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.18.img_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.18.img_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.18.img_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.18.img_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.18.img_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.18.img_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.18.img_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.18.txt_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.18.txt_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.18.txt_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.18.txt_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.18.txt_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.18.txt_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.18.txt_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.18.txt_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.2.img_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.2.img_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.2.img_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.2.img_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.2.img_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.2.img_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.2.img_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.2.img_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.2.txt_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.2.txt_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.2.txt_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.2.txt_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.2.txt_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.2.txt_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.2.txt_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.2.txt_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.3.img_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.3.img_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.3.img_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.3.img_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.3.img_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.3.img_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.3.img_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.3.img_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.3.txt_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.3.txt_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.3.txt_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.3.txt_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.3.txt_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.3.txt_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.3.txt_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.3.txt_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.4.img_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.4.img_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.4.img_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.4.img_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.4.img_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.4.img_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.4.img_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.4.img_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.4.txt_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.4.txt_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.4.txt_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.4.txt_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.4.txt_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.4.txt_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.4.txt_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.4.txt_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.5.img_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.5.img_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.5.img_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.5.img_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.5.img_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.5.img_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.5.img_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.5.img_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.5.txt_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.5.txt_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.5.txt_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.5.txt_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.5.txt_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.5.txt_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.5.txt_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.5.txt_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.6.img_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.6.img_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.6.img_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.6.img_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.6.img_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.6.img_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.6.img_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.6.img_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.6.txt_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.6.txt_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.6.txt_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.6.txt_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.6.txt_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.6.txt_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.6.txt_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.6.txt_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.7.img_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.7.img_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.7.img_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.7.img_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.7.img_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.7.img_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.7.img_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.7.img_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.7.txt_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.7.txt_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.7.txt_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.7.txt_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.7.txt_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.7.txt_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.7.txt_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.7.txt_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.8.img_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.8.img_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.8.img_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.8.img_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.8.img_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.8.img_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.8.img_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.8.img_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.8.txt_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.8.txt_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.8.txt_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.8.txt_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.8.txt_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.8.txt_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.8.txt_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.8.txt_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.9.img_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.9.img_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.9.img_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.9.img_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.9.img_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.9.img_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.9.img_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.9.img_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.9.txt_attn.proj.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.9.txt_attn.proj.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.double_blocks.9.txt_attn.qkv.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.9.txt_attn.qkv.lora_B.weight": [9216, 16],
|
||||
"diffusion_model.double_blocks.9.txt_mlp.0.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.double_blocks.9.txt_mlp.0.lora_B.weight": [12288, 16],
|
||||
"diffusion_model.double_blocks.9.txt_mlp.2.lora_A.weight": [16, 12288],
|
||||
"diffusion_model.double_blocks.9.txt_mlp.2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.0.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.0.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.0.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.0.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.1.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.1.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.1.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.1.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.10.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.10.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.10.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.10.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.11.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.11.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.11.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.11.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.12.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.12.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.12.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.12.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.13.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.13.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.13.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.13.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.14.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.14.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.14.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.14.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.15.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.15.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.15.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.15.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.16.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.16.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.16.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.16.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.17.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.17.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.17.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.17.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.18.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.18.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.18.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.18.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.19.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.19.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.19.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.19.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.2.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.2.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.2.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.2.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.20.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.20.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.20.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.20.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.21.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.21.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.21.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.21.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.22.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.22.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.22.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.22.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.23.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.23.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.23.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.23.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.24.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.24.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.24.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.24.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.25.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.25.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.25.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.25.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.26.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.26.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.26.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.26.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.27.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.27.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.27.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.27.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.28.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.28.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.28.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.28.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.29.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.29.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.29.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.29.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.3.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.3.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.3.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.3.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.30.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.30.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.30.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.30.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.31.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.31.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.31.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.31.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.32.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.32.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.32.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.32.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.33.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.33.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.33.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.33.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.34.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.34.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.34.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.34.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.35.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.35.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.35.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.35.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.36.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.36.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.36.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.36.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.37.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.37.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.37.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.37.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.4.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.4.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.4.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.4.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.5.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.5.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.5.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.5.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.6.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.6.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.6.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.6.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.7.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.7.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.7.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.7.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.8.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.8.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.8.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.8.linear2.lora_B.weight": [3072, 16],
|
||||
"diffusion_model.single_blocks.9.linear1.lora_A.weight": [16, 3072],
|
||||
"diffusion_model.single_blocks.9.linear1.lora_B.weight": [21504, 16],
|
||||
"diffusion_model.single_blocks.9.linear2.lora_A.weight": [16, 15360],
|
||||
"diffusion_model.single_blocks.9.linear2.lora_B.weight": [3072, 16],
|
||||
}
|
@ -0,0 +1,60 @@
|
||||
import accelerate
|
||||
import pytest
|
||||
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.util import params
|
||||
from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import (
|
||||
_group_state_by_submodel,
|
||||
is_state_dict_likely_in_flux_aitoolkit_format,
|
||||
lora_model_from_flux_aitoolkit_state_dict,
|
||||
)
|
||||
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_dora_onetrainer_format import (
|
||||
state_dict_keys as flux_onetrainer_state_dict_keys,
|
||||
)
|
||||
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_diffusers_format import (
|
||||
state_dict_keys as flux_diffusers_state_dict_keys,
|
||||
)
|
||||
from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_aitoolkit_format import state_dict_keys as flux_aitoolkit_state_dict_keys
|
||||
from tests.backend.patches.lora_conversions.lora_state_dicts.utils import keys_to_mock_state_dict
|
||||
|
||||
|
||||
def test_is_state_dict_likely_in_flux_aitoolkit_format():
|
||||
state_dict = keys_to_mock_state_dict(flux_aitoolkit_state_dict_keys)
|
||||
assert is_state_dict_likely_in_flux_aitoolkit_format(state_dict)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sd_keys", [flux_diffusers_state_dict_keys, flux_onetrainer_state_dict_keys])
|
||||
def test_is_state_dict_likely_in_flux_kohya_format_false(sd_keys: dict[str, list[int]]):
|
||||
state_dict = keys_to_mock_state_dict(sd_keys)
|
||||
assert not is_state_dict_likely_in_flux_aitoolkit_format(state_dict)
|
||||
|
||||
|
||||
def test_flux_aitoolkit_transformer_state_dict_is_in_invoke_format():
|
||||
state_dict = keys_to_mock_state_dict(flux_aitoolkit_state_dict_keys)
|
||||
converted_state_dict = _group_state_by_submodel(state_dict).transformer
|
||||
|
||||
# Extract the prefixes from the converted state dict (without the lora suffixes)
|
||||
converted_key_prefixes: list[str] = []
|
||||
for k in converted_state_dict.keys():
|
||||
k = k.replace(".lora_A.weight", "")
|
||||
k = k.replace(".lora_B.weight", "")
|
||||
converted_key_prefixes.append(k)
|
||||
|
||||
# Initialize a FLUX model on the meta device.
|
||||
with accelerate.init_empty_weights():
|
||||
model = Flux(params["flux-schnell"])
|
||||
model_keys = set(model.state_dict().keys())
|
||||
|
||||
for converted_key_prefix in converted_key_prefixes:
|
||||
assert any(model_key.startswith(converted_key_prefix) for model_key in model_keys), f"'{converted_key_prefix}' did not match any model keys."
|
||||
|
||||
|
||||
def test_lora_model_from_flux_aitoolkit_state_dict():
|
||||
state_dict = keys_to_mock_state_dict(flux_aitoolkit_state_dict_keys)
|
||||
|
||||
lora_model = lora_model_from_flux_aitoolkit_state_dict(state_dict)
|
||||
|
||||
# Assert that the lora_model has the expected layers.
|
||||
# lora_model_keys = set(lora_model.layers.keys())
|
||||
# lora_model_keys = {k.replace(".", "_") for k in lora_model_keys}
|
||||
# assert lora_model_keys == expected_layer_keys
|
Reference in New Issue
Block a user