mirror of
https://github.com/invoke-ai/InvokeAI
synced 2025-07-26 05:17:55 +00:00
fix(LoRA): add ai-toolkit to lora loader
This commit is contained in:
committed by
psychedelicious
parent
5c5108c28a
commit
ab8c739cd8
@ -20,6 +20,10 @@ from invokeai.backend.model_manager.taxonomy import (
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_aitoolkit_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_aitoolkit_format,
|
||||
lora_model_from_flux_aitoolkit_state_dict,
|
||||
)
|
||||
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import (
|
||||
is_state_dict_likely_flux_control,
|
||||
lora_model_from_flux_control_state_dict,
|
||||
@ -92,6 +96,8 @@ class LoRALoader(ModelLoader):
|
||||
model = lora_model_from_flux_onetrainer_state_dict(state_dict=state_dict)
|
||||
elif is_state_dict_likely_flux_control(state_dict=state_dict):
|
||||
model = lora_model_from_flux_control_state_dict(state_dict=state_dict)
|
||||
elif is_state_dict_likely_in_aitoolkit_format(state_dict=state_dict):
|
||||
model = lora_model_from_flux_aitoolkit_state_dict(state_dict=state_dict)
|
||||
else:
|
||||
raise ValueError(f"LoRA model is in unsupported FLUX format: {config.format}")
|
||||
else:
|
||||
|
@ -11,7 +11,7 @@ from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_ut
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
|
||||
|
||||
def is_state_dict_likely_in_aitoolkit_format(state_dict: dict[str, Any], metadata: dict[str, Any]) -> bool:
|
||||
def is_state_dict_likely_in_aitoolkit_format(state_dict: dict[str, Any], metadata: dict[str, Any] = None) -> bool:
|
||||
if metadata:
|
||||
software = json.loads(metadata.get("software", "{}"))
|
||||
return software.get("name") == "ai-toolkit"
|
||||
@ -19,7 +19,7 @@ def is_state_dict_likely_in_aitoolkit_format(state_dict: dict[str, Any], metadat
|
||||
return any("diffusion_model" == k.split(".", 1)[0] for k in state_dict.keys())
|
||||
|
||||
|
||||
def lora_model_from_aitoolkit_state_dict(state_dict: dict[str, torch.Tensor]) -> ModelPatchRaw:
|
||||
def lora_model_from_flux_aitoolkit_state_dict(state_dict: dict[str, torch.Tensor]) -> ModelPatchRaw:
|
||||
# Group keys by layer.
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = defaultdict(dict)
|
||||
for key, value in state_dict.items():
|
||||
|
Reference in New Issue
Block a user