InvokeAI/invokeai/backend/model_manager/config.py
2024-02-29 13:29:01 -05:00

337 lines
10 KiB
Python

# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Configuration definitions for image generation models.
Typical usage:
from invokeai.backend.model_manager import ModelConfigFactory
raw = dict(path='models/sd-1/main/foo.ckpt',
name='foo',
base='sd-1',
type='main',
config='configs/stable-diffusion/v1-inference.yaml',
variant='normal',
format='checkpoint'
)
config = ModelConfigFactory.make_config(raw)
print(config.name)
Validation errors will raise an InvalidModelConfigException error.
"""
import time
from enum import Enum
from typing import Literal, Optional, Type, Union
import torch
from diffusers import ModelMixin
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from typing_extensions import Annotated, Any, Dict
from ..raw_model import RawModel
# ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module]
class InvalidModelConfigException(Exception):
"""Exception for when config parser doesn't recognized this combination of model type and format."""
class BaseModelType(str, Enum):
"""Base model type."""
Any = "any"
StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2"
StableDiffusionXL = "sdxl"
StableDiffusionXLRefiner = "sdxl-refiner"
# Kandinsky2_1 = "kandinsky-2.1"
class ModelType(str, Enum):
"""Model type."""
ONNX = "onnx"
Main = "main"
Vae = "vae"
Lora = "lora"
ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding"
IPAdapter = "ip_adapter"
CLIPVision = "clip_vision"
T2IAdapter = "t2i_adapter"
class SubModelType(str, Enum):
"""Submodel type."""
UNet = "unet"
TextEncoder = "text_encoder"
TextEncoder2 = "text_encoder_2"
Tokenizer = "tokenizer"
Tokenizer2 = "tokenizer_2"
Vae = "vae"
VaeDecoder = "vae_decoder"
VaeEncoder = "vae_encoder"
Scheduler = "scheduler"
SafetyChecker = "safety_checker"
class ModelVariantType(str, Enum):
"""Variant type."""
Normal = "normal"
Inpaint = "inpaint"
Depth = "depth"
class ModelFormat(str, Enum):
"""Storage format of model."""
Diffusers = "diffusers"
Checkpoint = "checkpoint"
Lycoris = "lycoris"
Onnx = "onnx"
Olive = "olive"
EmbeddingFile = "embedding_file"
EmbeddingFolder = "embedding_folder"
InvokeAI = "invokeai"
class SchedulerPredictionType(str, Enum):
"""Scheduler prediction type."""
Epsilon = "epsilon"
VPrediction = "v_prediction"
Sample = "sample"
class ModelRepoVariant(str, Enum):
"""Various hugging face variants on the diffusers format."""
DEFAULT = "" # model files without "fp16" or other qualifier - empty str
FP16 = "fp16"
FP32 = "fp32"
ONNX = "onnx"
OPENVINO = "openvino"
FLAX = "flax"
class ModelConfigBase(BaseModel):
"""Base class for model configuration information."""
path: str = Field(description="filesystem path to the model file or directory")
name: str = Field(description="model name")
base: BaseModelType = Field(description="base model")
type: ModelType = Field(description="type of the model")
format: ModelFormat = Field(description="model format")
key: str = Field(description="unique key for model", default="<NOKEY>")
original_hash: Optional[str] = Field(
description="original fasthash of model contents", default=None
) # this is assigned at install time and will not change
current_hash: Optional[str] = Field(
description="current fasthash of model contents", default=None
) # if model is converted or otherwise modified, this will hold updated hash
description: Optional[str] = Field(description="human readable description of the model", default=None)
source: Optional[str] = Field(description="model original source (path, URL or repo_id)", default=None)
last_modified: Optional[float] = Field(description="timestamp for modification time", default_factory=time.time)
@staticmethod
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
schema["required"].extend(
["key", "base", "type", "format", "original_hash", "current_hash", "source", "last_modified"]
)
model_config = ConfigDict(
use_enum_values=False,
validate_assignment=True,
json_schema_extra=json_schema_extra,
)
def update(self, attributes: Dict[str, Any]) -> None:
"""Update the object with fields in dict."""
for key, value in attributes.items():
setattr(self, key, value) # may raise a validation error
class _CheckpointConfig(ModelConfigBase):
"""Model config for checkpoint-style models."""
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
config: str = Field(description="path to the checkpoint model config file")
class _DiffusersConfig(ModelConfigBase):
"""Model config for diffusers-style models."""
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT
class LoRAConfig(ModelConfigBase):
"""Model config for LoRA/Lycoris models."""
type: Literal[ModelType.Lora] = ModelType.Lora
format: Literal[ModelFormat.Lycoris, ModelFormat.Diffusers]
class VaeCheckpointConfig(ModelConfigBase):
"""Model config for standalone VAE models."""
type: Literal[ModelType.Vae] = ModelType.Vae
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
class VaeDiffusersConfig(ModelConfigBase):
"""Model config for standalone VAE models (diffusers version)."""
type: Literal[ModelType.Vae] = ModelType.Vae
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetDiffusersConfig(_DiffusersConfig):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
class ControlNetCheckpointConfig(_CheckpointConfig):
"""Model config for ControlNet models (diffusers version)."""
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
class TextualInversionConfig(ModelConfigBase):
"""Model config for textual inversion embeddings."""
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
format: Literal[ModelFormat.EmbeddingFile, ModelFormat.EmbeddingFolder]
class _MainConfig(ModelConfigBase):
"""Model config for main models."""
vae: Optional[str] = Field(default=None)
variant: ModelVariantType = ModelVariantType.Normal
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
ztsnr_training: bool = False
class MainCheckpointConfig(_CheckpointConfig, _MainConfig):
"""Model config for main checkpoint models."""
type: Literal[ModelType.Main] = ModelType.Main
class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
"""Model config for main diffusers models."""
type: Literal[ModelType.Main] = ModelType.Main
class IPAdapterConfig(ModelConfigBase):
"""Model config for IP Adaptor format models."""
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
image_encoder_model_id: str
format: Literal[ModelFormat.InvokeAI]
class CLIPVisionDiffusersConfig(ModelConfigBase):
"""Model config for ClipVision."""
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
format: Literal[ModelFormat.Diffusers]
class T2IConfig(ModelConfigBase):
"""Model config for T2I."""
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
format: Literal[ModelFormat.Diffusers]
_ControlNetConfig = Annotated[
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
Field(discriminator="format"),
]
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
AnyModelConfig = Union[
_MainModelConfig,
_VaeConfig,
_ControlNetConfig,
# ModelConfigBase,
LoRAConfig,
TextualInversionConfig,
IPAdapterConfig,
CLIPVisionDiffusersConfig,
T2IConfig,
]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
# IMPLEMENTATION NOTE:
# The preferred alternative to the above is a discriminated Union as shown
# below. However, it breaks FastAPI when used as the input Body parameter in a route.
# This is a known issue. Please see:
# https://github.com/tiangolo/fastapi/discussions/9761 and
# https://github.com/tiangolo/fastapi/discussions/9287
# AnyModelConfig = Annotated[
# Union[
# _MainModelConfig,
# _ONNXConfig,
# _VaeConfig,
# _ControlNetConfig,
# LoRAConfig,
# TextualInversionConfig,
# IPAdapterConfig,
# CLIPVisionDiffusersConfig,
# T2IConfig,
# ],
# Field(discriminator="type"),
# ]
class ModelConfigFactory(object):
"""Class for parsing config dicts into StableDiffusion Config obects."""
@classmethod
def make_config(
cls,
model_data: Union[Dict[str, Any], AnyModelConfig],
key: Optional[str] = None,
dest_class: Optional[Type[ModelConfigBase]] = None,
timestamp: Optional[float] = None,
) -> AnyModelConfig:
"""
Return the appropriate config object from raw dict values.
:param model_data: A raw dict corresponding the obect fields to be
parsed into a ModelConfigBase obect (or descendent), or a ModelConfigBase
object, which will be passed through unchanged.
:param dest_class: The config class to be returned. If not provided, will
be selected automatically.
"""
model: Optional[ModelConfigBase] = None
if isinstance(model_data, ModelConfigBase):
model = model_data
elif dest_class:
model = dest_class.model_validate(model_data)
else:
# mypy doesn't typecheck TypeAdapters well?
model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
assert model is not None
if key:
model.key = key
if timestamp:
model.last_modified = timestamp
return model # type: ignore