remove cruft code from router

This commit is contained in:
Lincoln Stein 2023-11-10 18:49:25 -05:00
parent b55fc2935e
commit bd56e9bc81
2 changed files with 34 additions and 49 deletions

View File

@ -46,10 +46,10 @@ async def list_model_records(
models_raw = list()
for base_model in base_models:
models_raw.extend(
[x.model_dump() for x in record_store.search_by_name(base_model=base_model, model_type=model_type)]
[x.model_dump() for x in record_store.search_by_attr(base_model=base_model, model_type=model_type)]
)
else:
models_raw = [x.model_dump() for x in record_store.search_by_name(model_type=model_type)]
models_raw = [x.model_dump() for x in record_store.search_by_attr(model_type=model_type)]
models = ModelsListValidator.validate_python({"models": models_raw})
return models
@ -108,7 +108,6 @@ async def update_model_record(
operation_id="del_model_record",
responses={204: {"description": "Model deleted successfully"}, 404: {"description": "Model not found"}},
status_code=204,
response_model=None,
)
async def del_model_record(
key: str = Path(description="Unique key of model to remove from model registry."),
@ -131,12 +130,10 @@ async def del_model_record(
operation_id="add_model_record",
responses={
201: {"description": "The model added successfully"},
404: {"description": "The model could not be found"},
409: {"description": "There is already a model corresponding to this path or repo_id"},
415: {"description": "Unrecognized file/folder format"},
},
status_code=201,
response_model=AnyModelConfig,
)
async def add_model_record(
config: AnyModelConfig = Body(description="Model configuration"),

View File

@ -314,17 +314,14 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
"""
count = 0
with self._db.lock:
try:
self._cursor.execute(
"""--sql
select count(*) FROM model_config
WHERE id=?;
""",
(key,),
)
count = self._cursor.fetchone()[0]
except sqlite3.Error as e:
raise e
self._cursor.execute(
"""--sql
select count(*) FROM model_config
WHERE id=?;
""",
(key,),
)
count = self._cursor.fetchone()[0]
return count > 0
def search_by_attr(
@ -357,49 +354,40 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
bindings.append(model_type)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
with self._db.lock:
try:
self._cursor.execute(
f"""--sql
select config FROM model_config
{where};
""",
tuple(bindings),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
except sqlite3.Error as e:
raise e
self._cursor.execute(
f"""--sql
select config FROM model_config
{where};
""",
tuple(bindings),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
return results
def search_by_path(self, path: Union[str, Path]) -> List[ModelConfigBase]:
"""Return models with the indicated path."""
results = []
with self._db.lock:
try:
self._cursor.execute(
"""--sql
SELECT config FROM model_config
WHERE model_path=?;
""",
(str(path),),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
except sqlite3.Error as e:
raise e
self._cursor.execute(
"""--sql
SELECT config FROM model_config
WHERE model_path=?;
""",
(str(path),),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
return results
def search_by_hash(self, hash: str) -> List[ModelConfigBase]:
"""Return models with the indicated original_hash."""
results = []
with self._db.lock:
try:
self._cursor.execute(
"""--sql
SELECT config FROM model_config
WHERE original_hash=?;
""",
(hash,),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
except sqlite3.Error as e:
raise e
self._cursor.execute(
"""--sql
SELECT config FROM model_config
WHERE original_hash=?;
""",
(hash,),
)
results = [ModelConfigFactory.make_config(json.loads(x[0])) for x in self._cursor.fetchall()]
return results