fix(LoRA): add ai-toolkit to lora loader

This commit is contained in:
Kevin Turner
2025-05-30 16:31:28 -07:00
committed by psychedelicious
parent 5c5108c28a
commit ab8c739cd8
2 changed files with 8 additions and 2 deletions

View File

@ -20,6 +20,10 @@ from invokeai.backend.model_manager.taxonomy import (
ModelType,
SubModelType,
)
from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import (
is_state_dict_likely_in_aitoolkit_format,
lora_model_from_flux_aitoolkit_state_dict,
)
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import (
is_state_dict_likely_flux_control,
lora_model_from_flux_control_state_dict,
@ -92,6 +96,8 @@ class LoRALoader(ModelLoader):
model = lora_model_from_flux_onetrainer_state_dict(state_dict=state_dict)
elif is_state_dict_likely_flux_control(state_dict=state_dict):
model = lora_model_from_flux_control_state_dict(state_dict=state_dict)
elif is_state_dict_likely_in_aitoolkit_format(state_dict=state_dict):
model = lora_model_from_flux_aitoolkit_state_dict(state_dict=state_dict)
else:
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
else:

View File

@ -11,7 +11,7 @@ from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_ut
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
def is_state_dict_likely_in_aitoolkit_format(state_dict: dict[str, Any], metadata: dict[str, Any]) -> bool:
def is_state_dict_likely_in_aitoolkit_format(state_dict: dict[str, Any], metadata: dict[str, Any] = None) -> bool:
if metadata:
software = json.loads(metadata.get("software", "{}"))
return software.get("name") == "ai-toolkit"
@ -19,7 +19,7 @@ def is_state_dict_likely_in_aitoolkit_format(state_dict: dict[str, Any], metadat
return any("diffusion_model" == k.split(".", 1)[0] for k in state_dict.keys())
def lora_model_from_aitoolkit_state_dict(state_dict: dict[str, torch.Tensor]) -> ModelPatchRaw:
def lora_model_from_flux_aitoolkit_state_dict(state_dict: dict[str, torch.Tensor]) -> ModelPatchRaw:
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = defaultdict(dict)
for key, value in state_dict.items():