implement workaround for FastAPI and discriminated unions in Body parameter

This commit is contained in:
Lincoln Stein 2023-11-11 12:22:38 -05:00
parent 2b36565e9e
commit af2264b6eb
3 changed files with 46 additions and 26 deletions

View File

@ -3,6 +3,8 @@
from typing import List, Optional
from hashlib import sha1
from random import randbytes
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
@ -85,8 +87,7 @@ async def get_model_record(
)
async def update_model_record(
key: Annotated[str, Path(description="Unique key of model")],
# info: Annotated[AnyModelConfig, Body(description="Model configuration")],
info: AnyModelConfig,
info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")]
) -> AnyModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger
@ -134,7 +135,7 @@ async def del_model_record(
status_code=201,
)
async def add_model_record(
config: AnyModelConfig,
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")]
) -> AnyModelConfig:
"""
Add a model using the configuration information appropriate for its type.

View File

@ -257,8 +257,7 @@ _ControlNetConfig = Annotated[
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
AnyModelConfig = Annotated[
Union[
AnyModelConfig = Union[
_MainModelConfig,
_ONNXConfig,
_VaeConfig,
@ -268,10 +267,24 @@ AnyModelConfig = Annotated[
IPAdapterConfig,
CLIPVisionDiffusersConfig,
T2IConfig,
],
Body(discriminator="type"),
]
# Preferred alternative is a discriminated Union, but it breaks FastAPI when applied to a route.
# AnyModelConfig = Annotated[
# Union[
# _MainModelConfig,
# _ONNXConfig,
# _VaeConfig,
# _ControlNetConfig,
# LoRAConfig,
# TextualInversionConfig,
# IPAdapterConfig,
# CLIPVisionDiffusersConfig,
# T2IConfig,
# ],
# Field(discriminator="type"),
# ]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
@ -295,9 +308,9 @@ class ModelConfigFactory(object):
be selected automatically.
"""
if isinstance(model_data, ModelConfigBase):
if key:
model_data.key = key
return model_data
model = model_data
elif dest_class:
model = dest_class.validate_python(model_data)
else:
model = AnyModelConfigValidator.validate_python(model_data)
if key:

View File

@ -44,6 +44,12 @@ def example_config() -> TextualInversionConfig:
)
def test_type(store: ModelRecordServiceBase):
config = example_config()
store.add_model("key1", config)
config1 = store.get_model("key1")
assert type(config1) == TextualInversionConfig
def test_add(store: ModelRecordServiceBase):
raw = dict(
path="/tmp/foo.ckpt",