model loading and conversion implemented for vaes

This commit is contained in:
Lincoln Stein
2024-02-03 22:55:09 -05:00
committed by psychedelicious
parent 5c2884569e
commit 60aa3d4893
29 changed files with 2382 additions and 237 deletions

View File

@ -19,12 +19,15 @@ Typical usage:
Validation errors will raise an InvalidModelConfigException error.
"""
import time
import torch
from enum import Enum
from typing import Literal, Optional, Type, Union
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
from diffusers import ModelMixin
from typing_extensions import Annotated, Any, Dict
from .onnx_runtime import IAIOnnxRuntimeModel
class InvalidModelConfigException(Exception):
"""Exception for when config parser doesn't recognized this combination of model type and format."""
@ -127,6 +130,7 @@ class ModelConfigBase(BaseModel):
) # if model is converted or otherwise modified, this will hold updated hash
description: Optional[str] = Field(default=None)
source: Optional[str] = Field(description="Model download source (URL or repo_id)", default=None)
last_modified: Optional[float] = Field(description="Timestamp for modification time", default_factory=time.time)
model_config = ConfigDict(
use_enum_values=False,
@ -280,6 +284,7 @@ AnyModelConfig = Union[
]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel]
# IMPLEMENTATION NOTE:
# The preferred alternative to the above is a discriminated Union as shown
@ -312,6 +317,7 @@ class ModelConfigFactory(object):
model_data: Union[dict, AnyModelConfig],
key: Optional[str] = None,
dest_class: Optional[Type] = None,
timestamp: Optional[float] = None
) -> AnyModelConfig:
"""
Return the appropriate config object from raw dict values.
@ -330,4 +336,6 @@ class ModelConfigFactory(object):
model = AnyModelConfigValidator.validate_python(model_data)
if key:
model.key = key
if timestamp:
model.last_modified = timestamp
return model