2023-11-05 03:03:26 +00:00
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""
Configuration definitions for image generation models .
Typical usage :
from invokeai . backend . model_manager import ModelConfigFactory
raw = dict ( path = ' models/sd-1/main/foo.ckpt ' ,
name = ' foo ' ,
2023-11-06 23:08:57 +00:00
base = ' sd-1 ' ,
type = ' main ' ,
2023-11-05 03:03:26 +00:00
config = ' configs/stable-diffusion/v1-inference.yaml ' ,
variant = ' normal ' ,
format = ' checkpoint '
)
config = ModelConfigFactory . make_config ( raw )
print ( config . name )
Validation errors will raise an InvalidModelConfigException error .
"""
2024-02-29 23:04:59 +00:00
2024-02-04 03:55:09 +00:00
import time
2023-11-05 03:03:26 +00:00
from enum import Enum
2024-03-08 04:37:31 +00:00
from typing import Literal , Optional , Type , TypeAlias , Union
2023-11-05 03:03:26 +00:00
2024-06-27 21:31:28 +00:00
import diffusers
2024-02-04 22:23:10 +00:00
import torch
2024-03-01 02:05:16 +00:00
from diffusers . models . modeling_utils import ModelMixin
2024-03-04 08:17:01 +00:00
from pydantic import BaseModel , ConfigDict , Discriminator , Field , Tag , TypeAdapter
2023-11-24 04:15:32 +00:00
from typing_extensions import Annotated , Any , Dict
2024-02-04 22:23:10 +00:00
2024-03-04 10:38:21 +00:00
from invokeai . app . util . misc import uuid_string
2024-06-13 20:34:27 +00:00
from invokeai . backend . model_hash . hash_validator import validate_hash
2024-07-03 16:04:22 +00:00
from invokeai . backend . raw_model import RawModel
2024-07-03 15:13:16 +00:00
from invokeai . backend . stable_diffusion . schedulers . schedulers import SCHEDULER_NAME_VALUES
2024-02-05 04:18:00 +00:00
2024-02-17 16:45:32 +00:00
# ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
2024-06-27 21:31:28 +00:00
AnyModel = Union [ ModelMixin , RawModel , torch . nn . Module , Dict [ str , torch . Tensor ] , diffusers . DiffusionPipeline ]
2024-02-06 03:56:32 +00:00
2024-02-05 04:18:00 +00:00
2023-11-05 03:03:26 +00:00
class InvalidModelConfigException ( Exception ) :
""" Exception for when config parser doesn ' t recognized this combination of model type and format. """
class BaseModelType ( str , Enum ) :
""" Base model type. """
Any = " any "
StableDiffusion1 = " sd-1 "
StableDiffusion2 = " sd-2 "
StableDiffusionXL = " sdxl "
StableDiffusionXLRefiner = " sdxl-refiner "
2024-08-12 18:04:23 +00:00
Flux = " flux "
2023-11-05 03:03:26 +00:00
# Kandinsky2_1 = "kandinsky-2.1"
class ModelType ( str , Enum ) :
""" Model type. """
ONNX = " onnx "
Main = " main "
2024-03-05 06:37:17 +00:00
VAE = " vae "
LoRA = " lora "
2023-11-05 03:03:26 +00:00
ControlNet = " controlnet " # used by model_probe
TextualInversion = " embedding "
IPAdapter = " ip_adapter "
CLIPVision = " clip_vision "
2024-08-16 21:04:48 +00:00
CLIPEmbed = " clip_embed "
2023-11-05 03:03:26 +00:00
T2IAdapter = " t2i_adapter "
2024-08-16 21:04:48 +00:00
T5Encoder = " t5_encoder "
2024-06-28 19:01:42 +00:00
SpandrelImageToImage = " spandrel_image_to_image "
2023-11-05 03:03:26 +00:00
class SubModelType ( str , Enum ) :
""" Submodel type. """
UNet = " unet "
2024-08-12 18:04:23 +00:00
Transformer = " transformer "
2023-11-05 03:03:26 +00:00
TextEncoder = " text_encoder "
TextEncoder2 = " text_encoder_2 "
Tokenizer = " tokenizer "
Tokenizer2 = " tokenizer_2 "
2024-03-05 06:37:17 +00:00
VAE = " vae "
VAEDecoder = " vae_decoder "
VAEEncoder = " vae_encoder "
2023-11-05 03:03:26 +00:00
Scheduler = " scheduler "
SafetyChecker = " safety_checker "
class ModelVariantType ( str , Enum ) :
""" Variant type. """
Normal = " normal "
Inpaint = " inpaint "
Depth = " depth "
class ModelFormat ( str , Enum ) :
""" Storage format of model. """
Diffusers = " diffusers "
Checkpoint = " checkpoint "
2024-03-05 06:37:17 +00:00
LyCORIS = " lycoris "
ONNX = " onnx "
2023-11-05 03:03:26 +00:00
Olive = " olive "
EmbeddingFile = " embedding_file "
EmbeddingFolder = " embedding_folder "
InvokeAI = " invokeai "
2024-08-16 21:04:48 +00:00
T5Encoder = " t5_encoder "
T5Encoder8b = " t5_encoder_8b "
T5Encoder4b = " t5_encoder_4b "
2024-08-19 16:08:24 +00:00
BnbQuantizednf4b = " bnb_quantized_nf4b "
2023-11-05 03:03:26 +00:00
class SchedulerPredictionType ( str , Enum ) :
""" Scheduler prediction type. """
Epsilon = " epsilon "
VPrediction = " v_prediction "
Sample = " sample "
2024-01-14 19:54:53 +00:00
class ModelRepoVariant ( str , Enum ) :
""" Various hugging face variants on the diffusers format. """
2024-05-09 04:21:01 +00:00
Default = " " # model files without "fp16" or other qualifier
2024-01-14 19:54:53 +00:00
FP16 = " fp16 "
FP32 = " fp32 "
ONNX = " onnx "
2024-03-05 06:37:17 +00:00
OpenVINO = " openvino "
Flax = " flax "
2024-01-14 19:54:53 +00:00
2024-03-01 11:12:48 +00:00
class ModelSourceType ( str , Enum ) :
""" Model source type. """
Path = " path "
Url = " url "
HFRepoID = " hf_repo_id "
2024-03-12 09:07:53 +00:00
DEFAULTS_PRECISION = Literal [ " fp16 " , " fp32 " ]
2024-03-08 04:32:02 +00:00
class MainModelDefaultSettings ( BaseModel ) :
2024-03-12 09:07:53 +00:00
vae : str | None = Field ( default = None , description = " Default VAE for this model (model key) " )
vae_precision : DEFAULTS_PRECISION | None = Field ( default = None , description = " Default VAE precision for this model " )
scheduler : SCHEDULER_NAME_VALUES | None = Field ( default = None , description = " Default scheduler for this model " )
steps : int | None = Field ( default = None , gt = 0 , description = " Default number of steps for this model " )
cfg_scale : float | None = Field ( default = None , ge = 1 , description = " Default CFG Scale for this model " )
cfg_rescale_multiplier : float | None = Field (
default = None , ge = 0 , lt = 1 , description = " Default CFG Rescale Multiplier for this model "
)
2024-03-12 09:08:09 +00:00
width : int | None = Field ( default = None , multiple_of = 8 , ge = 64 , description = " Default width for this model " )
height : int | None = Field ( default = None , multiple_of = 8 , ge = 64 , description = " Default height for this model " )
2024-03-05 00:24:25 +00:00
2024-03-25 05:10:58 +00:00
model_config = ConfigDict ( extra = " forbid " )
2024-03-05 00:24:25 +00:00
2024-03-08 04:33:23 +00:00
class ControlAdapterDefaultSettings ( BaseModel ) :
# This could be narrowed to controlnet processor nodes, but they change. Leaving this a string is safer.
preprocessor : str | None
2024-03-25 05:10:58 +00:00
model_config = ConfigDict ( extra = " forbid " )
2024-03-08 04:33:23 +00:00
2023-11-05 03:03:26 +00:00
class ModelConfigBase ( BaseModel ) :
""" Base class for model configuration information. """
2024-03-04 10:38:21 +00:00
key : str = Field ( description = " A unique key for this model. " , default_factory = uuid_string )
2024-03-01 12:04:33 +00:00
hash : str = Field ( description = " The hash of the model file(s). " )
path : str = Field (
description = " Path to the model on the filesystem. Relative paths are relative to the Invoke root directory. "
)
name : str = Field ( description = " Name of the model. " )
base : BaseModelType = Field ( description = " The base model. " )
2024-03-01 11:12:48 +00:00
description : Optional [ str ] = Field ( description = " Model description " , default = None )
2024-03-01 12:04:33 +00:00
source : str = Field ( description = " The original source of the model (path, URL or repo_id). " )
2024-03-01 11:12:48 +00:00
source_type : ModelSourceType = Field ( description = " The type of source " )
2024-03-04 10:38:21 +00:00
source_api_response : Optional [ str ] = Field (
description = " The original API response from the source, as stringified JSON. " , default = None
)
2024-03-06 18:15:33 +00:00
cover_image : Optional [ str ] = Field ( description = " Url for image to preview model " , default = None )
2023-11-05 03:03:26 +00:00
2024-03-05 01:35:52 +00:00
@staticmethod
def json_schema_extra ( schema : dict [ str , Any ] , model_class : Type [ BaseModel ] ) - > None :
2024-03-05 01:44:38 +00:00
schema [ " required " ] . extend ( [ " key " , " type " , " format " ] )
2024-03-05 01:35:52 +00:00
model_config = ConfigDict ( validate_assignment = True , json_schema_extra = json_schema_extra )
2023-11-05 03:03:26 +00:00
2024-03-01 02:18:31 +00:00
class CheckpointConfigBase ( ModelConfigBase ) :
2023-11-05 03:03:26 +00:00
""" Model config for checkpoint-style models. """
2024-08-19 16:08:24 +00:00
format : Literal [ ModelFormat . Checkpoint , ModelFormat . BnbQuantizednf4b ] = Field ( description = " Format of the provided checkpoint model " , default = ModelFormat . Checkpoint )
2024-03-01 04:25:21 +00:00
config_path : str = Field ( description = " path to the checkpoint model config file " )
2024-03-01 04:27:41 +00:00
converted_at : Optional [ float ] = Field (
2024-03-01 04:21:35 +00:00
description = " When this model was last converted to diffusers " , default_factory = time . time
)
2023-11-05 03:03:26 +00:00
2024-03-01 02:18:31 +00:00
class DiffusersConfigBase ( ModelConfigBase ) :
2023-11-05 03:03:26 +00:00
""" Model config for diffusers-style models. """
format : Literal [ ModelFormat . Diffusers ] = ModelFormat . Diffusers
2024-03-05 06:37:17 +00:00
repo_variant : Optional [ ModelRepoVariant ] = ModelRepoVariant . Default
2023-11-05 03:03:26 +00:00
2024-02-01 04:37:59 +00:00
2024-03-07 04:36:18 +00:00
class LoRAConfigBase ( ModelConfigBase ) :
type : Literal [ ModelType . LoRA ] = ModelType . LoRA
trigger_phrases : Optional [ set [ str ] ] = Field ( description = " Set of trigger phrases for this model " , default = None )
2024-08-16 21:04:48 +00:00
class T5EncoderConfigBase ( ModelConfigBase ) :
type : Literal [ ModelType . T5Encoder ] = ModelType . T5Encoder
class T5EncoderConfig ( T5EncoderConfigBase ) :
format : Literal [ ModelFormat . T5Encoder ] = ModelFormat . T5Encoder
@staticmethod
def get_tag ( ) - > Tag :
return Tag ( f " { ModelType . T5Encoder . value } . { ModelFormat . T5Encoder . value } " )
2024-08-20 16:37:12 +00:00
class T5Encoder8bConfig ( T5EncoderConfigBase ) :
format : Literal [ ModelFormat . T5Encoder8b ] = ModelFormat . T5Encoder8b
@staticmethod
def get_tag ( ) - > Tag :
return Tag ( f " { ModelType . T5Encoder . value } . { ModelFormat . T5Encoder8b . value } " )
2024-03-07 04:36:18 +00:00
class LoRALyCORISConfig ( LoRAConfigBase ) :
2023-11-05 03:03:26 +00:00
""" Model config for LoRA/Lycoris models. """
2024-03-05 06:37:17 +00:00
format : Literal [ ModelFormat . LyCORIS ] = ModelFormat . LyCORIS
2024-03-01 01:57:46 +00:00
@staticmethod
def get_tag ( ) - > Tag :
2024-03-05 06:37:17 +00:00
return Tag ( f " { ModelType . LoRA . value } . { ModelFormat . LyCORIS . value } " )
2024-03-01 01:57:46 +00:00
2024-03-07 04:36:18 +00:00
class LoRADiffusersConfig ( LoRAConfigBase ) :
2024-03-01 01:57:46 +00:00
""" Model config for LoRA/Diffusers models. """
format : Literal [ ModelFormat . Diffusers ] = ModelFormat . Diffusers
@staticmethod
def get_tag ( ) - > Tag :
2024-03-05 06:37:17 +00:00
return Tag ( f " { ModelType . LoRA . value } . { ModelFormat . Diffusers . value } " )
2023-11-05 03:03:26 +00:00
2024-03-05 06:37:17 +00:00
class VAECheckpointConfig ( CheckpointConfigBase ) :
2023-11-05 03:03:26 +00:00
""" Model config for standalone VAE models. """
2024-03-05 06:37:17 +00:00
type : Literal [ ModelType . VAE ] = ModelType . VAE
2023-11-05 03:03:26 +00:00
2024-03-01 01:57:46 +00:00
@staticmethod
def get_tag ( ) - > Tag :
2024-03-05 06:37:17 +00:00
return Tag ( f " { ModelType . VAE . value } . { ModelFormat . Checkpoint . value } " )
2024-03-01 01:57:46 +00:00
2023-11-05 03:03:26 +00:00
2024-03-05 06:37:17 +00:00
class VAEDiffusersConfig ( ModelConfigBase ) :
2023-11-05 03:03:26 +00:00
""" Model config for standalone VAE models (diffusers version). """
2024-03-05 06:37:17 +00:00
type : Literal [ ModelType . VAE ] = ModelType . VAE
2023-11-05 03:03:26 +00:00
format : Literal [ ModelFormat . Diffusers ] = ModelFormat . Diffusers
2024-03-01 01:57:46 +00:00
@staticmethod
def get_tag ( ) - > Tag :
2024-03-05 06:37:17 +00:00
return Tag ( f " { ModelType . VAE . value } . { ModelFormat . Diffusers . value } " )
2024-03-01 01:57:46 +00:00
2023-11-05 03:03:26 +00:00
2024-03-08 04:33:23 +00:00
class ControlAdapterConfigBase ( BaseModel ) :
default_settings : Optional [ ControlAdapterDefaultSettings ] = Field (
description = " Default settings for this model " , default = None
)
class ControlNetDiffusersConfig ( DiffusersConfigBase , ControlAdapterConfigBase ) :
2023-11-05 03:03:26 +00:00
""" Model config for ControlNet models (diffusers version). """
2023-11-11 00:14:15 +00:00
type : Literal [ ModelType . ControlNet ] = ModelType . ControlNet
2023-11-05 03:03:26 +00:00
format : Literal [ ModelFormat . Diffusers ] = ModelFormat . Diffusers
2024-03-01 01:57:46 +00:00
@staticmethod
def get_tag ( ) - > Tag :
2024-03-04 11:36:52 +00:00
return Tag ( f " { ModelType . ControlNet . value } . { ModelFormat . Diffusers . value } " )
2024-03-01 01:57:46 +00:00
2024-02-01 04:37:59 +00:00
2024-03-08 04:33:23 +00:00
class ControlNetCheckpointConfig ( CheckpointConfigBase , ControlAdapterConfigBase ) :
2023-11-05 03:03:26 +00:00
""" Model config for ControlNet models (diffusers version). """
2023-11-11 00:14:15 +00:00
type : Literal [ ModelType . ControlNet ] = ModelType . ControlNet
2023-11-05 03:03:26 +00:00
2024-03-01 01:57:46 +00:00
@staticmethod
def get_tag ( ) - > Tag :
2024-03-04 11:36:52 +00:00
return Tag ( f " { ModelType . ControlNet . value } . { ModelFormat . Checkpoint . value } " )
2024-03-01 01:57:46 +00:00
class TextualInversionFileConfig ( ModelConfigBase ) :
""" Model config for textual inversion embeddings. """
type : Literal [ ModelType . TextualInversion ] = ModelType . TextualInversion
format : Literal [ ModelFormat . EmbeddingFile ] = ModelFormat . EmbeddingFile
@staticmethod
def get_tag ( ) - > Tag :
2024-03-04 11:36:52 +00:00
return Tag ( f " { ModelType . TextualInversion . value } . { ModelFormat . EmbeddingFile . value } " )
2024-03-01 01:57:46 +00:00
2023-11-05 03:03:26 +00:00
2024-03-01 01:57:46 +00:00
class TextualInversionFolderConfig ( ModelConfigBase ) :
2023-11-05 03:03:26 +00:00
""" Model config for textual inversion embeddings. """
2023-11-11 00:14:15 +00:00
type : Literal [ ModelType . TextualInversion ] = ModelType . TextualInversion
2024-03-01 01:57:46 +00:00
format : Literal [ ModelFormat . EmbeddingFolder ] = ModelFormat . EmbeddingFolder
@staticmethod
def get_tag ( ) - > Tag :
2024-03-04 11:36:52 +00:00
return Tag ( f " { ModelType . TextualInversion . value } . { ModelFormat . EmbeddingFolder . value } " )
2023-11-05 03:03:26 +00:00
2024-03-07 04:36:18 +00:00
class MainConfigBase ( ModelConfigBase ) :
type : Literal [ ModelType . Main ] = ModelType . Main
trigger_phrases : Optional [ set [ str ] ] = Field ( description = " Set of trigger phrases for this model " , default = None )
2024-03-08 04:32:02 +00:00
default_settings : Optional [ MainModelDefaultSettings ] = Field (
description = " Default settings for this model " , default = None
)
2024-04-23 09:48:47 +00:00
variant : ModelVariantType = ModelVariantType . Normal
2024-03-07 04:36:18 +00:00
class MainCheckpointConfig ( CheckpointConfigBase , MainConfigBase ) :
2024-03-01 04:21:35 +00:00
""" Model config for main checkpoint models. """
2023-11-05 03:03:26 +00:00
2024-02-04 22:23:10 +00:00
prediction_type : SchedulerPredictionType = SchedulerPredictionType . Epsilon
upcast_attention : bool = False
2023-11-05 03:03:26 +00:00
2024-03-01 01:57:46 +00:00
@staticmethod
def get_tag ( ) - > Tag :
2024-03-04 11:36:52 +00:00
return Tag ( f " { ModelType . Main . value } . { ModelFormat . Checkpoint . value } " )
2024-03-01 01:57:46 +00:00
2023-11-05 03:03:26 +00:00
2024-08-19 16:08:24 +00:00
class MainBnbQuantized4bCheckpointConfig ( CheckpointConfigBase , MainConfigBase ) :
""" Model config for main checkpoint models. """
prediction_type : SchedulerPredictionType = SchedulerPredictionType . Epsilon
upcast_attention : bool = False
def __init__ ( self , * args , * * kwargs ) :
super ( ) . __init__ ( * args , * * kwargs )
self . format = ModelFormat . BnbQuantizednf4b
@staticmethod
def get_tag ( ) - > Tag :
return Tag ( f " { ModelType . Main . value } . { ModelFormat . BnbQuantizednf4b . value } " )
2024-03-07 04:36:18 +00:00
class MainDiffusersConfig ( DiffusersConfigBase , MainConfigBase ) :
2023-11-05 03:03:26 +00:00
""" Model config for main diffusers models. """
2024-03-01 01:57:46 +00:00
@staticmethod
def get_tag ( ) - > Tag :
2024-03-04 11:36:52 +00:00
return Tag ( f " { ModelType . Main . value } . { ModelFormat . Diffusers . value } " )
2024-03-01 01:57:46 +00:00
2024-02-01 04:37:59 +00:00
2024-03-23 20:10:28 +00:00
class IPAdapterBaseConfig ( ModelConfigBase ) :
2023-11-11 00:14:15 +00:00
type : Literal [ ModelType . IPAdapter ] = ModelType . IPAdapter
2024-03-23 20:10:28 +00:00
2024-03-29 06:20:18 +00:00
class IPAdapterInvokeAIConfig ( IPAdapterBaseConfig ) :
2024-03-23 20:10:28 +00:00
""" Model config for IP Adapter diffusers format models. """
2024-02-10 01:46:47 +00:00
image_encoder_model_id : str
2023-11-05 03:03:26 +00:00
format : Literal [ ModelFormat . InvokeAI ]
2024-03-01 01:57:46 +00:00
@staticmethod
def get_tag ( ) - > Tag :
2024-03-04 11:36:52 +00:00
return Tag ( f " { ModelType . IPAdapter . value } . { ModelFormat . InvokeAI . value } " )
2024-03-01 01:57:46 +00:00
2023-11-05 03:03:26 +00:00
2024-03-23 20:10:28 +00:00
class IPAdapterCheckpointConfig ( IPAdapterBaseConfig ) :
""" Model config for IP Adapter checkpoint format models. """
format : Literal [ ModelFormat . Checkpoint ]
@staticmethod
def get_tag ( ) - > Tag :
return Tag ( f " { ModelType . IPAdapter . value } . { ModelFormat . Checkpoint . value } " )
2024-08-16 21:04:48 +00:00
class CLIPEmbedDiffusersConfig ( DiffusersConfigBase ) :
""" Model config for Clip Embeddings. """
type : Literal [ ModelType . CLIPEmbed ] = ModelType . CLIPEmbed
format : Literal [ ModelFormat . Diffusers ] = ModelFormat . Diffusers
@staticmethod
def get_tag ( ) - > Tag :
return Tag ( f " { ModelType . CLIPEmbed . value } . { ModelFormat . Diffusers . value } " )
2024-03-19 20:14:12 +00:00
class CLIPVisionDiffusersConfig ( DiffusersConfigBase ) :
2024-03-06 08:42:47 +00:00
""" Model config for CLIPVision. """
2023-11-05 03:03:26 +00:00
2023-11-11 00:14:15 +00:00
type : Literal [ ModelType . CLIPVision ] = ModelType . CLIPVision
2024-07-23 21:41:00 +00:00
format : Literal [ ModelFormat . Diffusers ] = ModelFormat . Diffusers
2023-11-05 03:03:26 +00:00
2024-03-01 01:57:46 +00:00
@staticmethod
def get_tag ( ) - > Tag :
2024-03-04 11:36:52 +00:00
return Tag ( f " { ModelType . CLIPVision . value } . { ModelFormat . Diffusers . value } " )
2024-03-01 01:57:46 +00:00
2023-11-05 03:03:26 +00:00
2024-03-19 20:14:12 +00:00
class T2IAdapterConfig ( DiffusersConfigBase , ControlAdapterConfigBase ) :
2023-11-05 03:03:26 +00:00
""" Model config for T2I. """
2023-11-11 00:14:15 +00:00
type : Literal [ ModelType . T2IAdapter ] = ModelType . T2IAdapter
2024-07-23 21:41:00 +00:00
format : Literal [ ModelFormat . Diffusers ] = ModelFormat . Diffusers
2023-11-05 03:03:26 +00:00
2024-03-01 01:57:46 +00:00
@staticmethod
def get_tag ( ) - > Tag :
2024-03-04 11:36:52 +00:00
return Tag ( f " { ModelType . T2IAdapter . value } . { ModelFormat . Diffusers . value } " )
2024-03-01 01:57:46 +00:00
2024-06-28 22:03:09 +00:00
class SpandrelImageToImageConfig ( ModelConfigBase ) :
""" Model config for Spandrel Image to Image models. """
type : Literal [ ModelType . SpandrelImageToImage ] = ModelType . SpandrelImageToImage
format : Literal [ ModelFormat . Checkpoint ] = ModelFormat . Checkpoint
@staticmethod
def get_tag ( ) - > Tag :
return Tag ( f " { ModelType . SpandrelImageToImage . value } . { ModelFormat . Checkpoint . value } " )
2024-03-01 01:57:46 +00:00
def get_model_discriminator_value ( v : Any ) - > str :
"""
Computes the discriminator value for a model config .
https : / / docs . pydantic . dev / latest / concepts / unions / #discriminated-unions-with-callable-discriminator
"""
2024-03-04 11:36:52 +00:00
format_ = None
type_ = None
2024-03-01 01:57:46 +00:00
if isinstance ( v , dict ) :
2024-03-04 11:36:52 +00:00
format_ = v . get ( " format " )
if isinstance ( format_ , Enum ) :
format_ = format_ . value
type_ = v . get ( " type " )
if isinstance ( type_ , Enum ) :
type_ = type_ . value
else :
format_ = v . format . value
type_ = v . type . value
v = f " { type_ } . { format_ } "
return v
2024-03-01 01:57:46 +00:00
AnyModelConfig = Annotated [
Union [
Annotated [ MainDiffusersConfig , MainDiffusersConfig . get_tag ( ) ] ,
Annotated [ MainCheckpointConfig , MainCheckpointConfig . get_tag ( ) ] ,
2024-08-19 16:08:24 +00:00
Annotated [ MainBnbQuantized4bCheckpointConfig , MainBnbQuantized4bCheckpointConfig . get_tag ( ) ] ,
2024-03-05 06:37:17 +00:00
Annotated [ VAEDiffusersConfig , VAEDiffusersConfig . get_tag ( ) ] ,
Annotated [ VAECheckpointConfig , VAECheckpointConfig . get_tag ( ) ] ,
2024-03-01 01:57:46 +00:00
Annotated [ ControlNetDiffusersConfig , ControlNetDiffusersConfig . get_tag ( ) ] ,
Annotated [ ControlNetCheckpointConfig , ControlNetCheckpointConfig . get_tag ( ) ] ,
2024-03-05 06:37:17 +00:00
Annotated [ LoRALyCORISConfig , LoRALyCORISConfig . get_tag ( ) ] ,
2024-03-01 01:57:46 +00:00
Annotated [ LoRADiffusersConfig , LoRADiffusersConfig . get_tag ( ) ] ,
2024-08-16 21:04:48 +00:00
Annotated [ T5EncoderConfig , T5EncoderConfig . get_tag ( ) ] ,
2024-08-20 16:37:12 +00:00
Annotated [ T5Encoder8bConfig , T5Encoder8bConfig . get_tag ( ) ] ,
2024-03-01 01:57:46 +00:00
Annotated [ TextualInversionFileConfig , TextualInversionFileConfig . get_tag ( ) ] ,
Annotated [ TextualInversionFolderConfig , TextualInversionFolderConfig . get_tag ( ) ] ,
2024-03-29 06:20:18 +00:00
Annotated [ IPAdapterInvokeAIConfig , IPAdapterInvokeAIConfig . get_tag ( ) ] ,
2024-03-23 20:10:28 +00:00
Annotated [ IPAdapterCheckpointConfig , IPAdapterCheckpointConfig . get_tag ( ) ] ,
2024-03-01 01:57:46 +00:00
Annotated [ T2IAdapterConfig , T2IAdapterConfig . get_tag ( ) ] ,
2024-06-28 22:03:09 +00:00
Annotated [ SpandrelImageToImageConfig , SpandrelImageToImageConfig . get_tag ( ) ] ,
2024-03-01 01:57:46 +00:00
Annotated [ CLIPVisionDiffusersConfig , CLIPVisionDiffusersConfig . get_tag ( ) ] ,
2024-08-16 21:04:48 +00:00
Annotated [ CLIPEmbedDiffusersConfig , CLIPEmbedDiffusersConfig . get_tag ( ) ] ,
2024-03-01 01:57:46 +00:00
] ,
Discriminator ( get_model_discriminator_value ) ,
2023-11-05 03:03:26 +00:00
]
2023-11-12 21:50:05 +00:00
AnyModelConfigValidator = TypeAdapter ( AnyModelConfig )
2024-03-08 04:37:31 +00:00
AnyDefaultSettings : TypeAlias = Union [ MainModelDefaultSettings , ControlAdapterDefaultSettings ]
2024-02-06 03:56:32 +00:00
2024-03-06 19:18:21 +00:00
2023-11-05 03:03:26 +00:00
class ModelConfigFactory ( object ) :
""" Class for parsing config dicts into StableDiffusion Config obects. """
@classmethod
def make_config (
cls ,
2024-02-10 04:08:38 +00:00
model_data : Union [ Dict [ str , Any ] , AnyModelConfig ] ,
2023-11-05 03:03:26 +00:00
key : Optional [ str ] = None ,
2024-02-10 23:09:45 +00:00
dest_class : Optional [ Type [ ModelConfigBase ] ] = None ,
2024-02-04 22:23:10 +00:00
timestamp : Optional [ float ] = None ,
2023-11-05 03:03:26 +00:00
) - > AnyModelConfig :
"""
Return the appropriate config object from raw dict values .
: param model_data : A raw dict corresponding the obect fields to be
parsed into a ModelConfigBase obect ( or descendent ) , or a ModelConfigBase
object , which will be passed through unchanged .
: param dest_class : The config class to be returned . If not provided , will
be selected automatically .
"""
2024-02-10 23:09:45 +00:00
model : Optional [ ModelConfigBase ] = None
2023-11-05 03:03:26 +00:00
if isinstance ( model_data , ModelConfigBase ) :
2023-11-11 17:22:38 +00:00
model = model_data
elif dest_class :
2024-02-10 23:09:45 +00:00
model = dest_class . model_validate ( model_data )
2023-11-11 00:14:15 +00:00
else :
2024-02-10 23:09:45 +00:00
# mypy doesn't typecheck TypeAdapters well?
model = AnyModelConfigValidator . validate_python ( model_data ) # type: ignore
assert model is not None
2023-11-11 17:22:38 +00:00
if key :
model . key = key
2024-03-01 04:21:35 +00:00
if isinstance ( model , CheckpointConfigBase ) and timestamp is not None :
2024-03-01 04:27:41 +00:00
model . converted_at = timestamp
2024-06-13 20:34:27 +00:00
if model :
validate_hash ( model . hash )
2024-02-10 23:09:45 +00:00
return model # type: ignore