replace _class_map in ModelConfigFactory with a nested discriminated union

This commit is contained in:
Lincoln Stein 2023-11-10 19:14:15 -05:00
parent bd56e9bc81
commit 3a6ba236f5
2 changed files with 39 additions and 68 deletions

View File

@ -18,8 +18,6 @@ from ..dependencies import ApiDependencies
model_records_router = APIRouter(prefix="/v1/model/record", tags=["models"]) model_records_router = APIRouter(prefix="/v1/model/record", tags=["models"])
ModelConfigValidator = TypeAdapter(AnyModelConfig)
class ModelsList(BaseModel): class ModelsList(BaseModel):
"""Return list of configs.""" """Return list of configs."""

View File

@ -22,7 +22,8 @@ Validation errors will raise an InvalidModelConfigException error.
from enum import Enum from enum import Enum
from typing import Literal, Optional, Type, Union from typing import Literal, Optional, Type, Union
from pydantic import BaseModel, ConfigDict, Field, ValidationError from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from typing_extensions import Annotated
class InvalidModelConfigException(Exception): class InvalidModelConfigException(Exception):
@ -143,36 +144,42 @@ class _DiffusersConfig(ModelConfigBase):
class LoRAConfig(ModelConfigBase): class LoRAConfig(ModelConfigBase):
"""Model config for LoRA/Lycoris models.""" """Model config for LoRA/Lycoris models."""
type: Literal[ModelType.Lora] = ModelType.Lora
format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers] format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers]
class VaeCheckpointConfig(ModelConfigBase): class VaeCheckpointConfig(ModelConfigBase):
"""Model config for standalone VAE models.""" """Model config for standalone VAE models."""
type: Literal[ModelType.Vae] = ModelType.Vae
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
class VaeDiffusersConfig(ModelConfigBase): class VaeDiffusersConfig(ModelConfigBase):
"""Model config for standalone VAE models (diffusers version).""" """Model config for standalone VAE models (diffusers version)."""
type: Literal[ModelType.Vae] = ModelType.Vae
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetDiffusersConfig(_DiffusersConfig): class ControlNetDiffusersConfig(_DiffusersConfig):
"""Model config for ControlNet models (diffusers version).""" """Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetCheckpointConfig(_CheckpointConfig): class ControlNetCheckpointConfig(_CheckpointConfig):
"""Model config for ControlNet models (diffusers version).""" """Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
class TextualInversionConfig(ModelConfigBase): class TextualInversionConfig(ModelConfigBase):
"""Model config for textual inversion embeddings.""" """Model config for textual inversion embeddings."""
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder] format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder]
@ -187,6 +194,7 @@ class _MainConfig(ModelConfigBase):
class MainCheckpointConfig(_CheckpointConfig, _MainConfig): class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
"""Model config for main checkpoint models.""" """Model config for main checkpoint models."""
type: Literal[ModelType.Main] = ModelType.Main
# Note that we do not need prediction_type or upcast_attention here # Note that we do not need prediction_type or upcast_attention here
# because they are provided in the checkpoint's own config file. # because they are provided in the checkpoint's own config file.
@ -194,6 +202,7 @@ class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
class MainDiffusersConfig(_DiffusersConfig, _MainConfig): class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
"""Model config for main diffusers models.""" """Model config for main diffusers models."""
type: Literal[ModelType.Main] = ModelType.Main
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False upcast_attention: bool = False
@ -201,7 +210,9 @@ class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
class ONNXSD1Config(_MainConfig): class ONNXSD1Config(_MainConfig):
"""Model config for ONNX format models based on sd-1.""" """Model config for ONNX format models based on sd-1."""
type: Literal[ModelType.ONNX] = ModelType.ONNX
format: Literal[ModelFormat.Onnx, ModelFormat.Olive] format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
base: Literal[BaseModelType.StableDiffusion1] = BaseModelType.StableDiffusion1
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False upcast_attention: bool = False
@ -209,8 +220,10 @@ class ONNXSD1Config(_MainConfig):
class ONNXSD2Config(_MainConfig): class ONNXSD2Config(_MainConfig):
"""Model config for ONNX format models based on sd-2.""" """Model config for ONNX format models based on sd-2."""
type: Literal[ModelType.ONNX] = ModelType.ONNX
format: Literal[ModelFormat.Onnx, ModelFormat.Olive] format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
# No yaml config file for ONNX, so these are part of config # No yaml config file for ONNX, so these are part of config
base: Literal[BaseModelType.StableDiffusion2] = BaseModelType.StableDiffusion2
prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction
upcast_attention: bool = True upcast_attention: bool = True
@ -218,79 +231,49 @@ class ONNXSD2Config(_MainConfig):
class IPAdapterConfig(ModelConfigBase): class IPAdapterConfig(ModelConfigBase):
"""Model config for IP Adaptor format models.""" """Model config for IP Adaptor format models."""
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
format: Literal[ModelFormat.InvokeAI] format: Literal[ModelFormat.InvokeAI]
class CLIPVisionDiffusersConfig(ModelConfigBase): class CLIPVisionDiffusersConfig(ModelConfigBase):
"""Model config for ClipVision.""" """Model config for ClipVision."""
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
format: Literal[ModelFormat.Diffusers] format: Literal[ModelFormat.Diffusers]
class T2IConfig(ModelConfigBase): class T2IConfig(ModelConfigBase):
"""Model config for T2I.""" """Model config for T2I."""
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
format: Literal[ModelFormat.Diffusers] format: Literal[ModelFormat.Diffusers]
AnyModelConfig = Union[ _ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator='base')]
MainCheckpointConfig, _ControlNetConfig = Annotated[Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig], Field(discriminator='format')]
MainDiffusersConfig, _VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator='format')]
LoRAConfig, _MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator='format')]
TextualInversionConfig,
ONNXSD1Config, AnyModelConfig = Annotated[
ONNXSD2Config, Union[
VaeCheckpointConfig, _MainModelConfig,
VaeDiffusersConfig, _ONNXConfig,
ControlNetDiffusersConfig, _VaeConfig,
ControlNetCheckpointConfig, _ControlNetConfig,
IPAdapterConfig, LoRAConfig,
CLIPVisionDiffusersConfig, TextualInversionConfig,
T2IConfig, IPAdapterConfig,
CLIPVisionDiffusersConfig,
T2IConfig,
],
Field(discriminator='type')
] ]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
class ModelConfigFactory(object): class ModelConfigFactory(object):
"""Class for parsing config dicts into StableDiffusion Config obects.""" """Class for parsing config dicts into StableDiffusion Config obects."""
_class_map: dict = {
ModelFormat.Checkpoint: {
ModelType.Main: MainCheckpointConfig,
ModelType.Vae: VaeCheckpointConfig,
},
ModelFormat.Diffusers: {
ModelType.Main: MainDiffusersConfig,
ModelType.Lora: LoRAConfig,
ModelType.Vae: VaeDiffusersConfig,
ModelType.ControlNet: ControlNetDiffusersConfig,
ModelType.CLIPVision: CLIPVisionDiffusersConfig,
},
ModelFormat.Lycoris: {
ModelType.Lora: LoRAConfig,
},
ModelFormat.Onnx: {
ModelType.ONNX: {
BaseModelType.StableDiffusion1: ONNXSD1Config,
BaseModelType.StableDiffusion2: ONNXSD2Config,
},
},
ModelFormat.Olive: {
ModelType.ONNX: {
BaseModelType.StableDiffusion1: ONNXSD1Config,
BaseModelType.StableDiffusion2: ONNXSD2Config,
},
},
ModelFormat.EmbeddingFile: {
ModelType.TextualInversion: TextualInversionConfig,
},
ModelFormat.EmbeddingFolder: {
ModelType.TextualInversion: TextualInversionConfig,
},
ModelFormat.InvokeAI: {
ModelType.IPAdapter: IPAdapterConfig,
},
}
@classmethod @classmethod
def make_config( def make_config(
cls, cls,
@ -311,18 +294,8 @@ class ModelConfigFactory(object):
if key: if key:
model_data.key = key model_data.key = key
return model_data return model_data
try: else:
format = model_data.get("format") model = AnyModelConfigValidator.validate_python(model_data)
type = model_data.get("type")
model_base = model_data.get("base")
class_to_return = dest_class or cls._class_map[format][type]
if isinstance(class_to_return, dict): # additional level allowed
class_to_return = class_to_return[model_base]
model = class_to_return.model_validate(model_data)
if key: if key:
model.key = key # ensure consistency model.key = key
return model return model
except KeyError as exc:
raise InvalidModelConfigException(f"Unknown combination of format '{format}' and type '{type}'") from exc
except ValidationError as exc:
raise InvalidModelConfigException(f"Invalid model configuration passed: {str(exc)}") from exc