mirror of
https://github.com/invoke-ai/InvokeAI
synced 2025-07-26 05:17:55 +00:00
Add OMI LoRA config
This commit is contained in:
@ -31,6 +31,7 @@ from pathlib import Path
|
||||
from typing import ClassVar, Literal, Optional, TypeAlias, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
||||
from pydantic.main import Model
|
||||
from typing_extensions import Annotated, Any, Dict
|
||||
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
@ -334,6 +335,44 @@ class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase, LegacyProbeMixin,
|
||||
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
|
||||
|
||||
|
||||
class LoRAOmiConfig(LoRAConfigBase, ModelConfigBase):
|
||||
format: Literal[ModelFormat.OMI] = ModelFormat.OMI
|
||||
|
||||
@classmethod
|
||||
def matches(cls, mod: ModelOnDisk) -> bool:
|
||||
if mod.path.is_dir():
|
||||
return False
|
||||
|
||||
metadata = mod.metadata()
|
||||
return (
|
||||
metadata.get("modelspec.sai_model_spec") and
|
||||
metadata.get("ot_branch") == "omi_format" and
|
||||
metadata["modelspec.architecture"].split("/")[1].lower() == "lora"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def parse(cls, mod: ModelOnDisk) -> dict[str, Any]:
|
||||
metadata = mod.metadata()
|
||||
base_str, _ = metadata["modelspec.architecture"].split("/")
|
||||
base_str = base_str.lower()
|
||||
|
||||
if "stable-diffusion-v1" in base_str:
|
||||
base = BaseModelType.StableDiffusion1
|
||||
elif "stable-diffusion-v2" in base_str:
|
||||
base = BaseModelType.StableDiffusion2
|
||||
elif "stable-diffusion-v3" in base_str:
|
||||
base = BaseModelType.StableDiffusion3
|
||||
elif base_str == "stable-diffusion-xl-v1-base":
|
||||
base = BaseModelType.StableDiffusionXL
|
||||
elif "flux" in base_str:
|
||||
base = BaseModelType.Flux
|
||||
|
||||
else:
|
||||
raise InvalidModelConfigException(f"Unrecognised base architecture for OMI LoRA: {base_str}")
|
||||
|
||||
return { "base": base }
|
||||
|
||||
|
||||
class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
|
||||
"""Model config for LoRA/Lycoris models."""
|
||||
|
||||
@ -668,6 +707,7 @@ AnyModelConfig = Annotated[
|
||||
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
||||
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
|
||||
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
|
||||
Annotated[LoRAOmiConfig, LoRAOmiConfig.get_tag()],
|
||||
Annotated[ControlLoRALyCORISConfig, ControlLoRALyCORISConfig.get_tag()],
|
||||
Annotated[ControlLoRADiffusersConfig, ControlLoRADiffusersConfig.get_tag()],
|
||||
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||
|
Reference in New Issue
Block a user