# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team """ Configuration definitions for image generation models. Typical usage: from invokeai.backend.model_manager import ModelConfigFactory raw = dict(path='models/sd-1/main/foo.ckpt', name='foo', base='sd-1', type='main', config='configs/stable-diffusion/v1-inference.yaml', variant='normal', format='checkpoint' ) config = ModelConfigFactory.make_config(raw) print(config.name) Validation errors will raise an InvalidModelConfigException error. """ from enum import Enum from typing import Literal, Optional, Type, Union from fastapi import Body from pydantic import BaseModel, ConfigDict, Field, TypeAdapter from typing_extensions import Annotated class InvalidModelConfigException(Exception): """Exception for when config parser doesn't recognized this combination of model type and format.""" class BaseModelType(str, Enum): """Base model type.""" Any = "any" StableDiffusion1 = "sd-1" StableDiffusion2 = "sd-2" StableDiffusionXL = "sdxl" StableDiffusionXLRefiner = "sdxl-refiner" # Kandinsky2_1 = "kandinsky-2.1" class ModelType(str, Enum): """Model type.""" ONNX = "onnx" Main = "main" Vae = "vae" Lora = "lora" ControlNet = "controlnet" # used by model_probe TextualInversion = "embedding" IPAdapter = "ip_adapter" CLIPVision = "clip_vision" T2IAdapter = "t2i_adapter" class SubModelType(str, Enum): """Submodel type.""" UNet = "unet" TextEncoder = "text_encoder" TextEncoder2 = "text_encoder_2" Tokenizer = "tokenizer" Tokenizer2 = "tokenizer_2" Vae = "vae" VaeDecoder = "vae_decoder" VaeEncoder = "vae_encoder" Scheduler = "scheduler" SafetyChecker = "safety_checker" class ModelVariantType(str, Enum): """Variant type.""" Normal = "normal" Inpaint = "inpaint" Depth = "depth" class ModelFormat(str, Enum): """Storage format of model.""" Diffusers = "diffusers" Checkpoint = "checkpoint" Lycoris = "lycoris" Onnx = "onnx" Olive = "olive" EmbeddingFile = "embedding_file" EmbeddingFolder = "embedding_folder" InvokeAI = "invokeai" class SchedulerPredictionType(str, Enum): """Scheduler prediction type.""" Epsilon = "epsilon" VPrediction = "v_prediction" Sample = "sample" class ModelConfigBase(BaseModel): """Base class for model configuration information.""" path: str name: str base: BaseModelType type: ModelType format: ModelFormat key: str = Field(description="unique key for model", default="") original_hash: Optional[str] = Field( description="original fasthash of model contents", default=None ) # this is assigned at install time and will not change current_hash: Optional[str] = Field( description="current fasthash of model contents", default=None ) # if model is converted or otherwise modified, this will hold updated hash description: Optional[str] = Field(None) source: Optional[str] = Field(description="Model download source (URL or repo_id)", default=None) model_config = ConfigDict( use_enum_values=False, validate_assignment=True, ) def update(self, attributes: dict): """Update the object with fields in dict.""" for key, value in attributes.items(): setattr(self, key, value) # may raise a validation error class _CheckpointConfig(ModelConfigBase): """Model config for checkpoint-style models.""" format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint config: str = Field(description="path to the checkpoint model config file") class _DiffusersConfig(ModelConfigBase): """Model config for diffusers-style models.""" format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers class LoRAConfig(ModelConfigBase): """Model config for LoRA/Lycoris models.""" type: Literal[ModelType.Lora] = ModelType.Lora format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers] class VaeCheckpointConfig(ModelConfigBase): """Model config for standalone VAE models.""" type: Literal[ModelType.Vae] = ModelType.Vae format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint class VaeDiffusersConfig(ModelConfigBase): """Model config for standalone VAE models (diffusers version).""" type: Literal[ModelType.Vae] = ModelType.Vae format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers class ControlNetDiffusersConfig(_DiffusersConfig): """Model config for ControlNet models (diffusers version).""" type: Literal[ModelType.ControlNet] = ModelType.ControlNet format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers class ControlNetCheckpointConfig(_CheckpointConfig): """Model config for ControlNet models (diffusers version).""" type: Literal[ModelType.ControlNet] = ModelType.ControlNet format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint class TextualInversionConfig(ModelConfigBase): """Model config for textual inversion embeddings.""" type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder] class _MainConfig(ModelConfigBase): """Model config for main models.""" vae: Optional[str] = Field(default=None) variant: ModelVariantType = ModelVariantType.Normal ztsnr_training: bool = False class MainCheckpointConfig(_CheckpointConfig, _MainConfig): """Model config for main checkpoint models.""" type: Literal[ModelType.Main] = ModelType.Main # Note that we do not need prediction_type or upcast_attention here # because they are provided in the checkpoint's own config file. class MainDiffusersConfig(_DiffusersConfig, _MainConfig): """Model config for main diffusers models.""" type: Literal[ModelType.Main] = ModelType.Main prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon upcast_attention: bool = False class ONNXSD1Config(_MainConfig): """Model config for ONNX format models based on sd-1.""" type: Literal[ModelType.ONNX] = ModelType.ONNX format: Literal[ModelFormat.Onnx, ModelFormat.Olive] base: Literal[BaseModelType.StableDiffusion1] = BaseModelType.StableDiffusion1 prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon upcast_attention: bool = False class ONNXSD2Config(_MainConfig): """Model config for ONNX format models based on sd-2.""" type: Literal[ModelType.ONNX] = ModelType.ONNX format: Literal[ModelFormat.Onnx, ModelFormat.Olive] # No yaml config file for ONNX, so these are part of config base: Literal[BaseModelType.StableDiffusion2] = BaseModelType.StableDiffusion2 prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction upcast_attention: bool = True class IPAdapterConfig(ModelConfigBase): """Model config for IP Adaptor format models.""" type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter format: Literal[ModelFormat.InvokeAI] class CLIPVisionDiffusersConfig(ModelConfigBase): """Model config for ClipVision.""" type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision format: Literal[ModelFormat.Diffusers] class T2IConfig(ModelConfigBase): """Model config for T2I.""" type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter format: Literal[ModelFormat.Diffusers] _ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")] _ControlNetConfig = Annotated[ Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig], Field(discriminator="format") ] _VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")] _MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")] AnyModelConfig = Annotated[ Union[ _MainModelConfig, _ONNXConfig, _VaeConfig, _ControlNetConfig, LoRAConfig, TextualInversionConfig, IPAdapterConfig, CLIPVisionDiffusersConfig, T2IConfig, ], Body(discriminator="type"), ] AnyModelConfigValidator = TypeAdapter(AnyModelConfig) class ModelConfigFactory(object): """Class for parsing config dicts into StableDiffusion Config obects.""" @classmethod def make_config( cls, model_data: Union[dict, ModelConfigBase], key: Optional[str] = None, dest_class: Optional[Type] = None, ) -> AnyModelConfig: """ Return the appropriate config object from raw dict values. :param model_data: A raw dict corresponding the obect fields to be parsed into a ModelConfigBase obect (or descendent), or a ModelConfigBase object, which will be passed through unchanged. :param dest_class: The config class to be returned. If not provided, will be selected automatically. """ if isinstance(model_data, ModelConfigBase): if key: model_data.key = key return model_data else: model = AnyModelConfigValidator.validate_python(model_data) if key: model.key = key return model