diff --git a/invokeai/app/api/routers/model_records.py b/invokeai/app/api/routers/model_records.py index 1040e598f0..00ee3471d1 100644 --- a/invokeai/app/api/routers/model_records.py +++ b/invokeai/app/api/routers/model_records.py @@ -18,8 +18,6 @@ from ..dependencies import ApiDependencies model_records_router = APIRouter(prefix="/v1/model/record", tags=["models"]) -ModelConfigValidator = TypeAdapter(AnyModelConfig) - class ModelsList(BaseModel): """Return list of configs.""" diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 7a6a2589c6..187e28d27d 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -22,7 +22,8 @@ Validation errors will raise an InvalidModelConfigException error. from enum import Enum from typing import Literal, Optional, Type, Union -from pydantic import BaseModel, ConfigDict, Field, ValidationError +from pydantic import BaseModel, ConfigDict, Field, TypeAdapter +from typing_extensions import Annotated class InvalidModelConfigException(Exception): @@ -143,36 +144,42 @@ class _DiffusersConfig(ModelConfigBase): class LoRAConfig(ModelConfigBase): """Model config for LoRA/Lycoris models.""" + type: Literal[ModelType.Lora] = ModelType.Lora format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers] class VaeCheckpointConfig(ModelConfigBase): """Model config for standalone VAE models.""" + type: Literal[ModelType.Vae] = ModelType.Vae format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint class VaeDiffusersConfig(ModelConfigBase): """Model config for standalone VAE models (diffusers version).""" + type: Literal[ModelType.Vae] = ModelType.Vae format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers class ControlNetDiffusersConfig(_DiffusersConfig): """Model config for ControlNet models (diffusers version).""" + type: Literal[ModelType.ControlNet] = ModelType.ControlNet format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers class ControlNetCheckpointConfig(_CheckpointConfig): """Model config for ControlNet models (diffusers version).""" + type: Literal[ModelType.ControlNet] = ModelType.ControlNet format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint class TextualInversionConfig(ModelConfigBase): """Model config for textual inversion embeddings.""" + type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder] @@ -187,6 +194,7 @@ class _MainConfig(ModelConfigBase): class MainCheckpointConfig(_CheckpointConfig, _MainConfig): """Model config for main checkpoint models.""" + type: Literal[ModelType.Main] = ModelType.Main # Note that we do not need prediction_type or upcast_attention here # because they are provided in the checkpoint's own config file. @@ -194,6 +202,7 @@ class MainCheckpointConfig(_CheckpointConfig, _MainConfig): class MainDiffusersConfig(_DiffusersConfig, _MainConfig): """Model config for main diffusers models.""" + type: Literal[ModelType.Main] = ModelType.Main prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon upcast_attention: bool = False @@ -201,7 +210,9 @@ class MainDiffusersConfig(_DiffusersConfig, _MainConfig): class ONNXSD1Config(_MainConfig): """Model config for ONNX format models based on sd-1.""" + type: Literal[ModelType.ONNX] = ModelType.ONNX format: Literal[ModelFormat.Onnx, ModelFormat.Olive] + base: Literal[BaseModelType.StableDiffusion1] = BaseModelType.StableDiffusion1 prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon upcast_attention: bool = False @@ -209,8 +220,10 @@ class ONNXSD1Config(_MainConfig): class ONNXSD2Config(_MainConfig): """Model config for ONNX format models based on sd-2.""" + type: Literal[ModelType.ONNX] = ModelType.ONNX format: Literal[ModelFormat.Onnx, ModelFormat.Olive] # No yaml config file for ONNX, so these are part of config + base: Literal[BaseModelType.StableDiffusion2] = BaseModelType.StableDiffusion2 prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction upcast_attention: bool = True @@ -218,79 +231,49 @@ class ONNXSD2Config(_MainConfig): class IPAdapterConfig(ModelConfigBase): """Model config for IP Adaptor format models.""" + type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter format: Literal[ModelFormat.InvokeAI] class CLIPVisionDiffusersConfig(ModelConfigBase): """Model config for ClipVision.""" + type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision format: Literal[ModelFormat.Diffusers] class T2IConfig(ModelConfigBase): """Model config for T2I.""" + type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter format: Literal[ModelFormat.Diffusers] -AnyModelConfig = Union[ - MainCheckpointConfig, - MainDiffusersConfig, - LoRAConfig, - TextualInversionConfig, - ONNXSD1Config, - ONNXSD2Config, - VaeCheckpointConfig, - VaeDiffusersConfig, - ControlNetDiffusersConfig, - ControlNetCheckpointConfig, - IPAdapterConfig, - CLIPVisionDiffusersConfig, - T2IConfig, +_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator='base')] +_ControlNetConfig = Annotated[Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig], Field(discriminator='format')] +_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator='format')] +_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator='format')] + +AnyModelConfig = Annotated[ + Union[ + _MainModelConfig, + _ONNXConfig, + _VaeConfig, + _ControlNetConfig, + LoRAConfig, + TextualInversionConfig, + IPAdapterConfig, + CLIPVisionDiffusersConfig, + T2IConfig, + ], + Field(discriminator='type') ] +AnyModelConfigValidator = TypeAdapter(AnyModelConfig) class ModelConfigFactory(object): """Class for parsing config dicts into StableDiffusion Config obects.""" - _class_map: dict = { - ModelFormat.Checkpoint: { - ModelType.Main: MainCheckpointConfig, - ModelType.Vae: VaeCheckpointConfig, - }, - ModelFormat.Diffusers: { - ModelType.Main: MainDiffusersConfig, - ModelType.Lora: LoRAConfig, - ModelType.Vae: VaeDiffusersConfig, - ModelType.ControlNet: ControlNetDiffusersConfig, - ModelType.CLIPVision: CLIPVisionDiffusersConfig, - }, - ModelFormat.Lycoris: { - ModelType.Lora: LoRAConfig, - }, - ModelFormat.Onnx: { - ModelType.ONNX: { - BaseModelType.StableDiffusion1: ONNXSD1Config, - BaseModelType.StableDiffusion2: ONNXSD2Config, - }, - }, - ModelFormat.Olive: { - ModelType.ONNX: { - BaseModelType.StableDiffusion1: ONNXSD1Config, - BaseModelType.StableDiffusion2: ONNXSD2Config, - }, - }, - ModelFormat.EmbeddingFile: { - ModelType.TextualInversion: TextualInversionConfig, - }, - ModelFormat.EmbeddingFolder: { - ModelType.TextualInversion: TextualInversionConfig, - }, - ModelFormat.InvokeAI: { - ModelType.IPAdapter: IPAdapterConfig, - }, - } - @classmethod def make_config( cls, @@ -311,18 +294,8 @@ class ModelConfigFactory(object): if key: model_data.key = key return model_data - try: - format = model_data.get("format") - type = model_data.get("type") - model_base = model_data.get("base") - class_to_return = dest_class or cls._class_map[format][type] - if isinstance(class_to_return, dict): # additional level allowed - class_to_return = class_to_return[model_base] - model = class_to_return.model_validate(model_data) + else: + model = AnyModelConfigValidator.validate_python(model_data) if key: - model.key = key # ensure consistency + model.key = key return model - except KeyError as exc: - raise InvalidModelConfigException(f"Unknown combination of format '{format}' and type '{type}'") from exc - except ValidationError as exc: - raise InvalidModelConfigException(f"Invalid model configuration passed: {str(exc)}") from exc