InvokeAI/invokeai/backend/model_manager/config.py

322 lines
9.7 KiB
Python

# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Configuration definitions for image generation models.
Typical usage:
from invokeai.backend.model_manager import ModelConfigFactory
raw = dict(path='models/sd-1/main/foo.ckpt',
name='foo',
base='sd-1',
type='main',
config='configs/stable-diffusion/v1-inference.yaml',
variant='normal',
format='checkpoint'
)
config = ModelConfigFactory.make_config(raw)
print(config.name)
Validation errors will raise an InvalidModelConfigException error.
"""
from enum import Enum
from typing import Literal, Optional, Type, Union
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from typing_extensions import Annotated, Any, Dict
class InvalidModelConfigException(Exception):
"""Exception for when config parser doesn't recognized this combination of model type and format."""
class BaseModelType(str, Enum):
"""Base model type."""
Any = "any"
StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
# Kandinsky2_1 = "kandinsky-2.1"
class ModelType(str, Enum):
"""Model type."""
ONNX = "onnx"
Main = "main"
Vae = "vae"
Lora = "lora"
ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding"
IPAdapter = "ip_adapter"
CLIPVision = "clip_vision"
T2IAdapter = "t2i_adapter"
class SubModelType(str, Enum):
"""Submodel type."""
UNet = "unet"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2"
Vae = "vae"
VaeDecoder = "vae_decoder"
VaeEncoder = "vae_encoder"
Scheduler = "scheduler"
SafetyChecker = "safety_checker"
class ModelVariantType(str, Enum):
"""Variant type."""
Normal = "normal"
Inpaint = "inpaint"
Depth = "depth"
class ModelFormat(str, Enum):
"""Storage format of model."""
Diffusers = "diffusers"
Checkpoint = "checkpoint"
Lycoris = "lycoris"
Onnx = "onnx"
Olive = "olive"
EmbeddingFile = "embedding_file"
EmbeddingFolder = "embedding_folder"
InvokeAI = "invokeai"
class SchedulerPredictionType(str, Enum):
"""Scheduler prediction type."""
Epsilon = "epsilon"
VPrediction = "v_prediction"
Sample = "sample"
class ModelConfigBase(BaseModel):
"""Base class for model configuration information."""
path: str
name: str
base: BaseModelType
type: ModelType
format: ModelFormat
key: str = Field(description="unique key for model", default="<NOKEY>")
original_hash: Optional[str] = Field(
description="original fasthash of model contents", default=None
) # this is assigned at install time and will not change
current_hash: Optional[str] = Field(
description="current fasthash of model contents", default=None
) # if model is converted or otherwise modified, this will hold updated hash
description: Optional[str] = Field(default=None)
source: Optional[str] = Field(description="Model download source (URL or repo_id)", default=None)
model_config = ConfigDict(
use_enum_values=False,
validate_assignment=True,
)
def update(self, attributes: Dict[str, Any]) -> None:
"""Update the object with fields in dict."""
for key, value in attributes.items():
setattr(self, key, value) # may raise a validation error
class _CheckpointConfig(ModelConfigBase):
"""Model config for checkpoint-style models."""
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
config: str = Field(description="path to the checkpoint model config file")
class _DiffusersConfig(ModelConfigBase):
"""Model config for diffusers-style models."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class LoRAConfig(ModelConfigBase):
"""Model config for LoRA/Lycoris models."""
type: Literal[ModelType.Lora] = ModelType.Lora
format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers]
class VaeCheckpointConfig(ModelConfigBase):
"""Model config for standalone VAE models."""
type: Literal[ModelType.Vae] = ModelType.Vae
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
class VaeDiffusersConfig(ModelConfigBase):
"""Model config for standalone VAE models (diffusers version)."""
type: Literal[ModelType.Vae] = ModelType.Vae
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetDiffusersConfig(_DiffusersConfig):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetCheckpointConfig(_CheckpointConfig):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
class TextualInversionConfig(ModelConfigBase):
"""Model config for textual inversion embeddings."""
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder]
class _MainConfig(ModelConfigBase):
"""Model config for main models."""
vae: Optional[str] = Field(default=None)
variant: ModelVariantType = ModelVariantType.Normal
ztsnr_training: bool = False
class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
"""Model config for main checkpoint models."""
type: Literal[ModelType.Main] = ModelType.Main
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
"""Model config for main diffusers models."""
type: Literal[ModelType.Main] = ModelType.Main
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
class ONNXSD1Config(_MainConfig):
"""Model config for ONNX format models based on sd-1."""
type: Literal[ModelType.ONNX] = ModelType.ONNX
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
base: Literal[BaseModelType.StableDiffusion1] = BaseModelType.StableDiffusion1
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
class ONNXSD2Config(_MainConfig):
"""Model config for ONNX format models based on sd-2."""
type: Literal[ModelType.ONNX] = ModelType.ONNX
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
# No yaml config file for ONNX, so these are part of config
base: Literal[BaseModelType.StableDiffusion2] = BaseModelType.StableDiffusion2
prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction
upcast_attention: bool = True
class IPAdapterConfig(ModelConfigBase):
"""Model config for IP Adaptor format models."""
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
format: Literal[ModelFormat.InvokeAI]
class CLIPVisionDiffusersConfig(ModelConfigBase):
"""Model config for ClipVision."""
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
format: Literal[ModelFormat.Diffusers]
class T2IConfig(ModelConfigBase):
"""Model config for T2I."""
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
format: Literal[ModelFormat.Diffusers]
_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")]
_ControlNetConfig = Annotated[
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
Field(discriminator="format"),
]
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
AnyModelConfig = Union[
_MainModelConfig,
_ONNXConfig,
_VaeConfig,
_ControlNetConfig,
LoRAConfig,
TextualInversionConfig,
IPAdapterConfig,
CLIPVisionDiffusersConfig,
T2IConfig,
]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
# IMPLEMENTATION NOTE:
# The preferred alternative to the above is a discriminated Union as shown
# below. However, it breaks FastAPI when used as the input Body parameter in a route.
# This is a known issue. Please see:
# https://github.com/tiangolo/fastapi/discussions/9761 and
# https://github.com/tiangolo/fastapi/discussions/9287
# AnyModelConfig = Annotated[
# Union[
# _MainModelConfig,
# _ONNXConfig,
# _VaeConfig,
# _ControlNetConfig,
# LoRAConfig,
# TextualInversionConfig,
# IPAdapterConfig,
# CLIPVisionDiffusersConfig,
# T2IConfig,
# ],
# Field(discriminator="type"),
# ]
class ModelConfigFactory(object):
"""Class for parsing config dicts into StableDiffusion Config obects."""
@classmethod
def make_config(
cls,
model_data: Union[dict, AnyModelConfig],
key: Optional[str] = None,
dest_class: Optional[Type] = None,
) -> AnyModelConfig:
"""
Return the appropriate config object from raw dict values.
:param model_data: A raw dict corresponding the obect fields to be
parsed into a ModelConfigBase obect (or descendent), or a ModelConfigBase
object, which will be passed through unchanged.
:param dest_class: The config class to be returned. If not provided, will
be selected automatically.
"""
if isinstance(model_data, ModelConfigBase):
model = model_data
elif dest_class:
model = dest_class.validate_python(model_data)
else:
model = AnyModelConfigValidator.validate_python(model_data)
if key:
model.key = key
return model