# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team """ Configuration definitions for image generation models. Typical usage: from invokeai.backend.model_manager import ModelConfigFactory raw = dict(path='models/sd-1/main/foo.ckpt', name='foo', base='sd-1', type='main', config='configs/stable-diffusion/v1-inference.yaml', variant='normal', format='checkpoint' ) config = ModelConfigFactory.make_config(raw) print(config.name) Validation errors will raise an InvalidModelConfigException error. """ import time from enum import Enum from typing import Literal, Optional, Type, Union import torch from diffusers import ModelMixin from pydantic import BaseModel, ConfigDict, Field, TypeAdapter from typing_extensions import Annotated, Any, Dict from ..raw_model import RawModel # ModelMixin is the base class for all diffusers and transformers models # RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime AnyModel = Union[ModelMixin, RawModel, torch.nn.Module] class InvalidModelConfigException(Exception): """Exception for when config parser doesn't recognized this combination of model type and format.""" class BaseModelType(str, Enum): """Base model type.""" Any = "any" StableDiffusion1 = "sd-1" StableDiffusion2 = "sd-2" StableDiffusionXL = "sdxl" StableDiffusionXLRefiner = "sdxl-refiner" # Kandinsky2_1 = "kandinsky-2.1" class ModelType(str, Enum): """Model type.""" ONNX = "onnx" Main = "main" Vae = "vae" Lora = "lora" ControlNet = "controlnet" # used by model_probe TextualInversion = "embedding" IPAdapter = "ip_adapter" CLIPVision = "clip_vision" T2IAdapter = "t2i_adapter" class SubModelType(str, Enum): """Submodel type.""" UNet = "unet" TextEncoder = "text_encoder" TextEncoder2 = "text_encoder_2" Tokenizer = "tokenizer" Tokenizer2 = "tokenizer_2" Vae = "vae" VaeDecoder = "vae_decoder" VaeEncoder = "vae_encoder" Scheduler = "scheduler" SafetyChecker = "safety_checker" class ModelVariantType(str, Enum): """Variant type.""" Normal = "normal" Inpaint = "inpaint" Depth = "depth" class ModelFormat(str, Enum): """Storage format of model.""" Diffusers = "diffusers" Checkpoint = "checkpoint" Lycoris = "lycoris" Onnx = "onnx" Olive = "olive" EmbeddingFile = "embedding_file" EmbeddingFolder = "embedding_folder" InvokeAI = "invokeai" class SchedulerPredictionType(str, Enum): """Scheduler prediction type.""" Epsilon = "epsilon" VPrediction = "v_prediction" Sample = "sample" class ModelRepoVariant(str, Enum): """Various hugging face variants on the diffusers format.""" DEFAULT = "" # model files without "fp16" or other qualifier - empty str FP16 = "fp16" FP32 = "fp32" ONNX = "onnx" OPENVINO = "openvino" FLAX = "flax" class ModelConfigBase(BaseModel): """Base class for model configuration information.""" path: str = Field(description="filesystem path to the model file or directory") name: str = Field(description="model name") base: BaseModelType = Field(description="base model") type: ModelType = Field(description="type of the model") format: ModelFormat = Field(description="model format") key: str = Field(description="unique key for model", default="") original_hash: Optional[str] = Field( description="original fasthash of model contents", default=None ) # this is assigned at install time and will not change current_hash: Optional[str] = Field( description="current fasthash of model contents", default=None ) # if model is converted or otherwise modified, this will hold updated hash description: Optional[str] = Field(description="human readable description of the model", default=None) source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None) last_modified: Optional[float] = Field(description="timestamp for modification time", default_factory=time.time) @staticmethod def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None: schema["required"].extend( ["key", "base", "type", "format", "original_hash", "current_hash", "source", "last_modified"] ) model_config = ConfigDict( use_enum_values=False, validate_assignment=True, json_schema_extra=json_schema_extra, ) def update(self, attributes: Dict[str, Any]) -> None: """Update the object with fields in dict.""" for key, value in attributes.items(): setattr(self, key, value) # may raise a validation error class _CheckpointConfig(ModelConfigBase): """Model config for checkpoint-style models.""" format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint config: str = Field(description="path to the checkpoint model config file") class _DiffusersConfig(ModelConfigBase): """Model config for diffusers-style models.""" format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT class LoRAConfig(ModelConfigBase): """Model config for LoRA/Lycoris models.""" type: Literal[ModelType.Lora] = ModelType.Lora format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers] class VaeCheckpointConfig(ModelConfigBase): """Model config for standalone VAE models.""" type: Literal[ModelType.Vae] = ModelType.Vae format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint class VaeDiffusersConfig(ModelConfigBase): """Model config for standalone VAE models (diffusers version).""" type: Literal[ModelType.Vae] = ModelType.Vae format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers class ControlNetDiffusersConfig(_DiffusersConfig): """Model config for ControlNet models (diffusers version).""" type: Literal[ModelType.ControlNet] = ModelType.ControlNet format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers class ControlNetCheckpointConfig(_CheckpointConfig): """Model config for ControlNet models (diffusers version).""" type: Literal[ModelType.ControlNet] = ModelType.ControlNet format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint class TextualInversionConfig(ModelConfigBase): """Model config for textual inversion embeddings.""" type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder] class _MainConfig(ModelConfigBase): """Model config for main models.""" vae: Optional[str] = Field(default=None) variant: ModelVariantType = ModelVariantType.Normal prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon upcast_attention: bool = False ztsnr_training: bool = False class MainCheckpointConfig(_CheckpointConfig, _MainConfig): """Model config for main checkpoint models.""" type: Literal[ModelType.Main] = ModelType.Main class MainDiffusersConfig(_DiffusersConfig, _MainConfig): """Model config for main diffusers models.""" type: Literal[ModelType.Main] = ModelType.Main class ONNXSD1Config(_MainConfig): """Model config for ONNX format models based on sd-1.""" type: Literal[ModelType.ONNX] = ModelType.ONNX format: Literal[ModelFormat.Onnx, ModelFormat.Olive] base: Literal[BaseModelType.StableDiffusion1] = BaseModelType.StableDiffusion1 prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon upcast_attention: bool = False class ONNXSD2Config(_MainConfig): """Model config for ONNX format models based on sd-2.""" type: Literal[ModelType.ONNX] = ModelType.ONNX format: Literal[ModelFormat.Onnx, ModelFormat.Olive] # No yaml config file for ONNX, so these are part of config base: Literal[BaseModelType.StableDiffusion2] = BaseModelType.StableDiffusion2 prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction upcast_attention: bool = True class ONNXSDXLConfig(_MainConfig): """Model config for ONNX format models based on sdxl.""" type: Literal[ModelType.ONNX] = ModelType.ONNX format: Literal[ModelFormat.Onnx, ModelFormat.Olive] # No yaml config file for ONNX, so these are part of config base: Literal[BaseModelType.StableDiffusionXL] = BaseModelType.StableDiffusionXL prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction class IPAdapterConfig(ModelConfigBase): """Model config for IP Adaptor format models.""" type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter image_encoder_model_id: str format: Literal[ModelFormat.InvokeAI] class CLIPVisionDiffusersConfig(ModelConfigBase): """Model config for ClipVision.""" type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision format: Literal[ModelFormat.Diffusers] class T2IConfig(ModelConfigBase): """Model config for T2I.""" type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter format: Literal[ModelFormat.Diffusers] _ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config, ONNXSDXLConfig], Field(discriminator="base")] _ControlNetConfig = Annotated[ Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig], Field(discriminator="format"), ] _VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")] _MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")] AnyModelConfig = Union[ _MainModelConfig, _VaeConfig, _ControlNetConfig, # ModelConfigBase, LoRAConfig, TextualInversionConfig, IPAdapterConfig, CLIPVisionDiffusersConfig, T2IConfig, ] AnyModelConfigValidator = TypeAdapter(AnyModelConfig) # IMPLEMENTATION NOTE: # The preferred alternative to the above is a discriminated Union as shown # below. However, it breaks FastAPI when used as the input Body parameter in a route. # This is a known issue. Please see: # https://github.com/tiangolo/fastapi/discussions/9761 and # https://github.com/tiangolo/fastapi/discussions/9287 # AnyModelConfig = Annotated[ # Union[ # _MainModelConfig, # _ONNXConfig, # _VaeConfig, # _ControlNetConfig, # LoRAConfig, # TextualInversionConfig, # IPAdapterConfig, # CLIPVisionDiffusersConfig, # T2IConfig, # ], # Field(discriminator="type"), # ] class ModelConfigFactory(object): """Class for parsing config dicts into StableDiffusion Config obects.""" @classmethod def make_config( cls, model_data: Union[Dict[str, Any], AnyModelConfig], key: Optional[str] = None, dest_class: Optional[Type[ModelConfigBase]] = None, timestamp: Optional[float] = None, ) -> AnyModelConfig: """ Return the appropriate config object from raw dict values. :param model_data: A raw dict corresponding the obect fields to be parsed into a ModelConfigBase obect (or descendent), or a ModelConfigBase object, which will be passed through unchanged. :param dest_class: The config class to be returned. If not provided, will be selected automatically. """ model: Optional[ModelConfigBase] = None if isinstance(model_data, ModelConfigBase): model = model_data elif dest_class: model = dest_class.model_validate(model_data) else: # mypy doesn't typecheck TypeAdapters well? model = AnyModelConfigValidator.validate_python(model_data) # type: ignore assert model is not None if key: model.key = key if timestamp: model.last_modified = timestamp return model # type: ignore