mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
implement workaround for FastAPI and discriminated unions in Body parameter
This commit is contained in:
parent
2b36565e9e
commit
af2264b6eb
@ -3,6 +3,8 @@
|
||||
|
||||
|
||||
from typing import List, Optional
|
||||
from hashlib import sha1
|
||||
from random import randbytes
|
||||
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
@ -85,8 +87,7 @@ async def get_model_record(
|
||||
)
|
||||
async def update_model_record(
|
||||
key: Annotated[str, Path(description="Unique key of model")],
|
||||
# info: Annotated[AnyModelConfig, Body(description="Model configuration")],
|
||||
info: AnyModelConfig,
|
||||
info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")]
|
||||
) -> AnyModelConfig:
|
||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
@ -134,7 +135,7 @@ async def del_model_record(
|
||||
status_code=201,
|
||||
)
|
||||
async def add_model_record(
|
||||
config: AnyModelConfig,
|
||||
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")]
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
Add a model using the configuration information appropriate for its type.
|
||||
|
@ -257,8 +257,7 @@ _ControlNetConfig = Annotated[
|
||||
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
|
||||
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
|
||||
|
||||
AnyModelConfig = Annotated[
|
||||
Union[
|
||||
AnyModelConfig = Union[
|
||||
_MainModelConfig,
|
||||
_ONNXConfig,
|
||||
_VaeConfig,
|
||||
@ -268,10 +267,24 @@ AnyModelConfig = Annotated[
|
||||
IPAdapterConfig,
|
||||
CLIPVisionDiffusersConfig,
|
||||
T2IConfig,
|
||||
],
|
||||
Body(discriminator="type"),
|
||||
]
|
||||
|
||||
# Preferred alternative is a discriminated Union, but it breaks FastAPI when applied to a route.
|
||||
# AnyModelConfig = Annotated[
|
||||
# Union[
|
||||
# _MainModelConfig,
|
||||
# _ONNXConfig,
|
||||
# _VaeConfig,
|
||||
# _ControlNetConfig,
|
||||
# LoRAConfig,
|
||||
# TextualInversionConfig,
|
||||
# IPAdapterConfig,
|
||||
# CLIPVisionDiffusersConfig,
|
||||
# T2IConfig,
|
||||
# ],
|
||||
# Field(discriminator="type"),
|
||||
# ]
|
||||
|
||||
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
||||
|
||||
|
||||
@ -295,9 +308,9 @@ class ModelConfigFactory(object):
|
||||
be selected automatically.
|
||||
"""
|
||||
if isinstance(model_data, ModelConfigBase):
|
||||
if key:
|
||||
model_data.key = key
|
||||
return model_data
|
||||
model = model_data
|
||||
elif dest_class:
|
||||
model = dest_class.validate_python(model_data)
|
||||
else:
|
||||
model = AnyModelConfigValidator.validate_python(model_data)
|
||||
if key:
|
||||
|
@ -44,6 +44,12 @@ def example_config() -> TextualInversionConfig:
|
||||
)
|
||||
|
||||
|
||||
def test_type(store: ModelRecordServiceBase):
|
||||
config = example_config()
|
||||
store.add_model("key1", config)
|
||||
config1 = store.get_model("key1")
|
||||
assert type(config1) == TextualInversionConfig
|
||||
|
||||
def test_add(store: ModelRecordServiceBase):
|
||||
raw = dict(
|
||||
path="/tmp/foo.ckpt",
|
||||
|
Loading…
x
Reference in New Issue
Block a user