mirror of
https://github.com/invoke-ai/InvokeAI
synced 2025-07-26 05:17:55 +00:00
Use OMI conversion utils
This commit is contained in:
@ -357,15 +357,12 @@ class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
|
||||
|
||||
if "stable-diffusion-v1" in base_str:
|
||||
base = BaseModelType.StableDiffusion1
|
||||
elif "stable-diffusion-v2" in base_str:
|
||||
base = BaseModelType.StableDiffusion2
|
||||
elif "stable-diffusion-v3" in base_str:
|
||||
base = BaseModelType.StableDiffusion3
|
||||
elif base_str == "stable-diffusion-xl-v1-base":
|
||||
base = BaseModelType.StableDiffusionXL
|
||||
elif "flux" in base_str:
|
||||
base = BaseModelType.Flux
|
||||
|
||||
else:
|
||||
raise InvalidModelConfigException(f"Unrecognised base architecture for OMI LoRA: {base_str}")
|
||||
|
||||
|
@ -13,7 +13,7 @@ from invokeai.backend.model_manager.config import AnyModelConfig
|
||||
from invokeai.backend.model_manager.load.load_default import ModelLoader
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache import ModelCache
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.omi import convert_from_omi
|
||||
from invokeai.backend.model_manager.omi import convert_to_omi
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
AnyModel,
|
||||
BaseModelType,
|
||||
@ -41,7 +41,6 @@ from invokeai.backend.patches.lora_conversions.sdxl_lora_conversion_utils import
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.LoRA, format=ModelFormat.OMI)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.LoRA, format=ModelFormat.OMI)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.LoRA, format=ModelFormat.OMI)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion3, type=ModelType.LoRA, format=ModelFormat.OMI)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusionXL, type=ModelType.LoRA, format=ModelFormat.OMI)
|
||||
|
||||
@ -80,7 +79,7 @@ class LoRALoader(ModelLoader):
|
||||
state_dict = torch.load(model_path, map_location="cpu")
|
||||
|
||||
if config.format == ModelFormat.OMI:
|
||||
state_dict = convert_from_omi(state_dict)
|
||||
state_dict = convert_to_omi(state_dict. config.base) # type: ignore
|
||||
|
||||
# Apply state_dict key conversions, if necessary.
|
||||
if self._model_base == BaseModelType.StableDiffusionXL:
|
||||
|
@ -1,44 +1,17 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
logger = InvokeAILogger.get_logger()
|
||||
from invokeai.backend.model_manager.model_on_disk import StateDict
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType
|
||||
from omi_model_standards.convert.lora.convert_sdxl_lora import convert_sdxl_lora_key_sets
|
||||
from omi_model_standards.convert.lora.convert_flux_lora import convert_flux_lora_key_sets
|
||||
from omi_model_standards.convert.lora.convert_sd_lora import convert_sd_lora_key_sets
|
||||
from omi_model_standards.convert.lora.convert_sd3_lora import convert_sd3_lora_key_sets
|
||||
import omi_model_standards.convert.lora.convert_lora_util as lora_util
|
||||
|
||||
|
||||
def convert_from_omi(weights_sd):
|
||||
# convert from OMI to default LoRA
|
||||
# OMI format: {"prefix.module.name.lora_down.weight": weight, "prefix.module.name.lora_up.weight": weight, ...}
|
||||
# default LoRA format: {"prefix_module_name.lora_down.weight": weight, "prefix_module_name.lora_up.weight": weight, ...}
|
||||
|
||||
new_weights_sd = {}
|
||||
prefix = "lora_unet_"
|
||||
lora_dims = {}
|
||||
weight_dtype = None
|
||||
for key, weight in weights_sd.items():
|
||||
omi_prefix, key_body = key.split(".", 1)
|
||||
if omi_prefix != "diffusion":
|
||||
logger.warning(f"unexpected key: {key} in OMI format") # T5, CLIP, etc.
|
||||
continue
|
||||
|
||||
# only supports lora_down, lora_up and alpha
|
||||
new_key = (
|
||||
f"{prefix}{key_body}".replace(".", "_")
|
||||
.replace("_lora_down_", ".lora_down.")
|
||||
.replace("_lora_up_", ".lora_up.")
|
||||
.replace("_alpha", ".alpha")
|
||||
)
|
||||
new_weights_sd[new_key] = weight
|
||||
|
||||
lora_name = new_key.split(".")[0] # before first dot
|
||||
if lora_name not in lora_dims and "lora_down" in new_key:
|
||||
lora_dims[lora_name] = weight.shape[0]
|
||||
if weight_dtype is None:
|
||||
weight_dtype = weight.dtype # use first weight dtype for lora_down
|
||||
|
||||
# add alpha with rank
|
||||
for lora_name, dim in lora_dims.items():
|
||||
alpha_key = f"{lora_name}.alpha"
|
||||
if alpha_key not in new_weights_sd:
|
||||
new_weights_sd[alpha_key] = torch.tensor(dim, dtype=weight_dtype)
|
||||
|
||||
return new_weights_sd
|
||||
def convert_to_omi(weights_sd: StateDict, base: BaseModelType):
|
||||
keyset = {
|
||||
BaseModelType.Flux: convert_flux_lora_key_sets(),
|
||||
BaseModelType.StableDiffusionXL: convert_sdxl_lora_key_sets(),
|
||||
BaseModelType.StableDiffusion1: convert_sd_lora_key_sets(),
|
||||
BaseModelType.StableDiffusion3: convert_sd3_lora_key_sets(),
|
||||
}[base]
|
||||
return lora_util.convert_to_omi(weights_sd, keyset)
|
||||
|
@ -74,6 +74,7 @@ dependencies = [
|
||||
"python-multipart",
|
||||
"requests",
|
||||
"semver~=3.0.1",
|
||||
"omi-model-standards @ git+https://github.com/Open-Model-Initiative/OMI-Model-Standards.git@4ad235ceba6b42a97942834b7664379e4ec2d93c"
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
Reference in New Issue
Block a user