diff --git a/invokeai/app/api/routers/model_records.py b/invokeai/app/api/routers/model_records.py index 0fcfe5bb18..1040e598f0 100644 --- a/invokeai/app/api/routers/model_records.py +++ b/invokeai/app/api/routers/model_records.py @@ -46,10 +46,10 @@ async def list_model_records( models_raw = list() for base_model in base_models: models_raw.extend( - [x.model_dump() for x in record_store.search_by_name(base_model=base_model, model_type=model_type)] + [x.model_dump() for x in record_store.search_by_attr(base_model=base_model, model_type=model_type)] ) else: - models_raw = [x.model_dump() for x in record_store.search_by_name(model_type=model_type)] + models_raw = [x.model_dump() for x in record_store.search_by_attr(model_type=model_type)] models = ModelsListValidator.validate_python({"models": models_raw}) return models @@ -108,7 +108,6 @@ async def update_model_record( operation_id="del_model_record", responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}}, status_code=204, - response_model=None, ) async def del_model_record( key: str = Path(description="Unique key of model to remove from model registry."), @@ -131,12 +130,10 @@ async def del_model_record( operation_id="add_model_record", responses={ 201: {"description": "The model added successfully"}, - 404: {"description": "The model could not be found"}, 409: {"description": "There is already a model corresponding to this path or repo_id"}, 415: {"description": "Unrecognized file/folder format"}, }, status_code=201, - response_model=AnyModelConfig, ) async def add_model_record( config: AnyModelConfig = Body(description="Model configuration"), diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index 1af3ec8859..0fb027e3fa 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -314,17 +314,14 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): """ count = 0 with self._db.lock: - try: - self._cursor.execute( - """--sql - select count(*) FROM model_config - WHERE id=?; - """, - (key,), - ) - count = self._cursor.fetchone()[0] - except sqlite3.Error as e: - raise e + self._cursor.execute( + """--sql + select count(*) FROM model_config + WHERE id=?; + """, + (key,), + ) + count = self._cursor.fetchone()[0] return count > 0 def search_by_attr( @@ -357,49 +354,40 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): bindings.append(model_type) where = f"WHERE {' AND '.join(where_clause)}" if where_clause else "" with self._db.lock: - try: - self._cursor.execute( - f"""--sql - select config FROM model_config - {where}; - """, - tuple(bindings), - ) - results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] - except sqlite3.Error as e: - raise e + self._cursor.execute( + f"""--sql + select config FROM model_config + {where}; + """, + tuple(bindings), + ) + results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] return results def search_by_path(self, path: Union[str, Path]) -> List[ModelConfigBase]: """Return models with the indicated path.""" results = [] with self._db.lock: - try: - self._cursor.execute( - """--sql - SELECT config FROM model_config - WHERE model_path=?; - """, - (str(path),), - ) - results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] - except sqlite3.Error as e: - raise e + self._cursor.execute( + """--sql + SELECT config FROM model_config + WHERE model_path=?; + """, + (str(path),), + ) + results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] return results def search_by_hash(self, hash: str) -> List[ModelConfigBase]: """Return models with the indicated original_hash.""" results = [] with self._db.lock: - try: - self._cursor.execute( - """--sql - SELECT config FROM model_config - WHERE original_hash=?; - """, - (hash,), - ) - results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] - except sqlite3.Error as e: - raise e + self._cursor.execute( + """--sql + SELECT config FROM model_config + WHERE original_hash=?; + """, + (hash,), + ) + results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()] return results