mirror of
https://github.com/invoke-ai/InvokeAI
synced 2025-07-26 05:17:55 +00:00
Convert from OMI to default LoRA state dict
This commit is contained in:
@ -11,6 +11,7 @@ from safetensors.torch import load_file
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.omi import convert_from_omi
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
@ -73,6 +74,10 @@ class LoRALoader(ModelLoader):
|
||||
else:
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
|
||||
if config.format == ModelFormat.OMI:
|
||||
state_dict = convert_from_omi(state_dict)
|
||||
|
||||
|
||||
# Apply state_dict key conversions, if necessary.
|
||||
if self._model_base == BaseModelType.StableDiffusionXL:
|
||||
state_dict = convert_sdxl_keys_to_diffusers_format(state_dict)
|
||||
@ -85,7 +90,7 @@ class LoRALoader(ModelLoader):
|
||||
# is a popular choice. For example, in the diffusers training scripts:
|
||||
# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_flux.py#L1194
|
||||
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None)
|
||||
elif config.format == ModelFormat.LyCORIS:
|
||||
elif config.format in [ModelFormat.LyCORIS, ModelFormat.OMI]:
|
||||
if is_state_dict_likely_in_flux_kohya_format(state_dict=state_dict):
|
||||
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
|
||||
elif is_state_dict_likely_in_flux_onetrainer_format(state_dict=state_dict):
|
||||
|
43
invokeai/backend/model_manager/omi.py
Normal file
43
invokeai/backend/model_manager/omi.py
Normal file
@ -0,0 +1,43 @@
|
||||
import torch
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
logger = InvokeAILogger.get_logger()
|
||||
|
||||
|
||||
def convert_from_omi(weights_sd):
|
||||
# convert from OMI to default LoRA
|
||||
# OMI format: {"prefix.module.name.lora_down.weight": weight, "prefix.module.name.lora_up.weight": weight, ...}
|
||||
# default LoRA format: {"prefix_module_name.lora_down.weight": weight, "prefix_module_name.lora_up.weight": weight, ...}
|
||||
|
||||
new_weights_sd = {}
|
||||
prefix = "lora_unet_"
|
||||
lora_dims = {}
|
||||
weight_dtype = None
|
||||
for key, weight in weights_sd.items():
|
||||
omi_prefix, key_body = key.split(".", 1)
|
||||
if omi_prefix != "diffusion":
|
||||
logger.warning(f"unexpected key: {key} in OMI format") # T5, CLIP, etc.
|
||||
continue
|
||||
|
||||
# only supports lora_down, lora_up and alpha
|
||||
new_key = (
|
||||
f"{prefix}{key_body}".replace(".", "_")
|
||||
.replace("_lora_down_", ".lora_down.")
|
||||
.replace("_lora_up_", ".lora_up.")
|
||||
.replace("_alpha", ".alpha")
|
||||
)
|
||||
new_weights_sd[new_key] = weight
|
||||
|
||||
lora_name = new_key.split(".")[0] # before first dot
|
||||
if lora_name not in lora_dims and "lora_down" in new_key:
|
||||
lora_dims[lora_name] = weight.shape[0]
|
||||
if weight_dtype is None:
|
||||
weight_dtype = weight.dtype # use first weight dtype for lora_down
|
||||
|
||||
# add alpha with rank
|
||||
for lora_name, dim in lora_dims.items():
|
||||
alpha_key = f"{lora_name}.alpha"
|
||||
if alpha_key not in new_weights_sd:
|
||||
new_weights_sd[alpha_key] = torch.tensor(dim, dtype=weight_dtype)
|
||||
|
||||
return new_weights_sd
|
Reference in New Issue
Block a user