mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(mm): use callable discriminator for AnyModelConfig
union
This commit is contained in:
parent
8b34f5298c
commit
316573df2d
@ -26,7 +26,7 @@ from typing import Literal, Optional, Type, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers import ModelMixin
|
from diffusers import ModelMixin
|
||||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
||||||
from typing_extensions import Annotated, Any, Dict
|
from typing_extensions import Annotated, Any, Dict
|
||||||
|
|
||||||
from ..raw_model import RawModel
|
from ..raw_model import RawModel
|
||||||
@ -126,8 +126,6 @@ class ModelConfigBase(BaseModel):
|
|||||||
path: str = Field(description="filesystem path to the model file or directory")
|
path: str = Field(description="filesystem path to the model file or directory")
|
||||||
name: str = Field(description="model name")
|
name: str = Field(description="model name")
|
||||||
base: BaseModelType = Field(description="base model")
|
base: BaseModelType = Field(description="base model")
|
||||||
type: ModelType = Field(description="type of the model")
|
|
||||||
format: ModelFormat = Field(description="model format")
|
|
||||||
key: str = Field(description="unique key for model", default="<NOKEY>")
|
key: str = Field(description="unique key for model", default="<NOKEY>")
|
||||||
original_hash: Optional[str] = Field(
|
original_hash: Optional[str] = Field(
|
||||||
description="original fasthash of model contents", default=None
|
description="original fasthash of model contents", default=None
|
||||||
@ -171,11 +169,26 @@ class _DiffusersConfig(ModelConfigBase):
|
|||||||
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT
|
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT
|
||||||
|
|
||||||
|
|
||||||
class LoRAConfig(ModelConfigBase):
|
class LoRALycorisConfig(ModelConfigBase):
|
||||||
"""Model config for LoRA/Lycoris models."""
|
"""Model config for LoRA/Lycoris models."""
|
||||||
|
|
||||||
type: Literal[ModelType.Lora] = ModelType.Lora
|
type: Literal[ModelType.Lora] = ModelType.Lora
|
||||||
format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers]
|
format: Literal[ModelFormat.Lycoris] = ModelFormat.Lycoris
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.Lora}.{ModelFormat.Lycoris}")
|
||||||
|
|
||||||
|
|
||||||
|
class LoRADiffusersConfig(ModelConfigBase):
|
||||||
|
"""Model config for LoRA/Diffusers models."""
|
||||||
|
|
||||||
|
type: Literal[ModelType.Lora] = ModelType.Lora
|
||||||
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.Lora}.{ModelFormat.Diffusers}")
|
||||||
|
|
||||||
|
|
||||||
class VaeCheckpointConfig(ModelConfigBase):
|
class VaeCheckpointConfig(ModelConfigBase):
|
||||||
@ -184,6 +197,10 @@ class VaeCheckpointConfig(ModelConfigBase):
|
|||||||
type: Literal[ModelType.Vae] = ModelType.Vae
|
type: Literal[ModelType.Vae] = ModelType.Vae
|
||||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.Vae}.{ModelFormat.Checkpoint}")
|
||||||
|
|
||||||
|
|
||||||
class VaeDiffusersConfig(ModelConfigBase):
|
class VaeDiffusersConfig(ModelConfigBase):
|
||||||
"""Model config for standalone VAE models (diffusers version)."""
|
"""Model config for standalone VAE models (diffusers version)."""
|
||||||
@ -191,6 +208,10 @@ class VaeDiffusersConfig(ModelConfigBase):
|
|||||||
type: Literal[ModelType.Vae] = ModelType.Vae
|
type: Literal[ModelType.Vae] = ModelType.Vae
|
||||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.Vae}.{ModelFormat.Diffusers}")
|
||||||
|
|
||||||
|
|
||||||
class ControlNetDiffusersConfig(_DiffusersConfig):
|
class ControlNetDiffusersConfig(_DiffusersConfig):
|
||||||
"""Model config for ControlNet models (diffusers version)."""
|
"""Model config for ControlNet models (diffusers version)."""
|
||||||
@ -198,6 +219,10 @@ class ControlNetDiffusersConfig(_DiffusersConfig):
|
|||||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.ControlNet}.{ModelFormat.Diffusers}")
|
||||||
|
|
||||||
|
|
||||||
class ControlNetCheckpointConfig(_CheckpointConfig):
|
class ControlNetCheckpointConfig(_CheckpointConfig):
|
||||||
"""Model config for ControlNet models (diffusers version)."""
|
"""Model config for ControlNet models (diffusers version)."""
|
||||||
@ -205,12 +230,31 @@ class ControlNetCheckpointConfig(_CheckpointConfig):
|
|||||||
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
||||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.ControlNet}.{ModelFormat.Checkpoint}")
|
||||||
|
|
||||||
class TextualInversionConfig(ModelConfigBase):
|
|
||||||
|
class TextualInversionFileConfig(ModelConfigBase):
|
||||||
"""Model config for textual inversion embeddings."""
|
"""Model config for textual inversion embeddings."""
|
||||||
|
|
||||||
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
||||||
format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder]
|
format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.TextualInversion}.{ModelFormat.EmbeddingFile}")
|
||||||
|
|
||||||
|
|
||||||
|
class TextualInversionFolderConfig(ModelConfigBase):
|
||||||
|
"""Model config for textual inversion embeddings."""
|
||||||
|
|
||||||
|
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
||||||
|
format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.TextualInversion}.{ModelFormat.EmbeddingFolder}")
|
||||||
|
|
||||||
|
|
||||||
class _MainConfig(ModelConfigBase):
|
class _MainConfig(ModelConfigBase):
|
||||||
@ -228,12 +272,20 @@ class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
|
|||||||
|
|
||||||
type: Literal[ModelType.Main] = ModelType.Main
|
type: Literal[ModelType.Main] = ModelType.Main
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.Main}.{ModelFormat.Checkpoint}")
|
||||||
|
|
||||||
|
|
||||||
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
|
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
|
||||||
"""Model config for main diffusers models."""
|
"""Model config for main diffusers models."""
|
||||||
|
|
||||||
type: Literal[ModelType.Main] = ModelType.Main
|
type: Literal[ModelType.Main] = ModelType.Main
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.Main}.{ModelFormat.Diffusers}")
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterConfig(ModelConfigBase):
|
class IPAdapterConfig(ModelConfigBase):
|
||||||
"""Model config for IP Adaptor format models."""
|
"""Model config for IP Adaptor format models."""
|
||||||
@ -242,6 +294,10 @@ class IPAdapterConfig(ModelConfigBase):
|
|||||||
image_encoder_model_id: str
|
image_encoder_model_id: str
|
||||||
format: Literal[ModelFormat.InvokeAI]
|
format: Literal[ModelFormat.InvokeAI]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.IPAdapter}.{ModelFormat.InvokeAI}")
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionDiffusersConfig(ModelConfigBase):
|
class CLIPVisionDiffusersConfig(ModelConfigBase):
|
||||||
"""Model config for ClipVision."""
|
"""Model config for ClipVision."""
|
||||||
@ -249,36 +305,53 @@ class CLIPVisionDiffusersConfig(ModelConfigBase):
|
|||||||
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
|
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
|
||||||
format: Literal[ModelFormat.Diffusers]
|
format: Literal[ModelFormat.Diffusers]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.CLIPVision}.{ModelFormat.Diffusers}")
|
||||||
|
|
||||||
class T2IConfig(ModelConfigBase):
|
|
||||||
|
class T2IAdapterConfig(ModelConfigBase):
|
||||||
"""Model config for T2I."""
|
"""Model config for T2I."""
|
||||||
|
|
||||||
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
|
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
|
||||||
format: Literal[ModelFormat.Diffusers]
|
format: Literal[ModelFormat.Diffusers]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.T2IAdapter}.{ModelFormat.Diffusers}")
|
||||||
|
|
||||||
_ControlNetConfig = Annotated[
|
|
||||||
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
|
|
||||||
Field(discriminator="format"),
|
|
||||||
]
|
|
||||||
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
|
|
||||||
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
|
|
||||||
|
|
||||||
AnyModelConfig = Union[
|
def get_model_discriminator_value(v: Any) -> str:
|
||||||
_MainModelConfig,
|
"""
|
||||||
_VaeConfig,
|
Computes the discriminator value for a model config.
|
||||||
_ControlNetConfig,
|
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
|
||||||
# ModelConfigBase,
|
"""
|
||||||
LoRAConfig,
|
if isinstance(v, dict):
|
||||||
TextualInversionConfig,
|
return f"{v.get('type')}.{v.get('format')}" # pyright: ignore [reportUnknownMemberType]
|
||||||
IPAdapterConfig,
|
return f"{v.getattr('type')}.{v.getattr('format')}"
|
||||||
CLIPVisionDiffusersConfig,
|
|
||||||
T2IConfig,
|
|
||||||
|
AnyModelConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
|
||||||
|
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
|
||||||
|
Annotated[VaeDiffusersConfig, VaeDiffusersConfig.get_tag()],
|
||||||
|
Annotated[VaeCheckpointConfig, VaeCheckpointConfig.get_tag()],
|
||||||
|
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
||||||
|
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
|
||||||
|
Annotated[LoRALycorisConfig, LoRALycorisConfig.get_tag()],
|
||||||
|
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||||
|
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
|
||||||
|
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
|
||||||
|
Annotated[IPAdapterConfig, IPAdapterConfig.get_tag()],
|
||||||
|
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
||||||
|
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
||||||
|
],
|
||||||
|
Discriminator(get_model_discriminator_value),
|
||||||
]
|
]
|
||||||
|
|
||||||
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
||||||
|
|
||||||
|
|
||||||
# IMPLEMENTATION NOTE:
|
# IMPLEMENTATION NOTE:
|
||||||
# The preferred alternative to the above is a discriminated Union as shown
|
# The preferred alternative to the above is a discriminated Union as shown
|
||||||
# below. However, it breaks FastAPI when used as the input Body parameter in a route.
|
# below. However, it breaks FastAPI when used as the input Body parameter in a route.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user