mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
model loading and conversion implemented for vaes
This commit is contained in:
committed by
psychedelicious
parent
5c2884569e
commit
60aa3d4893
@ -19,12 +19,15 @@ Typical usage:
|
||||
Validation errors will raise an InvalidModelConfigException error.
|
||||
|
||||
"""
|
||||
import time
|
||||
import torch
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Type, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
|
||||
from diffusers import ModelMixin
|
||||
from typing_extensions import Annotated, Any, Dict
|
||||
|
||||
from .onnx_runtime import IAIOnnxRuntimeModel
|
||||
|
||||
class InvalidModelConfigException(Exception):
|
||||
"""Exception for when config parser doesn't recognized this combination of model type and format."""
|
||||
@ -127,6 +130,7 @@ class ModelConfigBase(BaseModel):
|
||||
) # if model is converted or otherwise modified, this will hold updated hash
|
||||
description: Optional[str] = Field(default=None)
|
||||
source: Optional[str] = Field(description="Model download source (URL or repo_id)", default=None)
|
||||
last_modified: Optional[float] = Field(description="Timestamp for modification time", default_factory=time.time)
|
||||
|
||||
model_config = ConfigDict(
|
||||
use_enum_values=False,
|
||||
@ -280,6 +284,7 @@ AnyModelConfig = Union[
|
||||
]
|
||||
|
||||
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
||||
AnyModel = Union[ModelMixin, torch.nn.Module, IAIOnnxRuntimeModel]
|
||||
|
||||
# IMPLEMENTATION NOTE:
|
||||
# The preferred alternative to the above is a discriminated Union as shown
|
||||
@ -312,6 +317,7 @@ class ModelConfigFactory(object):
|
||||
model_data: Union[dict, AnyModelConfig],
|
||||
key: Optional[str] = None,
|
||||
dest_class: Optional[Type] = None,
|
||||
timestamp: Optional[float] = None
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Return the appropriate config object from raw dict values.
|
||||
@ -330,4 +336,6 @@ class ModelConfigFactory(object):
|
||||
model = AnyModelConfigValidator.validate_python(model_data)
|
||||
if key:
|
||||
model.key = key
|
||||
if timestamp:
|
||||
model.last_modified = timestamp
|
||||
return model
|
||||
|
Reference in New Issue
Block a user