feat(LoRA): support AI Toolkit LoRA for FLUX [WIP]

This commit is contained in:
Kevin Turner
2025-05-30 16:08:00 -07:00
committed by psychedelicious
parent 3df7cfd605
commit 5c5108c28a
4 changed files with 49 additions and 2 deletions

View File

@ -296,7 +296,7 @@ class LoRAConfigBase(ABC, BaseModel):
from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict
sd = mod.load_state_dict(mod.path)
value = flux_format_from_state_dict(sd)
value = flux_format_from_state_dict(sd, mod.metadata())
mod.cache[key] = value
return value

View File

@ -137,6 +137,7 @@ class FluxLoRAFormat(str, Enum):
Kohya = "flux.kohya"
OneTrainer = "flux.onetrainer"
Control = "flux.control"
AIToolkit = "flux.aitoolkit"
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None]

View File

@ -0,0 +1,41 @@
import json
from collections import defaultdict
from typing import Any
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
lora_layers_from_flux_diffusers_grouped_state_dict,
)
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
def is_state_dict_likely_in_aitoolkit_format(state_dict: dict[str, Any], metadata: dict[str, Any]) -> bool:
if metadata:
software = json.loads(metadata.get("software", "{}"))
return software.get("name") == "ai-toolkit"
# metadata got lost somewhere
return any("diffusion_model" == k.split(".", 1)[0] for k in state_dict.keys())
def lora_model_from_aitoolkit_state_dict(state_dict: dict[str, torch.Tensor]) -> ModelPatchRaw:
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = defaultdict(dict)
for key, value in state_dict.items():
layer_name, param_name = key.split(".", 1)
grouped_state_dict[layer_name][param_name] = value
transformer_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
for layer_name, layer_state_dict in grouped_state_dict.items():
if layer_name.startswith("diffusion_model"):
transformer_grouped_sd[layer_name] = layer_state_dict
else:
raise ValueError(f"Layer '{layer_name}' does not match the expected pattern for FLUX LoRA weights.")
layers: dict[str, BaseLayerPatch] = lora_layers_from_flux_diffusers_grouped_state_dict(
transformer_grouped_sd, alpha=None
)
return ModelPatchRaw(layers=layers)

View File

@ -1,4 +1,7 @@
from invokeai.backend.model_manager.taxonomy import FluxLoRAFormat
from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import (
is_state_dict_likely_in_aitoolkit_format,
)
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
is_state_dict_likely_in_flux_diffusers_format,
@ -11,7 +14,9 @@ from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_u
)
def flux_format_from_state_dict(state_dict):
def flux_format_from_state_dict(state_dict: dict, metadata: dict | None = None) -> FluxLoRAFormat | None:
if is_state_dict_likely_in_aitoolkit_format(state_dict, metadata):
return FluxLoRAFormat.AIToolkit
if is_state_dict_likely_in_flux_kohya_format(state_dict):
return FluxLoRAFormat.Kohya
elif is_state_dict_likely_in_flux_onetrainer_format(state_dict):