mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
implement workaround for FastAPI and discriminated unions in Body parameter
This commit is contained in:
parent
2b36565e9e
commit
af2264b6eb
@ -3,6 +3,8 @@
|
|||||||
|
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
from hashlib import sha1
|
||||||
|
from random import randbytes
|
||||||
|
|
||||||
from fastapi import Body, Path, Query, Response
|
from fastapi import Body, Path, Query, Response
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
@ -34,8 +36,8 @@ ModelsListValidator = TypeAdapter(ModelsList)
|
|||||||
operation_id="list_model_records",
|
operation_id="list_model_records",
|
||||||
)
|
)
|
||||||
async def list_model_records(
|
async def list_model_records(
|
||||||
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
||||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||||
) -> ModelsList:
|
) -> ModelsList:
|
||||||
"""Get a list of models."""
|
"""Get a list of models."""
|
||||||
record_store = ApiDependencies.invoker.services.model_records
|
record_store = ApiDependencies.invoker.services.model_records
|
||||||
@ -61,7 +63,7 @@ async def list_model_records(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def get_model_record(
|
async def get_model_record(
|
||||||
key: str = Path(description="Key of the model record to fetch."),
|
key: str = Path(description="Key of the model record to fetch."),
|
||||||
) -> AnyModelConfig:
|
) -> AnyModelConfig:
|
||||||
"""Get a model record"""
|
"""Get a model record"""
|
||||||
record_store = ApiDependencies.invoker.services.model_records
|
record_store = ApiDependencies.invoker.services.model_records
|
||||||
@ -84,9 +86,8 @@ async def get_model_record(
|
|||||||
response_model=AnyModelConfig,
|
response_model=AnyModelConfig,
|
||||||
)
|
)
|
||||||
async def update_model_record(
|
async def update_model_record(
|
||||||
key: Annotated[str, Path(description="Unique key of model")],
|
key: Annotated[str, Path(description="Unique key of model")],
|
||||||
# info: Annotated[AnyModelConfig, Body(description="Model configuration")],
|
info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")]
|
||||||
info: AnyModelConfig,
|
|
||||||
) -> AnyModelConfig:
|
) -> AnyModelConfig:
|
||||||
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
@ -134,7 +135,7 @@ async def del_model_record(
|
|||||||
status_code=201,
|
status_code=201,
|
||||||
)
|
)
|
||||||
async def add_model_record(
|
async def add_model_record(
|
||||||
config: AnyModelConfig,
|
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")]
|
||||||
) -> AnyModelConfig:
|
) -> AnyModelConfig:
|
||||||
"""
|
"""
|
||||||
Add a model using the configuration information appropriate for its type.
|
Add a model using the configuration information appropriate for its type.
|
||||||
|
@ -257,21 +257,34 @@ _ControlNetConfig = Annotated[
|
|||||||
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
|
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
|
||||||
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
|
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
|
||||||
|
|
||||||
AnyModelConfig = Annotated[
|
AnyModelConfig = Union[
|
||||||
Union[
|
_MainModelConfig,
|
||||||
_MainModelConfig,
|
_ONNXConfig,
|
||||||
_ONNXConfig,
|
_VaeConfig,
|
||||||
_VaeConfig,
|
_ControlNetConfig,
|
||||||
_ControlNetConfig,
|
LoRAConfig,
|
||||||
LoRAConfig,
|
TextualInversionConfig,
|
||||||
TextualInversionConfig,
|
IPAdapterConfig,
|
||||||
IPAdapterConfig,
|
CLIPVisionDiffusersConfig,
|
||||||
CLIPVisionDiffusersConfig,
|
T2IConfig,
|
||||||
T2IConfig,
|
|
||||||
],
|
|
||||||
Body(discriminator="type"),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Preferred alternative is a discriminated Union, but it breaks FastAPI when applied to a route.
|
||||||
|
# AnyModelConfig = Annotated[
|
||||||
|
# Union[
|
||||||
|
# _MainModelConfig,
|
||||||
|
# _ONNXConfig,
|
||||||
|
# _VaeConfig,
|
||||||
|
# _ControlNetConfig,
|
||||||
|
# LoRAConfig,
|
||||||
|
# TextualInversionConfig,
|
||||||
|
# IPAdapterConfig,
|
||||||
|
# CLIPVisionDiffusersConfig,
|
||||||
|
# T2IConfig,
|
||||||
|
# ],
|
||||||
|
# Field(discriminator="type"),
|
||||||
|
# ]
|
||||||
|
|
||||||
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
AnyModelConfigValidator = TypeAdapter(AnyModelConfig)
|
||||||
|
|
||||||
|
|
||||||
@ -295,11 +308,11 @@ class ModelConfigFactory(object):
|
|||||||
be selected automatically.
|
be selected automatically.
|
||||||
"""
|
"""
|
||||||
if isinstance(model_data, ModelConfigBase):
|
if isinstance(model_data, ModelConfigBase):
|
||||||
if key:
|
model = model_data
|
||||||
model_data.key = key
|
elif dest_class:
|
||||||
return model_data
|
model = dest_class.validate_python(model_data)
|
||||||
else:
|
else:
|
||||||
model = AnyModelConfigValidator.validate_python(model_data)
|
model = AnyModelConfigValidator.validate_python(model_data)
|
||||||
if key:
|
if key:
|
||||||
model.key = key
|
model.key = key
|
||||||
return model
|
return model
|
||||||
|
@ -44,6 +44,12 @@ def example_config() -> TextualInversionConfig:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_type(store: ModelRecordServiceBase):
|
||||||
|
config = example_config()
|
||||||
|
store.add_model("key1", config)
|
||||||
|
config1 = store.get_model("key1")
|
||||||
|
assert type(config1) == TextualInversionConfig
|
||||||
|
|
||||||
def test_add(store: ModelRecordServiceBase):
|
def test_add(store: ModelRecordServiceBase):
|
||||||
raw = dict(
|
raw = dict(
|
||||||
path="/tmp/foo.ckpt",
|
path="/tmp/foo.ckpt",
|
||||||
|
Loading…
Reference in New Issue
Block a user