mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
525 lines
18 KiB
Python
525 lines
18 KiB
Python
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
|
"""
|
|
Configuration definitions for image generation models.
|
|
|
|
Typical usage:
|
|
|
|
from invokeai.backend.model_manager import ModelConfigFactory
|
|
raw = dict(path='models/sd-1/main/foo.ckpt',
|
|
name='foo',
|
|
base='sd-1',
|
|
type='main',
|
|
config='configs/stable-diffusion/v1-inference.yaml',
|
|
variant='normal',
|
|
format='checkpoint'
|
|
)
|
|
config = ModelConfigFactory.make_config(raw)
|
|
print(config.name)
|
|
|
|
Validation errors will raise an InvalidModelConfigException error.
|
|
|
|
"""
|
|
|
|
import time
|
|
from enum import Enum
|
|
from typing import Literal, Optional, Type, TypeAlias, Union
|
|
|
|
import diffusers
|
|
import torch
|
|
from diffusers.models.modeling_utils import ModelMixin
|
|
from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapter
|
|
from typing_extensions import Annotated, Any, Dict
|
|
|
|
from invokeai.app.util.misc import uuid_string
|
|
from invokeai.backend.model_hash.hash_validator import validate_hash
|
|
from invokeai.backend.raw_model import RawModel
|
|
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
|
|
|
# ModelMixin is the base class for all diffusers and transformers models
|
|
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
|
|
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor], diffusers.DiffusionPipeline]
|
|
|
|
|
|
class InvalidModelConfigException(Exception):
|
|
"""Exception for when config parser doesn't recognized this combination of model type and format."""
|
|
|
|
|
|
class BaseModelType(str, Enum):
|
|
"""Base model type."""
|
|
|
|
Any = "any"
|
|
StableDiffusion1 = "sd-1"
|
|
StableDiffusion2 = "sd-2"
|
|
StableDiffusionXL = "sdxl"
|
|
StableDiffusionXLRefiner = "sdxl-refiner"
|
|
Flux = "flux"
|
|
# Kandinsky2_1 = "kandinsky-2.1"
|
|
|
|
|
|
class ModelType(str, Enum):
|
|
"""Model type."""
|
|
|
|
ONNX = "onnx"
|
|
Main = "main"
|
|
VAE = "vae"
|
|
LoRA = "lora"
|
|
ControlNet = "controlnet" # used by model_probe
|
|
TextualInversion = "embedding"
|
|
IPAdapter = "ip_adapter"
|
|
CLIPVision = "clip_vision"
|
|
CLIPEmbed = "clip_embed"
|
|
T2IAdapter = "t2i_adapter"
|
|
T5Encoder = "t5_encoder"
|
|
SpandrelImageToImage = "spandrel_image_to_image"
|
|
|
|
|
|
class SubModelType(str, Enum):
|
|
"""Submodel type."""
|
|
|
|
UNet = "unet"
|
|
Transformer = "transformer"
|
|
TextEncoder = "text_encoder"
|
|
TextEncoder2 = "text_encoder_2"
|
|
Tokenizer = "tokenizer"
|
|
Tokenizer2 = "tokenizer_2"
|
|
VAE = "vae"
|
|
VAEDecoder = "vae_decoder"
|
|
VAEEncoder = "vae_encoder"
|
|
Scheduler = "scheduler"
|
|
SafetyChecker = "safety_checker"
|
|
|
|
|
|
class ModelVariantType(str, Enum):
|
|
"""Variant type."""
|
|
|
|
Normal = "normal"
|
|
Inpaint = "inpaint"
|
|
Depth = "depth"
|
|
|
|
|
|
class ModelFormat(str, Enum):
|
|
"""Storage format of model."""
|
|
|
|
Diffusers = "diffusers"
|
|
Checkpoint = "checkpoint"
|
|
LyCORIS = "lycoris"
|
|
ONNX = "onnx"
|
|
Olive = "olive"
|
|
EmbeddingFile = "embedding_file"
|
|
EmbeddingFolder = "embedding_folder"
|
|
InvokeAI = "invokeai"
|
|
T5Encoder = "t5_encoder"
|
|
BnbQuantizedLlmInt8b = "bnb_quantized_int8b"
|
|
BnbQuantizednf4b = "bnb_quantized_nf4b"
|
|
|
|
|
|
class SchedulerPredictionType(str, Enum):
|
|
"""Scheduler prediction type."""
|
|
|
|
Epsilon = "epsilon"
|
|
VPrediction = "v_prediction"
|
|
Sample = "sample"
|
|
|
|
|
|
class ModelRepoVariant(str, Enum):
|
|
"""Various hugging face variants on the diffusers format."""
|
|
|
|
Default = "" # model files without "fp16" or other qualifier
|
|
FP16 = "fp16"
|
|
FP32 = "fp32"
|
|
ONNX = "onnx"
|
|
OpenVINO = "openvino"
|
|
Flax = "flax"
|
|
|
|
|
|
class ModelSourceType(str, Enum):
|
|
"""Model source type."""
|
|
|
|
Path = "path"
|
|
Url = "url"
|
|
HFRepoID = "hf_repo_id"
|
|
|
|
|
|
DEFAULTS_PRECISION = Literal["fp16", "fp32"]
|
|
|
|
|
|
class MainModelDefaultSettings(BaseModel):
|
|
vae: str | None = Field(default=None, description="Default VAE for this model (model key)")
|
|
vae_precision: DEFAULTS_PRECISION | None = Field(default=None, description="Default VAE precision for this model")
|
|
scheduler: SCHEDULER_NAME_VALUES | None = Field(default=None, description="Default scheduler for this model")
|
|
steps: int | None = Field(default=None, gt=0, description="Default number of steps for this model")
|
|
cfg_scale: float | None = Field(default=None, ge=1, description="Default CFG Scale for this model")
|
|
cfg_rescale_multiplier: float | None = Field(
|
|
default=None, ge=0, lt=1, description="Default CFG Rescale Multiplier for this model"
|
|
)
|
|
width: int | None = Field(default=None, multiple_of=8, ge=64, description="Default width for this model")
|
|
height: int | None = Field(default=None, multiple_of=8, ge=64, description="Default height for this model")
|
|
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
|
|
class ControlAdapterDefaultSettings(BaseModel):
|
|
# This could be narrowed to controlnet processor nodes, but they change. Leaving this a string is safer.
|
|
preprocessor: str | None
|
|
|
|
model_config = ConfigDict(extra="forbid")
|
|
|
|
|
|
class ModelConfigBase(BaseModel):
|
|
"""Base class for model configuration information."""
|
|
|
|
key: str = Field(description="A unique key for this model.", default_factory=uuid_string)
|
|
hash: str = Field(description="The hash of the model file(s).")
|
|
path: str = Field(
|
|
description="Path to the model on the filesystem. Relative paths are relative to the Invoke root directory."
|
|
)
|
|
name: str = Field(description="Name of the model.")
|
|
base: BaseModelType = Field(description="The base model.")
|
|
description: Optional[str] = Field(description="Model description", default=None)
|
|
source: str = Field(description="The original source of the model (path, URL or repo_id).")
|
|
source_type: ModelSourceType = Field(description="The type of source")
|
|
source_api_response: Optional[str] = Field(
|
|
description="The original API response from the source, as stringified JSON.", default=None
|
|
)
|
|
cover_image: Optional[str] = Field(description="Url for image to preview model", default=None)
|
|
|
|
@staticmethod
|
|
def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None:
|
|
schema["required"].extend(["key", "type", "format"])
|
|
|
|
model_config = ConfigDict(validate_assignment=True, json_schema_extra=json_schema_extra)
|
|
|
|
|
|
class CheckpointConfigBase(ModelConfigBase):
|
|
"""Model config for checkpoint-style models."""
|
|
|
|
format: Literal[ModelFormat.Checkpoint, ModelFormat.BnbQuantizednf4b] = Field(
|
|
description="Format of the provided checkpoint model", default=ModelFormat.Checkpoint
|
|
)
|
|
config_path: str = Field(description="path to the checkpoint model config file")
|
|
converted_at: Optional[float] = Field(
|
|
description="When this model was last converted to diffusers", default_factory=time.time
|
|
)
|
|
|
|
|
|
class DiffusersConfigBase(ModelConfigBase):
|
|
"""Model config for diffusers-style models."""
|
|
|
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
|
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default
|
|
|
|
|
|
class LoRAConfigBase(ModelConfigBase):
|
|
type: Literal[ModelType.LoRA] = ModelType.LoRA
|
|
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
|
|
|
|
|
class T5EncoderConfigBase(ModelConfigBase):
|
|
type: Literal[ModelType.T5Encoder] = ModelType.T5Encoder
|
|
|
|
|
|
class T5EncoderConfig(T5EncoderConfigBase):
|
|
format: Literal[ModelFormat.T5Encoder] = ModelFormat.T5Encoder
|
|
|
|
@staticmethod
|
|
def get_tag() -> Tag:
|
|
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.T5Encoder.value}")
|
|
|
|
|
|
class T5EncoderBnbQuantizedLlmInt8bConfig(T5EncoderConfigBase):
|
|
format: Literal[ModelFormat.BnbQuantizedLlmInt8b] = ModelFormat.BnbQuantizedLlmInt8b
|
|
|
|
@staticmethod
|
|
def get_tag() -> Tag:
|
|
return Tag(f"{ModelType.T5Encoder.value}.{ModelFormat.BnbQuantizedLlmInt8b.value}")
|
|
|
|
|
|
class LoRALyCORISConfig(LoRAConfigBase):
|
|
"""Model config for LoRA/Lycoris models."""
|
|
|
|
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
|
|
|
|
@staticmethod
|
|
def get_tag() -> Tag:
|
|
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
|
|
|
|
|
|
class LoRADiffusersConfig(LoRAConfigBase):
|
|
"""Model config for LoRA/Diffusers models."""
|
|
|
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
|
|
|
@staticmethod
|
|
def get_tag() -> Tag:
|
|
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.Diffusers.value}")
|
|
|
|
|
|
class VAECheckpointConfig(CheckpointConfigBase):
|
|
"""Model config for standalone VAE models."""
|
|
|
|
type: Literal[ModelType.VAE] = ModelType.VAE
|
|
|
|
@staticmethod
|
|
def get_tag() -> Tag:
|
|
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Checkpoint.value}")
|
|
|
|
|
|
class VAEDiffusersConfig(ModelConfigBase):
|
|
"""Model config for standalone VAE models (diffusers version)."""
|
|
|
|
type: Literal[ModelType.VAE] = ModelType.VAE
|
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
|
|
|
@staticmethod
|
|
def get_tag() -> Tag:
|
|
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Diffusers.value}")
|
|
|
|
|
|
class ControlAdapterConfigBase(BaseModel):
|
|
default_settings: Optional[ControlAdapterDefaultSettings] = Field(
|
|
description="Default settings for this model", default=None
|
|
)
|
|
|
|
|
|
class ControlNetDiffusersConfig(DiffusersConfigBase, ControlAdapterConfigBase):
|
|
"""Model config for ControlNet models (diffusers version)."""
|
|
|
|
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
|
|
|
@staticmethod
|
|
def get_tag() -> Tag:
|
|
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Diffusers.value}")
|
|
|
|
|
|
class ControlNetCheckpointConfig(CheckpointConfigBase, ControlAdapterConfigBase):
|
|
"""Model config for ControlNet models (diffusers version)."""
|
|
|
|
type: Literal[ModelType.ControlNet] = ModelType.ControlNet
|
|
|
|
@staticmethod
|
|
def get_tag() -> Tag:
|
|
return Tag(f"{ModelType.ControlNet.value}.{ModelFormat.Checkpoint.value}")
|
|
|
|
|
|
class TextualInversionFileConfig(ModelConfigBase):
|
|
"""Model config for textual inversion embeddings."""
|
|
|
|
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
|
format: Literal[ModelFormat.EmbeddingFile] = ModelFormat.EmbeddingFile
|
|
|
|
@staticmethod
|
|
def get_tag() -> Tag:
|
|
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFile.value}")
|
|
|
|
|
|
class TextualInversionFolderConfig(ModelConfigBase):
|
|
"""Model config for textual inversion embeddings."""
|
|
|
|
type: Literal[ModelType.TextualInversion] = ModelType.TextualInversion
|
|
format: Literal[ModelFormat.EmbeddingFolder] = ModelFormat.EmbeddingFolder
|
|
|
|
@staticmethod
|
|
def get_tag() -> Tag:
|
|
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")
|
|
|
|
|
|
class MainConfigBase(ModelConfigBase):
|
|
type: Literal[ModelType.Main] = ModelType.Main
|
|
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
|
|
default_settings: Optional[MainModelDefaultSettings] = Field(
|
|
description="Default settings for this model", default=None
|
|
)
|
|
variant: ModelVariantType = ModelVariantType.Normal
|
|
|
|
|
|
class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
|
"""Model config for main checkpoint models."""
|
|
|
|
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
|
upcast_attention: bool = False
|
|
|
|
@staticmethod
|
|
def get_tag() -> Tag:
|
|
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")
|
|
|
|
|
|
class MainBnbQuantized4bCheckpointConfig(CheckpointConfigBase, MainConfigBase):
|
|
"""Model config for main checkpoint models."""
|
|
|
|
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
|
upcast_attention: bool = False
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.format = ModelFormat.BnbQuantizednf4b
|
|
|
|
@staticmethod
|
|
def get_tag() -> Tag:
|
|
return Tag(f"{ModelType.Main.value}.{ModelFormat.BnbQuantizednf4b.value}")
|
|
|
|
|
|
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
|
|
"""Model config for main diffusers models."""
|
|
|
|
@staticmethod
|
|
def get_tag() -> Tag:
|
|
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
|
|
|
|
|
|
class IPAdapterBaseConfig(ModelConfigBase):
|
|
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
|
|
|
|
|
|
class IPAdapterInvokeAIConfig(IPAdapterBaseConfig):
|
|
"""Model config for IP Adapter diffusers format models."""
|
|
|
|
image_encoder_model_id: str
|
|
format: Literal[ModelFormat.InvokeAI]
|
|
|
|
@staticmethod
|
|
def get_tag() -> Tag:
|
|
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}")
|
|
|
|
|
|
class IPAdapterCheckpointConfig(IPAdapterBaseConfig):
|
|
"""Model config for IP Adapter checkpoint format models."""
|
|
|
|
format: Literal[ModelFormat.Checkpoint]
|
|
|
|
@staticmethod
|
|
def get_tag() -> Tag:
|
|
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.Checkpoint.value}")
|
|
|
|
|
|
class CLIPEmbedDiffusersConfig(DiffusersConfigBase):
|
|
"""Model config for Clip Embeddings."""
|
|
|
|
type: Literal[ModelType.CLIPEmbed] = ModelType.CLIPEmbed
|
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
|
|
|
@staticmethod
|
|
def get_tag() -> Tag:
|
|
return Tag(f"{ModelType.CLIPEmbed.value}.{ModelFormat.Diffusers.value}")
|
|
|
|
|
|
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
|
|
"""Model config for CLIPVision."""
|
|
|
|
type: Literal[ModelType.CLIPVision] = ModelType.CLIPVision
|
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
|
|
|
@staticmethod
|
|
def get_tag() -> Tag:
|
|
return Tag(f"{ModelType.CLIPVision.value}.{ModelFormat.Diffusers.value}")
|
|
|
|
|
|
class T2IAdapterConfig(DiffusersConfigBase, ControlAdapterConfigBase):
|
|
"""Model config for T2I."""
|
|
|
|
type: Literal[ModelType.T2IAdapter] = ModelType.T2IAdapter
|
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
|
|
|
@staticmethod
|
|
def get_tag() -> Tag:
|
|
return Tag(f"{ModelType.T2IAdapter.value}.{ModelFormat.Diffusers.value}")
|
|
|
|
|
|
class SpandrelImageToImageConfig(ModelConfigBase):
|
|
"""Model config for Spandrel Image to Image models."""
|
|
|
|
type: Literal[ModelType.SpandrelImageToImage] = ModelType.SpandrelImageToImage
|
|
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
|
|
|
@staticmethod
|
|
def get_tag() -> Tag:
|
|
return Tag(f"{ModelType.SpandrelImageToImage.value}.{ModelFormat.Checkpoint.value}")
|
|
|
|
|
|
def get_model_discriminator_value(v: Any) -> str:
|
|
"""
|
|
Computes the discriminator value for a model config.
|
|
https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions-with-callable-discriminator
|
|
"""
|
|
format_ = None
|
|
type_ = None
|
|
if isinstance(v, dict):
|
|
format_ = v.get("format")
|
|
if isinstance(format_, Enum):
|
|
format_ = format_.value
|
|
type_ = v.get("type")
|
|
if isinstance(type_, Enum):
|
|
type_ = type_.value
|
|
else:
|
|
format_ = v.format.value
|
|
type_ = v.type.value
|
|
v = f"{type_}.{format_}"
|
|
return v
|
|
|
|
|
|
AnyModelConfig = Annotated[
|
|
Union[
|
|
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
|
|
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
|
|
Annotated[MainBnbQuantized4bCheckpointConfig, MainBnbQuantized4bCheckpointConfig.get_tag()],
|
|
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
|
|
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
|
|
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
|
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
|
|
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
|
|
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
|
Annotated[T5EncoderConfig, T5EncoderConfig.get_tag()],
|
|
Annotated[T5EncoderBnbQuantizedLlmInt8bConfig, T5EncoderBnbQuantizedLlmInt8bConfig.get_tag()],
|
|
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
|
|
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
|
|
Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
|
|
Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()],
|
|
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
|
Annotated[SpandrelImageToImageConfig, SpandrelImageToImageConfig.get_tag()],
|
|
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
|
Annotated[CLIPEmbedDiffusersConfig, CLIPEmbedDiffusersConfig.get_tag()],
|
|
],
|
|
Discriminator(get_model_discriminator_value),
|
|
]
|
|
|
|
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
|
AnyDefaultSettings: TypeAlias = Union[MainModelDefaultSettings, ControlAdapterDefaultSettings]
|
|
|
|
|
|
class ModelConfigFactory(object):
|
|
"""Class for parsing config dicts into StableDiffusion Config obects."""
|
|
|
|
@classmethod
|
|
def make_config(
|
|
cls,
|
|
model_data: Union[Dict[str, Any], AnyModelConfig],
|
|
key: Optional[str] = None,
|
|
dest_class: Optional[Type[ModelConfigBase]] = None,
|
|
timestamp: Optional[float] = None,
|
|
) -> AnyModelConfig:
|
|
"""
|
|
Return the appropriate config object from raw dict values.
|
|
|
|
:param model_data: A raw dict corresponding the obect fields to be
|
|
parsed into a ModelConfigBase obect (or descendent), or a ModelConfigBase
|
|
object, which will be passed through unchanged.
|
|
:param dest_class: The config class to be returned. If not provided, will
|
|
be selected automatically.
|
|
"""
|
|
model: Optional[ModelConfigBase] = None
|
|
if isinstance(model_data, ModelConfigBase):
|
|
model = model_data
|
|
elif dest_class:
|
|
model = dest_class.model_validate(model_data)
|
|
else:
|
|
# mypy doesn't typecheck TypeAdapters well?
|
|
model = AnyModelConfigValidator.validate_python(model_data) # type: ignore
|
|
assert model is not None
|
|
if key:
|
|
model.key = key
|
|
if isinstance(model, CheckpointConfigBase) and timestamp is not None:
|
|
model.converted_at = timestamp
|
|
if model:
|
|
validate_hash(model.hash)
|
|
return model # type: ignore
|