# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team """ Configuration definitions for image generation models. Typical usage: from invokeai.backend.model_manager import ModelConfigFactory raw = dict(path='models/sd-1/main/foo.ckpt', name='foo', base='sd-1', type='main', config='configs/stable-diffusion/v1-inference.yaml', variant='normal', format='checkpoint' ) config = ModelConfigFactory.make_config(raw) print(config.name) Validation errors will raise an InvalidModelConfigException error. """ from enum import Enum from typing import Literal, Optional, Type, Union from pydantic import BaseModel, ConfigDict, Field, TypeAdapter from typing_extensions import Annotated, Any, Dict class InvalidModelConfigException(Exception): """Exception for when config parser doesn't recognized this combination of model type and format.""" class BaseModelType(str, Enum): """Base model type.""" Any = "any" StableDiffusion1 = "sd-1" StableDiffusion2 = "sd-2" StableDiffusionXL = "sdxl" StableDiffusionXLRefiner = "sdxl-refiner" # Kandinsky2_1 = "kandinsky-2.1" class ModelType(str, Enum): """Model type.""" ONNX = "onnx" Main = "main" Vae = "vae" Lora = "lora" ControlNet = "controlnet" # used by model_probe TextualInversion = "embedding" IPAdapter = "ip_adapter" CLIPVision = "clip_vision" T2IAdapter = "t2i_adapter" class SubModelType(str, Enum): """Submodel type.""" UNet = "unet" TextEncoder = "text_encoder" TextEncoder2 = "text_encoder_2" Tokenizer = "tokenizer" Tokenizer2 = "tokenizer_2" Vae = "vae" VaeDecoder = "vae_decoder" VaeEncoder = "vae_encoder" Scheduler = "scheduler" SafetyChecker = "safety_checker" class ModelVariantType(str, Enum): """Variant type.""" Normal = "normal" Inpaint = "inpaint" Depth = "depth" class ModelFormat(str, Enum): """Storage format of model.""" Diffusers = "diffusers" Checkpoint = "checkpoint" Lycoris = "lycoris" Onnx = "onnx" Olive = "olive" EmbeddingFile = "embedding_file" EmbeddingFolder = "embedding_folder" InvokeAI = "invokeai" class SchedulerPredictionType(str, Enum): """Scheduler prediction type.""" Epsilon = "epsilon" VPrediction = "v_prediction" Sample = "sample" class ModelRepoVariant(str, Enum): """Various hugging face variants on the diffusers format.""" DEFAULT = "default" # model files without "fp16" or other qualifier FP16 = "fp16" FP32 = "fp32" ONNX = "onnx" OPENVINO = "openvino" FLAX = "flax" class ModelConfigBase(BaseModel): """Base class for model configuration information.""" path: str name: str base: BaseModelType type: ModelType format: ModelFormat key: str = Field(description="unique key for model", default="") original_hash: Optional[str] = Field( description="original fasthash of model contents", default=None ) # this is assigned at install time and will not change current_hash: Optional[str] = Field( description="current fasthash of model contents", default=None ) # if model is converted or otherwise modified, this will hold updated hash description: Optional[str] = Field(default=None) source: Optional[str] = Field(description="Model download source (URL or repo_id)", default=None) model_config = ConfigDict( use_enum_values=False, validate_assignment=True, ) def update(self, attributes: Dict[str, Any]) -> None: """Update the object with fields in dict.""" for key, value in attributes.items(): setattr(self, key, value) # may raise a validation error class _CheckpointConfig(ModelConfigBase): """Model config for checkpoint-style models.""" format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint config: str = Field(description="path to the checkpoint model config file") class _DiffusersConfig(ModelConfigBase): """Model config for diffusers-style models.""" format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers class LoRAConfig(ModelConfigBase): """Model config for LoRA/Lycoris models.""" type: Literal[ModelType.Lora] = ModelType.Lora format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers] class VaeCheckpointConfig(ModelConfigBase): """Model config for standalone VAE models.""" type: Literal[ModelType.Vae] = ModelType.Vae format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint class VaeDiffusersConfig(ModelConfigBase): """Model config for standalone VAE models (diffusers version).""" type: Literal[ModelType.Vae] = ModelType.Vae format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers class ControlNetDiffusersConfig(_DiffusersConfig): """Model config for ControlNet models (diffusers version).""" type: Literal[ModelType.ControlNet] = ModelType.ControlNet format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers class ControlNetCheckpointConfig(_CheckpointConfig): """Model config for ControlNet models (diffusers version).""" type: Literal[ModelType.ControlNet] = ModelType.ControlNet format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint class TextualInversionConfig(ModelConfigBase): """Model config for textual inversion embeddings.""" type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder] class _MainConfig(ModelConfigBase): """Model config for main models.""" vae: Optional[str] = Field(default=None) variant: ModelVariantType = ModelVariantType.Normal ztsnr_training: bool = False class MainCheckpointConfig(_CheckpointConfig, _MainConfig): """Model config for main checkpoint models.""" type: Literal[ModelType.Main] = ModelType.Main class MainDiffusersConfig(_DiffusersConfig, _MainConfig): """Model config for main diffusers models.""" type: Literal[ModelType.Main] = ModelType.Main prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon upcast_attention: bool = False class ONNXSD1Config(_MainConfig): """Model config for ONNX format models based on sd-1.""" type: Literal[ModelType.ONNX] = ModelType.ONNX format: Literal[ModelFormat.Onnx, ModelFormat.Olive] base: Literal[BaseModelType.StableDiffusion1] = BaseModelType.StableDiffusion1 prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon upcast_attention: bool = False class ONNXSD2Config(_MainConfig): """Model config for ONNX format models based on sd-2.""" type: Literal[ModelType.ONNX] = ModelType.ONNX format: Literal[ModelFormat.Onnx, ModelFormat.Olive] # No yaml config file for ONNX, so these are part of config base: Literal[BaseModelType.StableDiffusion2] = BaseModelType.StableDiffusion2 prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction upcast_attention: bool = True class IPAdapterConfig(ModelConfigBase): """Model config for IP Adaptor format models.""" type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter format: Literal[ModelFormat.InvokeAI] class CLIPVisionDiffusersConfig(ModelConfigBase): """Model config for ClipVision.""" type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision format: Literal[ModelFormat.Diffusers] class T2IConfig(ModelConfigBase): """Model config for T2I.""" type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter format: Literal[ModelFormat.Diffusers] _ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")] _ControlNetConfig = Annotated[ Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig], Field(discriminator="format"), ] _VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")] _MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")] AnyModelConfig = Union[ _MainModelConfig, _ONNXConfig, _VaeConfig, _ControlNetConfig, LoRAConfig, TextualInversionConfig, IPAdapterConfig, CLIPVisionDiffusersConfig, T2IConfig, ] AnyModelConfigValidator = TypeAdapter(AnyModelConfig) # IMPLEMENTATION NOTE: # The preferred alternative to the above is a discriminated Union as shown # below. However, it breaks FastAPI when used as the input Body parameter in a route. # This is a known issue. Please see: # https://github.com/tiangolo/fastapi/discussions/9761 and # https://github.com/tiangolo/fastapi/discussions/9287 # AnyModelConfig = Annotated[ # Union[ # _MainModelConfig, # _ONNXConfig, # _VaeConfig, # _ControlNetConfig, # LoRAConfig, # TextualInversionConfig, # IPAdapterConfig, # CLIPVisionDiffusersConfig, # T2IConfig, # ], # Field(discriminator="type"), # ] class ModelConfigFactory(object): """Class for parsing config dicts into StableDiffusion Config obects.""" @classmethod def make_config( cls, model_data: Union[dict, AnyModelConfig], key: Optional[str] = None, dest_class: Optional[Type] = None, ) -> AnyModelConfig: """ Return the appropriate config object from raw dict values. :param model_data: A raw dict corresponding the obect fields to be parsed into a ModelConfigBase obect (or descendent), or a ModelConfigBase object, which will be passed through unchanged. :param dest_class: The config class to be returned. If not provided, will be selected automatically. """ if isinstance(model_data, ModelConfigBase): model = model_data elif dest_class: model = dest_class.validate_python(model_data) else: model = AnyModelConfigValidator.validate_python(model_data) if key: model.key = key return model