feat(mm): use callable discriminator for AnyModelConfig union

This commit is contained in:
psychedelicious 2024-03-01 12:57:46 +11:00
parent 8b34f5298c
commit 316573df2d

View File

@ -26,7 +26,7 @@ from typing import Literal, Optional, Type, Union
import torch import torch
from diffusers import ModelMixin from diffusers import ModelMixin
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict from typing_extensions import Annotated, Any, Dict
from ..raw_model import RawModel from ..raw_model import RawModel
@ -126,8 +126,6 @@ class ModelConfigBase(BaseModel):
path: str = Field(description="filesystem path to the model file or directory") path: str = Field(description="filesystem path to the model file or directory")
name: str = Field(description="model name") name: str = Field(description="model name")
base: BaseModelType = Field(description="base model") base: BaseModelType = Field(description="base model")
type: ModelType = Field(description="type of the model")
format: ModelFormat = Field(description="model format")
key: str = Field(description="unique key for model", default="<NOKEY>") key: str = Field(description="unique key for model", default="<NOKEY>")
original_hash: Optional[str] = Field( original_hash: Optional[str] = Field(
description="original fasthash of model contents", default=None description="original fasthash of model contents", default=None
@ -171,11 +169,26 @@ class _DiffusersConfig(ModelConfigBase):
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT
class LoRAConfig(ModelConfigBase): class LoRALycorisConfig(ModelConfigBase):
"""Model config for LoRA/Lycoris models.""" """Model config for LoRA/Lycoris models."""
type: Literal[ModelType.Lora] = ModelType.Lora type: Literal[ModelType.Lora] = ModelType.Lora
format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers] format: Literal[ModelFormat.Lycoris] = ModelFormat.Lycoris
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Lora}.{ModelFormat.Lycoris}")
class LoRADiffusersConfig(ModelConfigBase):
"""Model config for LoRA/Diffusers models."""
type: Literal[ModelType.Lora] = ModelType.Lora
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Lora}.{ModelFormat.Diffusers}")
class VaeCheckpointConfig(ModelConfigBase): class VaeCheckpointConfig(ModelConfigBase):
@ -184,6 +197,10 @@ class VaeCheckpointConfig(ModelConfigBase):
type: Literal[ModelType.Vae] = ModelType.Vae type: Literal[ModelType.Vae] = ModelType.Vae
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Vae}.{ModelFormat.Checkpoint}")
class VaeDiffusersConfig(ModelConfigBase): class VaeDiffusersConfig(ModelConfigBase):
"""Model config for standalone VAE models (diffusers version).""" """Model config for standalone VAE models (diffusers version)."""
@ -191,6 +208,10 @@ class VaeDiffusersConfig(ModelConfigBase):
type: Literal[ModelType.Vae] = ModelType.Vae type: Literal[ModelType.Vae] = ModelType.Vae
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Vae}.{ModelFormat.Diffusers}")
class ControlNetDiffusersConfig(_DiffusersConfig): class ControlNetDiffusersConfig(_DiffusersConfig):
"""Model config for ControlNet models (diffusers version).""" """Model config for ControlNet models (diffusers version)."""
@ -198,6 +219,10 @@ class ControlNetDiffusersConfig(_DiffusersConfig):
type: Literal[ModelType.ControlNet] = ModelType.ControlNet type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.ControlNet}.{ModelFormat.Diffusers}")
class ControlNetCheckpointConfig(_CheckpointConfig): class ControlNetCheckpointConfig(_CheckpointConfig):
"""Model config for ControlNet models (diffusers version).""" """Model config for ControlNet models (diffusers version)."""
@ -205,12 +230,31 @@ class ControlNetCheckpointConfig(_CheckpointConfig):
type: Literal[ModelType.ControlNet] = ModelType.ControlNet type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.ControlNet}.{ModelFormat.Checkpoint}")
class TextualInversionConfig(ModelConfigBase):
class TextualInversionFileConfig(ModelConfigBase):
"""Model config for textual inversion embeddings.""" """Model config for textual inversion embeddings."""
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder] format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.TextualInversion}.{ModelFormat.EmbeddingFile}")
class TextualInversionFolderConfig(ModelConfigBase):
"""Model config for textual inversion embeddings."""
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.TextualInversion}.{ModelFormat.EmbeddingFolder}")
class _MainConfig(ModelConfigBase): class _MainConfig(ModelConfigBase):
@ -228,12 +272,20 @@ class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
type: Literal[ModelType.Main] = ModelType.Main type: Literal[ModelType.Main] = ModelType.Main
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main}.{ModelFormat.Checkpoint}")
class MainDiffusersConfig(_DiffusersConfig, _MainConfig): class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
"""Model config for main diffusers models.""" """Model config for main diffusers models."""
type: Literal[ModelType.Main] = ModelType.Main type: Literal[ModelType.Main] = ModelType.Main
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main}.{ModelFormat.Diffusers}")
class IPAdapterConfig(ModelConfigBase): class IPAdapterConfig(ModelConfigBase):
"""Model config for IP Adaptor format models.""" """Model config for IP Adaptor format models."""
@ -242,6 +294,10 @@ class IPAdapterConfig(ModelConfigBase):
image_encoder_model_id: str image_encoder_model_id: str
format: Literal[ModelFormat.InvokeAI] format: Literal[ModelFormat.InvokeAI]
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.IPAdapter}.{ModelFormat.InvokeAI}")
class CLIPVisionDiffusersConfig(ModelConfigBase): class CLIPVisionDiffusersConfig(ModelConfigBase):
"""Model config for ClipVision.""" """Model config for ClipVision."""
@ -249,36 +305,53 @@ class CLIPVisionDiffusersConfig(ModelConfigBase):
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
format: Literal[ModelFormat.Diffusers] format: Literal[ModelFormat.Diffusers]
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPVision}.{ModelFormat.Diffusers}")
class T2IConfig(ModelConfigBase):
class T2IAdapterConfig(ModelConfigBase):
"""Model config for T2I.""" """Model config for T2I."""
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
format: Literal[ModelFormat.Diffusers] format: Literal[ModelFormat.Diffusers]
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.T2IAdapter}.{ModelFormat.Diffusers}")
_ControlNetConfig = Annotated[
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
Field(discriminator="format"),
]
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
AnyModelConfig = Union[ def get_model_discriminator_value(v: Any) -> str:
_MainModelConfig, """
_VaeConfig, Computes the discriminator value for a model config.
_ControlNetConfig, https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
# ModelConfigBase, """
LoRAConfig, if isinstance(v, dict):
TextualInversionConfig, return f"{v.get('type')}.{v.get('format')}" # pyright: ignore [reportUnknownMemberType]
IPAdapterConfig, return f"{v.getattr('type')}.{v.getattr('format')}"
CLIPVisionDiffusersConfig,
T2IConfig,
AnyModelConfig = Annotated[
Union[
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
Annotated[VaeDiffusersConfig, VaeDiffusersConfig.get_tag()],
Annotated[VaeCheckpointConfig, VaeCheckpointConfig.get_tag()],
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
Annotated[LoRALycorisConfig, LoRALycorisConfig.get_tag()],
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
Annotated[IPAdapterConfig, IPAdapterConfig.get_tag()],
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),
] ]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig) AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
# IMPLEMENTATION NOTE: # IMPLEMENTATION NOTE:
# The preferred alternative to the above is a discriminated Union as shown # The preferred alternative to the above is a discriminated Union as shown
# below. However, it breaks FastAPI when used as the input Body parameter in a route. # below. However, it breaks FastAPI when used as the input Body parameter in a route.