loaders for main, controlnet, ip-adapter, clipvision and t2i

This commit is contained in:
Lincoln Stein
2024-02-04 17:23:10 -05:00
committed by psychedelicious
parent 60aa3d4893
commit 34d5cad4c9
32 changed files with 1123 additions and 159 deletions

View File

@ -20,14 +20,16 @@ Validation errors will raise an InvalidModelConfigException error.
"""
import time
import torch
from enum import Enum
from typing import Literal, Optional, Type, Union
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
import torch
from diffusers import ModelMixin
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from typing_extensions import Annotated, Any, Dict
from .onnx_runtime import IAIOnnxRuntimeModel
from ..ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
class InvalidModelConfigException(Exception):
"""Exception for when config parser doesn't recognized this combination of model type and format."""
@ -204,6 +206,8 @@ class _MainConfig(ModelConfigBase):
vae: Optional[str] = Field(default=None)
variant: ModelVariantType = ModelVariantType.Normal
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
ztsnr_training: bool = False
@ -217,8 +221,6 @@ class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
"""Model config for main diffusers models."""
type: Literal[ModelType.Main] = ModelType.Main
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
class ONNXSD1Config(_MainConfig):
@ -276,6 +278,7 @@ AnyModelConfig = Union[
_ONNXConfig,
_VaeConfig,
_ControlNetConfig,
# ModelConfigBase,
LoRAConfig,
TextualInversionConfig,
IPAdapterConfig,
@ -284,7 +287,7 @@ AnyModelConfig = Union[
]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel]
AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel, IPAdapter, IPAdapterPlus]
# IMPLEMENTATION NOTE:
# The preferred alternative to the above is a discriminated Union as shown
@ -317,7 +320,7 @@ class ModelConfigFactory(object):
model_data: Union[dict, AnyModelConfig],
key: Optional[str] = None,
dest_class: Optional[Type] = None,
timestamp: Optional[float] = None
timestamp: Optional[float] = None,
) -> AnyModelConfig:
"""
Return the appropriate config object from raw dict values.