Add OMI LoRA config

This commit is contained in:
Billy
2025-06-17 13:34:03 +10:00
parent 8f152f162b
commit 85c4304efd

View File

@ -31,6 +31,7 @@ from pathlib import Path
from typing import ClassVar, Literal, Optional, TypeAlias, Union
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from pydantic.main import Model
from typing_extensions import Annotated, Any, Dict
from invokeai.app.util.misc import uuid_string
@ -334,6 +335,44 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, LegacyProbeMixin,
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
format: Literal[ModelFormat.OMI] = ModelFormat.OMI
@classmethod
def matches(cls, mod: ModelOnDisk) -> bool:
if mod.path.is_dir():
return False
metadata = mod.metadata()
return (
metadata.get("modelspec.sai_model_spec") and
metadata.get("ot_branch") == "omi_format" and
metadata["modelspec.architecture"].split("/")[1].lower() == "lora"
)
@classmethod
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
metadata = mod.metadata()
base_str, _ = metadata["modelspec.architecture"].split("/")
base_str = base_str.lower()
if "stable-diffusion-v1" in base_str:
base = BaseModelType.StableDiffusion1
elif "stable-diffusion-v2" in base_str:
base = BaseModelType.StableDiffusion2
elif "stable-diffusion-v3" in base_str:
base = BaseModelType.StableDiffusion3
elif base_str == "stable-diffusion-xl-v1-base":
base = BaseModelType.StableDiffusionXL
elif "flux" in base_str:
base = BaseModelType.Flux
else:
raise InvalidModelConfigException(f"Unrecognised base architecture for OMI LoRA: {base_str}")
return { "base": base }
class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
"""Model config for LoRA/Lycoris models."""
@ -668,6 +707,7 @@ AnyModelConfig = Annotated[
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
Annotated[LoRAOmiConfig, LoRAOmiConfig.get_tag()],
Annotated[ControlLoRALyCORISConfig, ControlLoRALyCORISConfig.get_tag()],
Annotated[ControlLoRADiffusersConfig, ControlLoRADiffusersConfig.get_tag()],
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],