Convert from OMI to default LoRA state dict

This commit is contained in:
Billy
2025-06-17 13:56:22 +10:00
parent 85c4304efd
commit 84ab4a1c30
2 changed files with 49 additions and 1 deletions

View File

@ -11,6 +11,7 @@ from safetensors.torch import load_file
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.omi import convert_from_omi
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.taxonomy import (
@ -73,6 +74,10 @@ class LoRALoader(ModelLoader):
else:
state_dict = torch.load(model_path, map_location="cpu")
if config.format == ModelFormat.OMI:
state_dict = convert_from_omi(state_dict)
# Apply state_dict key conversions, if necessary.
if self._model_base == BaseModelType.StableDiffusionXL:
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
@ -85,7 +90,7 @@ class LoRALoader(ModelLoader):
# is a popular choice. For example, in the diffusers training scripts:
# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_flux.py#L1194
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None)
elif config.format == ModelFormat.LyCORIS:
elif config.format in [ModelFormat.LyCORIS, ModelFormat.OMI]:
if is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict):
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
elif is_state_dict_likely_in_flux_onetrainer_format(state_dict=state_dict):

View File

@ -0,0 +1,43 @@
import torch
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger()
def convert_from_omi(weights_sd):
# convert from OMI to default LoRA
# OMI format: {"prefix.module.name.lora_down.weight": weight, "prefix.module.name.lora_up.weight": weight, ...}
# default LoRA format: {"prefix_module_name.lora_down.weight": weight, "prefix_module_name.lora_up.weight": weight, ...}
new_weights_sd = {}
prefix = "lora_unet_"
lora_dims = {}
weight_dtype = None
for key, weight in weights_sd.items():
omi_prefix, key_body = key.split(".", 1)
if omi_prefix != "diffusion":
logger.warning(f"unexpected key: {key} in OMI format") # T5, CLIP, etc.
continue
# only supports lora_down, lora_up and alpha
new_key = (
f"{prefix}{key_body}".replace(".", "_")
.replace("_lora_down_", ".lora_down.")
.replace("_lora_up_", ".lora_up.")
.replace("_alpha", ".alpha")
)
new_weights_sd[new_key] = weight
lora_name = new_key.split(".")[0] # before first dot
if lora_name not in lora_dims and "lora_down" in new_key:
lora_dims[lora_name] = weight.shape[0]
if weight_dtype is None:
weight_dtype = weight.dtype # use first weight dtype for lora_down
# add alpha with rank
for lora_name, dim in lora_dims.items():
alpha_key = f"{lora_name}.alpha"
if alpha_key not in new_weights_sd:
new_weights_sd[alpha_key] = torch.tensor(dim, dtype=weight_dtype)
return new_weights_sd