remove cruft code from router

This commit is contained in:
Lincoln Stein 2023-11-10 18:49:25 -05:00
parent b55fc2935e
commit bd56e9bc81
2 changed files with 34 additions and 49 deletions

View File

@ -46,10 +46,10 @@ async def list_model_records(
models_raw = list() models_raw = list()
for base_model in base_models: for base_model in base_models:
models_raw.extend( models_raw.extend(
[x.model_dump() for x in record_store.search_by_name(base_model=base_model, model_type=model_type)] [x.model_dump() for x in record_store.search_by_attr(base_model=base_model, model_type=model_type)]
) )
else: else:
models_raw = [x.model_dump() for x in record_store.search_by_name(model_type=model_type)] models_raw = [x.model_dump() for x in record_store.search_by_attr(model_type=model_type)]
models = ModelsListValidator.validate_python({"models": models_raw}) models = ModelsListValidator.validate_python({"models": models_raw})
return models return models
@ -108,7 +108,6 @@ async def update_model_record(
operation_id="del_model_record", operation_id="del_model_record",
responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}}, responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}},
status_code=204, status_code=204,
response_model=None,
) )
async def del_model_record( async def del_model_record(
key: str = Path(description="Unique key of model to remove from model registry."), key: str = Path(description="Unique key of model to remove from model registry."),
@ -131,12 +130,10 @@ async def del_model_record(
operation_id="add_model_record", operation_id="add_model_record",
responses={ responses={
201: {"description": "The model added successfully"}, 201: {"description": "The model added successfully"},
404: {"description": "The model could not be found"},
409: {"description": "There is already a model corresponding to this path or repo_id"}, 409: {"description": "There is already a model corresponding to this path or repo_id"},
415: {"description": "Unrecognized file/folder format"}, 415: {"description": "Unrecognized file/folder format"},
}, },
status_code=201, status_code=201,
response_model=AnyModelConfig,
) )
async def add_model_record( async def add_model_record(
config: AnyModelConfig = Body(description="Model configuration"), config: AnyModelConfig = Body(description="Model configuration"),

View File

@ -314,17 +314,14 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
""" """
count = 0 count = 0
with self._db.lock: with self._db.lock:
try: self._cursor.execute(
self._cursor.execute( """--sql
"""--sql select count(*) FROM model_config
select count(*) FROM model_config WHERE id=?;
WHERE id=?; """,
""", (key,),
(key,), )
) count = self._cursor.fetchone()[0]
count = self._cursor.fetchone()[0]
except sqlite3.Error as e:
raise e
return count > 0 return count > 0
def search_by_attr( def search_by_attr(
@ -357,49 +354,40 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
bindings.append(model_type) bindings.append(model_type)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else "" where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
with self._db.lock: with self._db.lock:
try: self._cursor.execute(
self._cursor.execute( f"""--sql
f"""--sql select config FROM model_config
select config FROM model_config {where};
{where}; """,
""", tuple(bindings),
tuple(bindings), )
) results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
except sqlite3.Error as e:
raise e
return results return results
def search_by_path(self, path: Union[str, Path]) -> List[ModelConfigBase]: def search_by_path(self, path: Union[str, Path]) -> List[ModelConfigBase]:
"""Return models with the indicated path.""" """Return models with the indicated path."""
results = [] results = []
with self._db.lock: with self._db.lock:
try: self._cursor.execute(
self._cursor.execute( """--sql
"""--sql SELECT config FROM model_config
SELECT config FROM model_config WHERE model_path=?;
WHERE model_path=?; """,
""", (str(path),),
(str(path),), )
) results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
except sqlite3.Error as e:
raise e
return results return results
def search_by_hash(self, hash: str) -> List[ModelConfigBase]: def search_by_hash(self, hash: str) -> List[ModelConfigBase]:
"""Return models with the indicated original_hash.""" """Return models with the indicated original_hash."""
results = [] results = []
with self._db.lock: with self._db.lock:
try: self._cursor.execute(
self._cursor.execute( """--sql
"""--sql SELECT config FROM model_config
SELECT config FROM model_config WHERE original_hash=?;
WHERE original_hash=?; """,
""", (hash,),
(hash,), )
) results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
except sqlite3.Error as e:
raise e
return results return results