Use OMI conversion utils

This commit is contained in:
Billy
2025-06-19 09:40:49 +10:00
parent 2876c72fa9
commit 45d09f8f51
4 changed files with 18 additions and 48 deletions

View File

@ -357,15 +357,12 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
if "stable-diffusion-v1" in base_str:
base = BaseModelType.StableDiffusion1
elif "stable-diffusion-v2" in base_str:
base = BaseModelType.StableDiffusion2
elif "stable-diffusion-v3" in base_str:
base = BaseModelType.StableDiffusion3
elif base_str == "stable-diffusion-xl-v1-base":
base = BaseModelType.StableDiffusionXL
elif "flux" in base_str:
base = BaseModelType.Flux
else:
raise InvalidModelConfigException(f"Unrecognised base architecture for OMI LoRA: {base_str}")

View File

@ -13,7 +13,7 @@ from invokeai.backend.model_manager.config import AnyModelConfig
from invokeai.backend.model_manager.load.load_default import ModelLoader
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.omi import convert_from_omi
from invokeai.backend.model_manager.omi import convert_to_omi
from invokeai.backend.model_manager.taxonomy import (
AnyModel,
BaseModelType,
@ -41,7 +41,6 @@ from invokeai.backend.patches.lora_conversions.sdxl_lora_conversion_utils import
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.LoRA, format=ModelFormat.OMI)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.LoRA, format=ModelFormat.OMI)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.LoRA, format=ModelFormat.OMI)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion3, type=ModelType.LoRA, format=ModelFormat.OMI)
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.LoRA, format=ModelFormat.OMI)
@ -80,7 +79,7 @@ class LoRALoader(ModelLoader):
state_dict = torch.load(model_path, map_location="cpu")
if config.format == ModelFormat.OMI:
state_dict = convert_from_omi(state_dict)
state_dict = convert_to_omi(state_dict. config.base) # type: ignore
# Apply state_dict key conversions, if necessary.
if self._model_base == BaseModelType.StableDiffusionXL:

View File

@ -1,44 +1,17 @@
import torch
from invokeai.backend.util.logging import InvokeAILogger
logger = InvokeAILogger.get_logger()
from invokeai.backend.model_manager.model_on_disk import StateDict
from invokeai.backend.model_manager.taxonomy import BaseModelType
from omi_model_standards.convert.lora.convert_sdxl_lora import convert_sdxl_lora_key_sets
from omi_model_standards.convert.lora.convert_flux_lora import convert_flux_lora_key_sets
from omi_model_standards.convert.lora.convert_sd_lora import convert_sd_lora_key_sets
from omi_model_standards.convert.lora.convert_sd3_lora import convert_sd3_lora_key_sets
import omi_model_standards.convert.lora.convert_lora_util as lora_util
def convert_from_omi(weights_sd):
# convert from OMI to default LoRA
# OMI format: {"prefix.module.name.lora_down.weight": weight, "prefix.module.name.lora_up.weight": weight, ...}
# default LoRA format: {"prefix_module_name.lora_down.weight": weight, "prefix_module_name.lora_up.weight": weight, ...}
new_weights_sd = {}
prefix = "lora_unet_"
lora_dims = {}
weight_dtype = None
for key, weight in weights_sd.items():
omi_prefix, key_body = key.split(".", 1)
if omi_prefix != "diffusion":
logger.warning(f"unexpected key: {key} in OMI format") # T5, CLIP, etc.
continue
# only supports lora_down, lora_up and alpha
new_key = (
f"{prefix}{key_body}".replace(".", "_")
.replace("_lora_down_", ".lora_down.")
.replace("_lora_up_", ".lora_up.")
.replace("_alpha", ".alpha")
)
new_weights_sd[new_key] = weight
lora_name = new_key.split(".")[0] # before first dot
if lora_name not in lora_dims and "lora_down" in new_key:
lora_dims[lora_name] = weight.shape[0]
if weight_dtype is None:
weight_dtype = weight.dtype # use first weight dtype for lora_down
# add alpha with rank
for lora_name, dim in lora_dims.items():
alpha_key = f"{lora_name}.alpha"
if alpha_key not in new_weights_sd:
new_weights_sd[alpha_key] = torch.tensor(dim, dtype=weight_dtype)
return new_weights_sd
def convert_to_omi(weights_sd: StateDict, base: BaseModelType):
keyset = {
BaseModelType.Flux: convert_flux_lora_key_sets(),
BaseModelType.StableDiffusionXL: convert_sdxl_lora_key_sets(),
BaseModelType.StableDiffusion1: convert_sd_lora_key_sets(),
BaseModelType.StableDiffusion3: convert_sd3_lora_key_sets(),
}[base]
return lora_util.convert_to_omi(weights_sd, keyset)

View File

@ -74,6 +74,7 @@ dependencies = [
"python-multipart",
"requests",
"semver~=3.0.1",
"omi-model-standards @ git+https://github.com/Open-Model-Initiative/OMI-Model-Standards.git@4ad235ceba6b42a97942834b7664379e4ec2d93c"
]
[project.optional-dependencies]