From ef8dcf5fae3a0e9e16caae97687c34b42e0e2bda Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 12 Nov 2023 14:20:32 -0500 Subject: [PATCH] blackify --- invokeai/app/api/routers/model_records.py | 12 ++++++------ .../services/model_records/test_model_records_sql.py | 3 ++- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/invokeai/app/api/routers/model_records.py b/invokeai/app/api/routers/model_records.py index 45665dfa77..298b30735a 100644 --- a/invokeai/app/api/routers/model_records.py +++ b/invokeai/app/api/routers/model_records.py @@ -36,8 +36,8 @@ ModelsListValidator = TypeAdapter(ModelsList) operation_id="list_model_records", ) async def list_model_records( - base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"), - model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"), + base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"), + model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"), ) -> ModelsList: """Get a list of models.""" record_store = ApiDependencies.invoker.services.model_records @@ -63,7 +63,7 @@ async def list_model_records( }, ) async def get_model_record( - key: str = Path(description="Key of the model record to fetch."), + key: str = Path(description="Key of the model record to fetch."), ) -> AnyModelConfig: """Get a model record""" record_store = ApiDependencies.invoker.services.model_records @@ -86,8 +86,8 @@ async def get_model_record( response_model=AnyModelConfig, ) async def update_model_record( - key: Annotated[str, Path(description="Unique key of model")], - info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")] + key: Annotated[str, Path(description="Unique key of model")], + info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")], ) -> AnyModelConfig: """Update model contents with a new config. If the model name or base fields are changed, then the model is renamed.""" logger = ApiDependencies.invoker.services.logger @@ -135,7 +135,7 @@ async def del_model_record( status_code=201, ) async def add_model_record( - config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")] + config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")] ) -> AnyModelConfig: """ Add a model using the configuration information appropriate for its type. diff --git a/tests/app/services/model_records/test_model_records_sql.py b/tests/app/services/model_records/test_model_records_sql.py index 48a5433a4d..52a7c40dfd 100644 --- a/tests/app/services/model_records/test_model_records_sql.py +++ b/tests/app/services/model_records/test_model_records_sql.py @@ -49,7 +49,8 @@ def test_type(store: ModelRecordServiceBase): store.add_model("key1", config) config1 = store.get_model("key1") assert type(config1) == TextualInversionConfig - + + def test_add(store: ModelRecordServiceBase): raw = dict( path="/tmp/foo.ckpt",