diff --git a/invokeai/app/api/routers/model_records.py b/invokeai/app/api/routers/model_records.py index 7b66ad876e..34e1dc378a 100644 --- a/invokeai/app/api/routers/model_records.py +++ b/invokeai/app/api/routers/model_records.py @@ -3,6 +3,8 @@ from typing import List, Optional +from hashlib import sha1 +from random import randbytes from fastapi import Body, Path, Query, Response from fastapi.routing import APIRouter @@ -34,8 +36,8 @@ ModelsListValidator = TypeAdapter(ModelsList) operation_id="list_model_records", ) async def list_model_records( - base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"), - model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"), + base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"), + model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"), ) -> ModelsList: """Get a list of models.""" record_store = ApiDependencies.invoker.services.model_records @@ -61,7 +63,7 @@ async def list_model_records( }, ) async def get_model_record( - key: str = Path(description="Key of the model record to fetch."), + key: str = Path(description="Key of the model record to fetch."), ) -> AnyModelConfig: """Get a model record""" record_store = ApiDependencies.invoker.services.model_records @@ -84,9 +86,8 @@ async def get_model_record( response_model=AnyModelConfig, ) async def update_model_record( - key: Annotated[str, Path(description="Unique key of model")], - # info: Annotated[AnyModelConfig, Body(description="Model configuration")], - info: AnyModelConfig, + key: Annotated[str, Path(description="Unique key of model")], + info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")] ) -> AnyModelConfig: """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" logger = ApiDependencies.invoker.services.logger @@ -134,7 +135,7 @@ async def del_model_record( status_code=201, ) async def add_model_record( - config: AnyModelConfig, + config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")] ) -> AnyModelConfig: """ Add a model using the configuration information appropriate for its type. diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index cf28bbeb17..bc369fa297 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -257,21 +257,34 @@ _ControlNetConfig = Annotated[ _VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")] _MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")] -AnyModelConfig = Annotated[ - Union[ - _MainModelConfig, - _ONNXConfig, - _VaeConfig, - _ControlNetConfig, - LoRAConfig, - TextualInversionConfig, - IPAdapterConfig, - CLIPVisionDiffusersConfig, - T2IConfig, - ], - Body(discriminator="type"), +AnyModelConfig = Union[ + _MainModelConfig, + _ONNXConfig, + _VaeConfig, + _ControlNetConfig, + LoRAConfig, + TextualInversionConfig, + IPAdapterConfig, + CLIPVisionDiffusersConfig, + T2IConfig, ] +# Preferred alternative is a discriminated Union, but it breaks FastAPI when applied to a route. +# AnyModelConfig = Annotated[ +# Union[ +# _MainModelConfig, +# _ONNXConfig, +# _VaeConfig, +# _ControlNetConfig, +# LoRAConfig, +# TextualInversionConfig, +# IPAdapterConfig, +# CLIPVisionDiffusersConfig, +# T2IConfig, +# ], +# Field(discriminator="type"), +# ] + AnyModelConfigValidator = TypeAdapter(AnyModelConfig) @@ -295,11 +308,11 @@ class ModelConfigFactory(object): be selected automatically. """ if isinstance(model_data, ModelConfigBase): - if key: - model_data.key = key - return model_data + model = model_data + elif dest_class: + model = dest_class.validate_python(model_data) else: model = AnyModelConfigValidator.validate_python(model_data) - if key: - model.key = key - return model + if key: + model.key = key + return model diff --git a/tests/app/services/model_records/test_model_records_sql.py b/tests/app/services/model_records/test_model_records_sql.py index c856cb7c02..48a5433a4d 100644 --- a/tests/app/services/model_records/test_model_records_sql.py +++ b/tests/app/services/model_records/test_model_records_sql.py @@ -44,6 +44,12 @@ def example_config() -> TextualInversionConfig: ) +def test_type(store: ModelRecordServiceBase): + config = example_config() + store.add_model("key1", config) + config1 = store.get_model("key1") + assert type(config1) == TextualInversionConfig + def test_add(store: ModelRecordServiceBase): raw = dict( path="/tmp/foo.ckpt",