mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
loaders for main, controlnet, ip-adapter, clipvision and t2i
This commit is contained in:
committed by
psychedelicious
parent
60aa3d4893
commit
34d5cad4c9
@ -20,14 +20,16 @@ Validation errors will raise an InvalidModelConfigException error.
|
||||
|
||||
"""
|
||||
import time
|
||||
import torch
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Type, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||
import torch
|
||||
from diffusers import ModelMixin
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||
from typing_extensions import Annotated, Any, Dict
|
||||
|
||||
from .onnx_runtime import IAIOnnxRuntimeModel
|
||||
from ..ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||
|
||||
class InvalidModelConfigException(Exception):
|
||||
"""Exception for when config parser doesn't recognized this combination of model type and format."""
|
||||
@ -204,6 +206,8 @@ class _MainConfig(ModelConfigBase):
|
||||
|
||||
vae: Optional[str] = Field(default=None)
|
||||
variant: ModelVariantType = ModelVariantType.Normal
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
ztsnr_training: bool = False
|
||||
|
||||
|
||||
@ -217,8 +221,6 @@ class MainDiffusersConfig(_DiffusersConfig, _MainConfig):
|
||||
"""Model config for main diffusers models."""
|
||||
|
||||
type: Literal[ModelType.Main] = ModelType.Main
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
|
||||
upcast_attention: bool = False
|
||||
|
||||
|
||||
class ONNXSD1Config(_MainConfig):
|
||||
@ -276,6 +278,7 @@ AnyModelConfig = Union[
|
||||
_ONNXConfig,
|
||||
_VaeConfig,
|
||||
_ControlNetConfig,
|
||||
# ModelConfigBase,
|
||||
LoRAConfig,
|
||||
TextualInversionConfig,
|
||||
IPAdapterConfig,
|
||||
@ -284,7 +287,7 @@ AnyModelConfig = Union[
|
||||
]
|
||||
|
||||
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
||||
AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel]
|
||||
AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel, IPAdapter, IPAdapterPlus]
|
||||
|
||||
# IMPLEMENTATION NOTE:
|
||||
# The preferred alternative to the above is a discriminated Union as shown
|
||||
@ -317,7 +320,7 @@ class ModelConfigFactory(object):
|
||||
model_data: Union[dict, AnyModelConfig],
|
||||
key: Optional[str] = None,
|
||||
dest_class: Optional[Type] = None,
|
||||
timestamp: Optional[float] = None
|
||||
timestamp: Optional[float] = None,
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Return the appropriate config object from raw dict values.
|
||||
|
Reference in New Issue
Block a user