mirror of
https://github.com/invoke-ai/InvokeAI
synced 2025-07-26 13:26:21 +00:00
143 lines
3.3 KiB
Python
143 lines
3.3 KiB
Python
from enum import Enum
|
|
from typing import Dict, TypeAlias, Union
|
|
|
|
import diffusers
|
|
import onnxruntime as ort
|
|
import torch
|
|
from diffusers import ModelMixin
|
|
|
|
from invokeai.backend.raw_model import RawModel
|
|
|
|
# ModelMixin is the base class for all diffusers and transformers models
|
|
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
|
|
AnyModel = Union[
|
|
ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor], diffusers.DiffusionPipeline, ort.InferenceSession
|
|
]
|
|
|
|
|
|
class BaseModelType(str, Enum):
|
|
"""Base model type."""
|
|
|
|
Any = "any"
|
|
StableDiffusion1 = "sd-1"
|
|
StableDiffusion2 = "sd-2"
|
|
StableDiffusion3 = "sd-3"
|
|
StableDiffusionXL = "sdxl"
|
|
StableDiffusionXLRefiner = "sdxl-refiner"
|
|
Flux = "flux"
|
|
CogView4 = "cogview4"
|
|
Imagen3 = "imagen3"
|
|
Imagen4 = "imagen4"
|
|
ChatGPT4o = "chatgpt-4o"
|
|
|
|
|
|
class ModelType(str, Enum):
|
|
"""Model type."""
|
|
|
|
ONNX = "onnx"
|
|
Main = "main"
|
|
VAE = "vae"
|
|
LoRA = "lora"
|
|
ControlLoRa = "control_lora"
|
|
ControlNet = "controlnet" # used by model_probe
|
|
TextualInversion = "embedding"
|
|
IPAdapter = "ip_adapter"
|
|
CLIPVision = "clip_vision"
|
|
CLIPEmbed = "clip_embed"
|
|
T2IAdapter = "t2i_adapter"
|
|
T5Encoder = "t5_encoder"
|
|
SpandrelImageToImage = "spandrel_image_to_image"
|
|
SigLIP = "siglip"
|
|
FluxRedux = "flux_redux"
|
|
LlavaOnevision = "llava_onevision"
|
|
|
|
|
|
class SubModelType(str, Enum):
|
|
"""Submodel type."""
|
|
|
|
UNet = "unet"
|
|
Transformer = "transformer"
|
|
TextEncoder = "text_encoder"
|
|
TextEncoder2 = "text_encoder_2"
|
|
TextEncoder3 = "text_encoder_3"
|
|
Tokenizer = "tokenizer"
|
|
Tokenizer2 = "tokenizer_2"
|
|
Tokenizer3 = "tokenizer_3"
|
|
VAE = "vae"
|
|
VAEDecoder = "vae_decoder"
|
|
VAEEncoder = "vae_encoder"
|
|
Scheduler = "scheduler"
|
|
SafetyChecker = "safety_checker"
|
|
|
|
|
|
class ClipVariantType(str, Enum):
|
|
"""Variant type."""
|
|
|
|
L = "large"
|
|
G = "gigantic"
|
|
|
|
|
|
class ModelVariantType(str, Enum):
|
|
"""Variant type."""
|
|
|
|
Normal = "normal"
|
|
Inpaint = "inpaint"
|
|
Depth = "depth"
|
|
|
|
|
|
class ModelFormat(str, Enum):
|
|
"""Storage format of model."""
|
|
OMI = "omi"
|
|
Diffusers = "diffusers"
|
|
Checkpoint = "checkpoint"
|
|
LyCORIS = "lycoris"
|
|
ONNX = "onnx"
|
|
Olive = "olive"
|
|
EmbeddingFile = "embedding_file"
|
|
EmbeddingFolder = "embedding_folder"
|
|
InvokeAI = "invokeai"
|
|
T5Encoder = "t5_encoder"
|
|
BnbQuantizedLlmInt8b = "bnb_quantized_int8b"
|
|
BnbQuantizednf4b = "bnb_quantized_nf4b"
|
|
GGUFQuantized = "gguf_quantized"
|
|
Api = "api"
|
|
|
|
|
|
class SchedulerPredictionType(str, Enum):
|
|
"""Scheduler prediction type."""
|
|
|
|
Epsilon = "epsilon"
|
|
VPrediction = "v_prediction"
|
|
Sample = "sample"
|
|
|
|
|
|
class ModelRepoVariant(str, Enum):
|
|
"""Various hugging face variants on the diffusers format."""
|
|
|
|
Default = "" # model files without "fp16" or other qualifier
|
|
FP16 = "fp16"
|
|
FP32 = "fp32"
|
|
ONNX = "onnx"
|
|
OpenVINO = "openvino"
|
|
Flax = "flax"
|
|
|
|
|
|
class ModelSourceType(str, Enum):
|
|
"""Model source type."""
|
|
|
|
Path = "path"
|
|
Url = "url"
|
|
HFRepoID = "hf_repo_id"
|
|
|
|
|
|
class FluxLoRAFormat(str, Enum):
|
|
"""Flux LoRA formats."""
|
|
|
|
Diffusers = "flux.diffusers"
|
|
Kohya = "flux.kohya"
|
|
OneTrainer = "flux.onetrainer"
|
|
Control = "flux.control"
|
|
|
|
|
|
AnyVariant: TypeAlias = Union[ModelVariantType, ClipVariantType, None]
|