implement workaround for FastAPI and discriminated unions in Body parameter

This commit is contained in:
Lincoln Stein 2023-11-11 12:22:38 -05:00
parent 2b36565e9e
commit af2264b6eb
3 changed files with 46 additions and 26 deletions

View File

@ -3,6 +3,8 @@
from typing import List, Optional from typing import List, Optional
from hashlib import sha1
from random import randbytes
from fastapi import Body, Path, Query, Response from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
@ -34,8 +36,8 @@ ModelsListValidator = TypeAdapter(ModelsList)
operation_id="list_model_records", operation_id="list_model_records",
) )
async def list_model_records( async def list_model_records(
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"), base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"), model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
) -> ModelsList: ) -> ModelsList:
"""Get a list of models.""" """Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records record_store = ApiDependencies.invoker.services.model_records
@ -61,7 +63,7 @@ async def list_model_records(
}, },
) )
async def get_model_record( async def get_model_record(
key: str = Path(description="Key of the model record to fetch."), key: str = Path(description="Key of the model record to fetch."),
) -> AnyModelConfig: ) -> AnyModelConfig:
"""Get a model record""" """Get a model record"""
record_store = ApiDependencies.invoker.services.model_records record_store = ApiDependencies.invoker.services.model_records
@ -84,9 +86,8 @@ async def get_model_record(
response_model=AnyModelConfig, response_model=AnyModelConfig,
) )
async def update_model_record( async def update_model_record(
key: Annotated[str, Path(description="Unique key of model")], key: Annotated[str, Path(description="Unique key of model")],
# info: Annotated[AnyModelConfig, Body(description="Model configuration")], info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")]
info: AnyModelConfig,
) -> AnyModelConfig: ) -> AnyModelConfig:
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
@ -134,7 +135,7 @@ async def del_model_record(
status_code=201, status_code=201,
) )
async def add_model_record( async def add_model_record(
config: AnyModelConfig, config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")]
) -> AnyModelConfig: ) -> AnyModelConfig:
""" """
Add a model using the configuration information appropriate for its type. Add a model using the configuration information appropriate for its type.

View File

@ -257,21 +257,34 @@ _ControlNetConfig = Annotated[
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")] _VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")] _MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
AnyModelConfig = Annotated[ AnyModelConfig = Union[
Union[ _MainModelConfig,
_MainModelConfig, _ONNXConfig,
_ONNXConfig, _VaeConfig,
_VaeConfig, _ControlNetConfig,
_ControlNetConfig, LoRAConfig,
LoRAConfig, TextualInversionConfig,
TextualInversionConfig, IPAdapterConfig,
IPAdapterConfig, CLIPVisionDiffusersConfig,
CLIPVisionDiffusersConfig, T2IConfig,
T2IConfig,
],
Body(discriminator="type"),
] ]
# Preferred alternative is a discriminated Union, but it breaks FastAPI when applied to a route.
# AnyModelConfig = Annotated[
# Union[
# _MainModelConfig,
# _ONNXConfig,
# _VaeConfig,
# _ControlNetConfig,
# LoRAConfig,
# TextualInversionConfig,
# IPAdapterConfig,
# CLIPVisionDiffusersConfig,
# T2IConfig,
# ],
# Field(discriminator="type"),
# ]
AnyModelConfigValidator = TypeAdapter(AnyModelConfig) AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
@ -295,11 +308,11 @@ class ModelConfigFactory(object):
be selected automatically. be selected automatically.
""" """
if isinstance(model_data, ModelConfigBase): if isinstance(model_data, ModelConfigBase):
if key: model = model_data
model_data.key = key elif dest_class:
return model_data model = dest_class.validate_python(model_data)
else: else:
model = AnyModelConfigValidator.validate_python(model_data) model = AnyModelConfigValidator.validate_python(model_data)
if key: if key:
model.key = key model.key = key
return model return model

View File

@ -44,6 +44,12 @@ def example_config() -> TextualInversionConfig:
) )
def test_type(store: ModelRecordServiceBase):
config = example_config()
store.add_model("key1", config)
config1 = store.get_model("key1")
assert type(config1) == TextualInversionConfig
def test_add(store: ModelRecordServiceBase): def test_add(store: ModelRecordServiceBase):
raw = dict( raw = dict(
path="/tmp/foo.ckpt", path="/tmp/foo.ckpt",