replace _class_map in ModelConfigFactory with a nested discriminated union

This commit is contained in:
Lincoln Stein 2023-11-10 19:14:15 -05:00
parent bd56e9bc81
commit 3a6ba236f5
2 changed files with 39 additions and 68 deletions

View File

@ -18,8 +18,6 @@ from ..dependencies import ApiDependencies
model_records_router = APIRouter(prefix="/v1/model/record", tags=["models"])
ModelConfigValidator = TypeAdapter(AnyModelConfig)
class ModelsList(BaseModel):
"""Return list of configs."""

View File

@ -22,7 +22,8 @@ Validation errors will raise an InvalidModelConfigException error.
from enum import Enum
from typing import Literal, Optional, Type, Union
from pydantic import BaseModel, ConfigDict, Field, ValidationError
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from typing_extensions import Annotated
class InvalidModelConfigException(Exception):
@ -143,36 +144,42 @@ class _DiffusersConfig(ModelConfigBase):
class LoRAConfig(ModelConfigBase):
"""Model config for LoRA/Lycoris models."""
type: Literal[ModelType.Lora] = ModelType.Lora
format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers]
class VaeCheckpointConfig(ModelConfigBase):
"""Model config for standalone VAE models."""
type: Literal[ModelType.Vae] = ModelType.Vae
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
class VaeDiffusersConfig(ModelConfigBase):
"""Model config for standalone VAE models (diffusers version)."""
type: Literal[ModelType.Vae] = ModelType.Vae
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetDiffusersConfig(_DiffusersConfig):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetCheckpointConfig(_CheckpointConfig):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
class TextualInversionConfig(ModelConfigBase):
"""Model config for textual inversion embeddings."""
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder]
@ -187,6 +194,7 @@ class _MainConfig(ModelConfigBase):
class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
"""Model config for main checkpoint models."""
type: Literal[ModelType.Main] = ModelType.Main
# Note that we do not need prediction_type or upcast_attention here
# because they are provided in the checkpoint's own config file.
@ -194,6 +202,7 @@ class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
"""Model config for main diffusers models."""
type: Literal[ModelType.Main] = ModelType.Main
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
@ -201,7 +210,9 @@ class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
class ONNXSD1Config(_MainConfig):
"""Model config for ONNX format models based on sd-1."""
type: Literal[ModelType.ONNX] = ModelType.ONNX
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
base: Literal[BaseModelType.StableDiffusion1] = BaseModelType.StableDiffusion1
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
@ -209,8 +220,10 @@ class ONNXSD1Config(_MainConfig):
class ONNXSD2Config(_MainConfig):
"""Model config for ONNX format models based on sd-2."""
type: Literal[ModelType.ONNX] = ModelType.ONNX
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
# No yaml config file for ONNX, so these are part of config
base: Literal[BaseModelType.StableDiffusion2] = BaseModelType.StableDiffusion2
prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction
upcast_attention: bool = True
@ -218,79 +231,49 @@ class ONNXSD2Config(_MainConfig):
class IPAdapterConfig(ModelConfigBase):
"""Model config for IP Adaptor format models."""
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
format: Literal[ModelFormat.InvokeAI]
class CLIPVisionDiffusersConfig(ModelConfigBase):
"""Model config for ClipVision."""
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
format: Literal[ModelFormat.Diffusers]
class T2IConfig(ModelConfigBase):
"""Model config for T2I."""
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
format: Literal[ModelFormat.Diffusers]
AnyModelConfig = Union[
MainCheckpointConfig,
MainDiffusersConfig,
_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator='base')]
_ControlNetConfig = Annotated[Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig], Field(discriminator='format')]
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator='format')]
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator='format')]
AnyModelConfig = Annotated[
Union[
_MainModelConfig,
_ONNXConfig,
_VaeConfig,
_ControlNetConfig,
LoRAConfig,
TextualInversionConfig,
ONNXSD1Config,
ONNXSD2Config,
VaeCheckpointConfig,
VaeDiffusersConfig,
ControlNetDiffusersConfig,
ControlNetCheckpointConfig,
IPAdapterConfig,
CLIPVisionDiffusersConfig,
T2IConfig,
],
Field(discriminator='type')
]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
class ModelConfigFactory(object):
"""Class for parsing config dicts into StableDiffusion Config obects."""
_class_map: dict = {
ModelFormat.Checkpoint: {
ModelType.Main: MainCheckpointConfig,
ModelType.Vae: VaeCheckpointConfig,
},
ModelFormat.Diffusers: {
ModelType.Main: MainDiffusersConfig,
ModelType.Lora: LoRAConfig,
ModelType.Vae: VaeDiffusersConfig,
ModelType.ControlNet: ControlNetDiffusersConfig,
ModelType.CLIPVision: CLIPVisionDiffusersConfig,
},
ModelFormat.Lycoris: {
ModelType.Lora: LoRAConfig,
},
ModelFormat.Onnx: {
ModelType.ONNX: {
BaseModelType.StableDiffusion1: ONNXSD1Config,
BaseModelType.StableDiffusion2: ONNXSD2Config,
},
},
ModelFormat.Olive: {
ModelType.ONNX: {
BaseModelType.StableDiffusion1: ONNXSD1Config,
BaseModelType.StableDiffusion2: ONNXSD2Config,
},
},
ModelFormat.EmbeddingFile: {
ModelType.TextualInversion: TextualInversionConfig,
},
ModelFormat.EmbeddingFolder: {
ModelType.TextualInversion: TextualInversionConfig,
},
ModelFormat.InvokeAI: {
ModelType.IPAdapter: IPAdapterConfig,
},
}
@classmethod
def make_config(
cls,
@ -311,18 +294,8 @@ class ModelConfigFactory(object):
if key:
model_data.key = key
return model_data
try:
format = model_data.get("format")
type = model_data.get("type")
model_base = model_data.get("base")
class_to_return = dest_class or cls._class_map[format][type]
if isinstance(class_to_return, dict): # additional level allowed
class_to_return = class_to_return[model_base]
model = class_to_return.model_validate(model_data)
else:
model = AnyModelConfigValidator.validate_python(model_data)
if key:
model.key = key # ensure consistency
model.key = key
return model
except KeyError as exc:
raise InvalidModelConfigException(f"Unknown combination of format '{format}' and type '{type}'") from exc
except ValidationError as exc:
raise InvalidModelConfigException(f"Invalid model configuration passed: {str(exc)}") from exc