diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index e22f74c767..87c2feab5b 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -26,7 +26,7 @@ from typing import Literal, Optional, Type, Union import torch from diffusers import ModelMixin -from pydantic import BaseModel, ConfigDict, Field, TypeAdapter +from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter from typing_extensions import Annotated, Any, Dict from ..raw_model import RawModel @@ -126,8 +126,6 @@ class ModelConfigBase(BaseModel): path: str = Field(description="filesystem path to the model file or directory") name: str = Field(description="model name") base: BaseModelType = Field(description="base model") - type: ModelType = Field(description="type of the model") - format: ModelFormat = Field(description="model format") key: str = Field(description="unique key for model", default="") original_hash: Optional[str] = Field( description="original fasthash of model contents", default=None @@ -171,11 +169,26 @@ class _DiffusersConfig(ModelConfigBase): repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT -class LoRAConfig(ModelConfigBase): +class LoRALycorisConfig(ModelConfigBase): """Model config for LoRA/Lycoris models.""" type: Literal[ModelType.Lora] = ModelType.Lora - format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers] + format: Literal[ModelFormat.Lycoris] = ModelFormat.Lycoris + + @staticmethod + def get_tag() -> Tag: + return Tag(f"{ModelType.Lora}.{ModelFormat.Lycoris}") + + +class LoRADiffusersConfig(ModelConfigBase): + """Model config for LoRA/Diffusers models.""" + + type: Literal[ModelType.Lora] = ModelType.Lora + format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers + + @staticmethod + def get_tag() -> Tag: + return Tag(f"{ModelType.Lora}.{ModelFormat.Diffusers}") class VaeCheckpointConfig(ModelConfigBase): @@ -184,6 +197,10 @@ class VaeCheckpointConfig(ModelConfigBase): type: Literal[ModelType.Vae] = ModelType.Vae format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint + @staticmethod + def get_tag() -> Tag: + return Tag(f"{ModelType.Vae}.{ModelFormat.Checkpoint}") + class VaeDiffusersConfig(ModelConfigBase): """Model config for standalone VAE models (diffusers version).""" @@ -191,6 +208,10 @@ class VaeDiffusersConfig(ModelConfigBase): type: Literal[ModelType.Vae] = ModelType.Vae format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers + @staticmethod + def get_tag() -> Tag: + return Tag(f"{ModelType.Vae}.{ModelFormat.Diffusers}") + class ControlNetDiffusersConfig(_DiffusersConfig): """Model config for ControlNet models (diffusers version).""" @@ -198,6 +219,10 @@ class ControlNetDiffusersConfig(_DiffusersConfig): type: Literal[ModelType.ControlNet] = ModelType.ControlNet format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers + @staticmethod + def get_tag() -> Tag: + return Tag(f"{ModelType.ControlNet}.{ModelFormat.Diffusers}") + class ControlNetCheckpointConfig(_CheckpointConfig): """Model config for ControlNet models (diffusers version).""" @@ -205,12 +230,31 @@ class ControlNetCheckpointConfig(_CheckpointConfig): type: Literal[ModelType.ControlNet] = ModelType.ControlNet format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint + @staticmethod + def get_tag() -> Tag: + return Tag(f"{ModelType.ControlNet}.{ModelFormat.Checkpoint}") -class TextualInversionConfig(ModelConfigBase): + +class TextualInversionFileConfig(ModelConfigBase): """Model config for textual inversion embeddings.""" type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion - format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder] + format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile + + @staticmethod + def get_tag() -> Tag: + return Tag(f"{ModelType.TextualInversion}.{ModelFormat.EmbeddingFile}") + + +class TextualInversionFolderConfig(ModelConfigBase): + """Model config for textual inversion embeddings.""" + + type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion + format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder + + @staticmethod + def get_tag() -> Tag: + return Tag(f"{ModelType.TextualInversion}.{ModelFormat.EmbeddingFolder}") class _MainConfig(ModelConfigBase): @@ -228,12 +272,20 @@ class MainCheckpointConfig(_CheckpointConfig, _MainConfig): type: Literal[ModelType.Main] = ModelType.Main + @staticmethod + def get_tag() -> Tag: + return Tag(f"{ModelType.Main}.{ModelFormat.Checkpoint}") + class MainDiffusersConfig(_DiffusersConfig, _MainConfig): """Model config for main diffusers models.""" type: Literal[ModelType.Main] = ModelType.Main + @staticmethod + def get_tag() -> Tag: + return Tag(f"{ModelType.Main}.{ModelFormat.Diffusers}") + class IPAdapterConfig(ModelConfigBase): """Model config for IP Adaptor format models.""" @@ -242,6 +294,10 @@ class IPAdapterConfig(ModelConfigBase): image_encoder_model_id: str format: Literal[ModelFormat.InvokeAI] + @staticmethod + def get_tag() -> Tag: + return Tag(f"{ModelType.IPAdapter}.{ModelFormat.InvokeAI}") + class CLIPVisionDiffusersConfig(ModelConfigBase): """Model config for ClipVision.""" @@ -249,36 +305,53 @@ class CLIPVisionDiffusersConfig(ModelConfigBase): type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision format: Literal[ModelFormat.Diffusers] + @staticmethod + def get_tag() -> Tag: + return Tag(f"{ModelType.CLIPVision}.{ModelFormat.Diffusers}") -class T2IConfig(ModelConfigBase): + +class T2IAdapterConfig(ModelConfigBase): """Model config for T2I.""" type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter format: Literal[ModelFormat.Diffusers] + @staticmethod + def get_tag() -> Tag: + return Tag(f"{ModelType.T2IAdapter}.{ModelFormat.Diffusers}") -_ControlNetConfig = Annotated[ - Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig], - Field(discriminator="format"), -] -_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")] -_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")] -AnyModelConfig = Union[ - _MainModelConfig, - _VaeConfig, - _ControlNetConfig, - # ModelConfigBase, - LoRAConfig, - TextualInversionConfig, - IPAdapterConfig, - CLIPVisionDiffusersConfig, - T2IConfig, +def get_model_discriminator_value(v: Any) -> str: + """ + Computes the discriminator value for a model config. + https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator + """ + if isinstance(v, dict): + return f"{v.get('type')}.{v.get('format')}" # pyright: ignore [reportUnknownMemberType] + return f"{v.getattr('type')}.{v.getattr('format')}" + + +AnyModelConfig = Annotated[ + Union[ + Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()], + Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()], + Annotated[VaeDiffusersConfig, VaeDiffusersConfig.get_tag()], + Annotated[VaeCheckpointConfig, VaeCheckpointConfig.get_tag()], + Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()], + Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()], + Annotated[LoRALycorisConfig, LoRALycorisConfig.get_tag()], + Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()], + Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()], + Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()], + Annotated[IPAdapterConfig, IPAdapterConfig.get_tag()], + Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()], + Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()], + ], + Discriminator(get_model_discriminator_value), ] AnyModelConfigValidator = TypeAdapter(AnyModelConfig) - # IMPLEMENTATION NOTE: # The preferred alternative to the above is a discriminated Union as shown # below. However, it breaks FastAPI when used as the input Body parameter in a route.