mirror of
https://github.com/invoke-ai/InvokeAI
synced 2025-07-26 05:17:55 +00:00
feat(LoRA): support AI Toolkit LoRA for FLUX [WIP]
This commit is contained in:
committed by
psychedelicious
parent
3df7cfd605
commit
5c5108c28a
@ -296,7 +296,7 @@ class LoRAConfigBase(ABC, BaseModel):
|
||||
from invokeai.backend.patches.lora_conversions.formats import flux_format_from_state_dict
|
||||
|
||||
sd = mod.load_state_dict(mod.path)
|
||||
value = flux_format_from_state_dict(sd)
|
||||
value = flux_format_from_state_dict(sd, mod.metadata())
|
||||
mod.cache[key] = value
|
||||
return value
|
||||
|
||||
|
@ -137,6 +137,7 @@ class FluxLoRAFormat(str, Enum):
|
||||
Kohya = "flux.kohya"
|
||||
OneTrainer = "flux.onetrainer"
|
||||
Control = "flux.control"
|
||||
AIToolkit = "flux.aitoolkit"
|
||||
|
||||
|
||||
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None]
|
||||
|
@ -0,0 +1,41 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
|
||||
lora_layers_from_flux_diffusers_grouped_state_dict,
|
||||
)
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
|
||||
|
||||
def is_state_dict_likely_in_aitoolkit_format(state_dict: dict[str, Any], metadata: dict[str, Any]) -> bool:
|
||||
if metadata:
|
||||
software = json.loads(metadata.get("software", "{}"))
|
||||
return software.get("name") == "ai-toolkit"
|
||||
# metadata got lost somewhere
|
||||
return any("diffusion_model" == k.split(".", 1)[0] for k in state_dict.keys())
|
||||
|
||||
|
||||
def lora_model_from_aitoolkit_state_dict(state_dict: dict[str, torch.Tensor]) -> ModelPatchRaw:
|
||||
# Group keys by layer.
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = defaultdict(dict)
|
||||
for key, value in state_dict.items():
|
||||
layer_name, param_name = key.split(".", 1)
|
||||
grouped_state_dict[layer_name][param_name] = value
|
||||
|
||||
transformer_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
|
||||
|
||||
for layer_name, layer_state_dict in grouped_state_dict.items():
|
||||
if layer_name.startswith("diffusion_model"):
|
||||
transformer_grouped_sd[layer_name] = layer_state_dict
|
||||
else:
|
||||
raise ValueError(f"Layer '{layer_name}' does not match the expected pattern for FLUX LoRA weights.")
|
||||
|
||||
layers: dict[str, BaseLayerPatch] = lora_layers_from_flux_diffusers_grouped_state_dict(
|
||||
transformer_grouped_sd, alpha=None
|
||||
)
|
||||
|
||||
return ModelPatchRaw(layers=layers)
|
@ -1,4 +1,7 @@
|
||||
from invokeai.backend.model_manager.taxonomy import FluxLoRAFormat
|
||||
from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_aitoolkit_format,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
|
||||
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_diffusers_format,
|
||||
@ -11,7 +14,9 @@ from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_u
|
||||
)
|
||||
|
||||
|
||||
def flux_format_from_state_dict(state_dict):
|
||||
def flux_format_from_state_dict(state_dict: dict, metadata: dict | None = None) -> FluxLoRAFormat | None:
|
||||
if is_state_dict_likely_in_aitoolkit_format(state_dict, metadata):
|
||||
return FluxLoRAFormat.AIToolkit
|
||||
if is_state_dict_likely_in_flux_kohya_format(state_dict):
|
||||
return FluxLoRAFormat.Kohya
|
||||
elif is_state_dict_likely_in_flux_onetrainer_format(state_dict):
|
||||
|
Reference in New Issue
Block a user