add ability to filter model listings by format

This commit is contained in:
Lincoln Stein
2023-12-13 15:59:21 -05:00
parent 340957f920
commit 569ae7c482
3 changed files with 12 additions and 3 deletions

View File

@ -45,6 +45,7 @@ async def list_model_records(
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"),
model_format: Optional[str] = Query(default=None, description="Exact match on the format of the model (e.g. 'diffusers')"),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records
@ -52,10 +53,10 @@ async def list_model_records(
if base_models:
for base_model in base_models:
found_models.extend(
record_store.search_by_attr(base_model=base_model, model_type=model_type, model_name=model_name)
record_store.search_by_attr(base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format)
)
else:
found_models.extend(record_store.search_by_attr(model_type=model_type, model_name=model_name))
found_models.extend(record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format))
return ModelsList(models=found_models)