feat(mm): use callable discriminator for AnyModelConfig union

This commit is contained in:
psychedelicious 2024-03-01 12:57:46 +11:00
parent 8b34f5298c
commit 316573df2d

View File

@ -26,7 +26,7 @@ from typing import Literal, Optional, Type, Union
import torch
from diffusers import ModelMixin
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
from typing_extensions import Annotated, Any, Dict
from ..raw_model import RawModel
@ -126,8 +126,6 @@ class ModelConfigBase(BaseModel):
path: str = Field(description="filesystem path to the model file or directory")
name: str = Field(description="model name")
base: BaseModelType = Field(description="base model")
type: ModelType = Field(description="type of the model")
format: ModelFormat = Field(description="model format")
key: str = Field(description="unique key for model", default="<NOKEY>")
original_hash: Optional[str] = Field(
description="original fasthash of model contents", default=None
@ -171,11 +169,26 @@ class _DiffusersConfig(ModelConfigBase):
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT
class LoRAConfig(ModelConfigBase):
class LoRALycorisConfig(ModelConfigBase):
"""Model config for LoRA/Lycoris models."""
type: Literal[ModelType.Lora] = ModelType.Lora
format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers]
format: Literal[ModelFormat.Lycoris] = ModelFormat.Lycoris
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Lora}.{ModelFormat.Lycoris}")
class LoRADiffusersConfig(ModelConfigBase):
"""Model config for LoRA/Diffusers models."""
type: Literal[ModelType.Lora] = ModelType.Lora
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Lora}.{ModelFormat.Diffusers}")
class VaeCheckpointConfig(ModelConfigBase):
@ -184,6 +197,10 @@ class VaeCheckpointConfig(ModelConfigBase):
type: Literal[ModelType.Vae] = ModelType.Vae
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Vae}.{ModelFormat.Checkpoint}")
class VaeDiffusersConfig(ModelConfigBase):
"""Model config for standalone VAE models (diffusers version)."""
@ -191,6 +208,10 @@ class VaeDiffusersConfig(ModelConfigBase):
type: Literal[ModelType.Vae] = ModelType.Vae
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Vae}.{ModelFormat.Diffusers}")
class ControlNetDiffusersConfig(_DiffusersConfig):
"""Model config for ControlNet models (diffusers version)."""
@ -198,6 +219,10 @@ class ControlNetDiffusersConfig(_DiffusersConfig):
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.ControlNet}.{ModelFormat.Diffusers}")
class ControlNetCheckpointConfig(_CheckpointConfig):
"""Model config for ControlNet models (diffusers version)."""
@ -205,12 +230,31 @@ class ControlNetCheckpointConfig(_CheckpointConfig):
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.ControlNet}.{ModelFormat.Checkpoint}")
class TextualInversionConfig(ModelConfigBase):
class TextualInversionFileConfig(ModelConfigBase):
"""Model config for textual inversion embeddings."""
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder]
format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.TextualInversion}.{ModelFormat.EmbeddingFile}")
class TextualInversionFolderConfig(ModelConfigBase):
"""Model config for textual inversion embeddings."""
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.TextualInversion}.{ModelFormat.EmbeddingFolder}")
class _MainConfig(ModelConfigBase):
@ -228,12 +272,20 @@ class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
type: Literal[ModelType.Main] = ModelType.Main
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main}.{ModelFormat.Checkpoint}")
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
"""Model config for main diffusers models."""
type: Literal[ModelType.Main] = ModelType.Main
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main}.{ModelFormat.Diffusers}")
class IPAdapterConfig(ModelConfigBase):
"""Model config for IP Adaptor format models."""
@ -242,6 +294,10 @@ class IPAdapterConfig(ModelConfigBase):
image_encoder_model_id: str
format: Literal[ModelFormat.InvokeAI]
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.IPAdapter}.{ModelFormat.InvokeAI}")
class CLIPVisionDiffusersConfig(ModelConfigBase):
"""Model config for ClipVision."""
@ -249,36 +305,53 @@ class CLIPVisionDiffusersConfig(ModelConfigBase):
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
format: Literal[ModelFormat.Diffusers]
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.CLIPVision}.{ModelFormat.Diffusers}")
class T2IConfig(ModelConfigBase):
class T2IAdapterConfig(ModelConfigBase):
"""Model config for T2I."""
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
format: Literal[ModelFormat.Diffusers]
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.T2IAdapter}.{ModelFormat.Diffusers}")
_ControlNetConfig = Annotated[
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
Field(discriminator="format"),
]
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
AnyModelConfig = Union[
_MainModelConfig,
_VaeConfig,
_ControlNetConfig,
# ModelConfigBase,
LoRAConfig,
TextualInversionConfig,
IPAdapterConfig,
CLIPVisionDiffusersConfig,
T2IConfig,
def get_model_discriminator_value(v: Any) -> str:
"""
Computes the discriminator value for a model config.
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
"""
if isinstance(v, dict):
return f"{v.get('type')}.{v.get('format')}" # pyright: ignore [reportUnknownMemberType]
return f"{v.getattr('type')}.{v.getattr('format')}"
AnyModelConfig = Annotated[
Union[
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
Annotated[VaeDiffusersConfig, VaeDiffusersConfig.get_tag()],
Annotated[VaeCheckpointConfig, VaeCheckpointConfig.get_tag()],
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
Annotated[LoRALycorisConfig, LoRALycorisConfig.get_tag()],
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
Annotated[IPAdapterConfig, IPAdapterConfig.get_tag()],
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
],
Discriminator(get_model_discriminator_value),
]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
# IMPLEMENTATION NOTE:
# The preferred alternative to the above is a discriminated Union as shown
# below. However, it breaks FastAPI when used as the input Body parameter in a route.