add ability to filter model listings by format

This commit is contained in:
Lincoln Stein 2023-12-13 15:59:21 -05:00
parent 340957f920
commit 569ae7c482
3 changed files with 12 additions and 3 deletions

View File

@ -45,6 +45,7 @@ async def list_model_records(
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"),
model_format: Optional[str] = Query(default=None, description="Exact match on the format of the model (e.g. 'diffusers')"),
) -> ModelsList:
"""Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records
@ -52,10 +53,10 @@ async def list_model_records(
if base_models:
for base_model in base_models:
found_models.extend(
record_store.search_by_attr(base_model=base_model, model_type=model_type, model_name=model_name)
record_store.search_by_attr(base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format)
)
else:
found_models.extend(record_store.search_by_attr(model_type=model_type, model_name=model_name))
found_models.extend(record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format))
return ModelsList(models=found_models)

View File

@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Optional, Union
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType
class DuplicateModelException(Exception):
@ -106,6 +106,7 @@ class ModelRecordServiceBase(ABC):
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
model_format: Optional[ModelFormat] = None,
) -> List[AnyModelConfig]:
"""
Return models matching name, base and/or type.
@ -113,6 +114,7 @@ class ModelRecordServiceBase(ABC):
:param model_name: Filter by name of model (optional)
:param base_model: Filter by base model (optional)
:param model_type: Filter by type of model (optional)
:param model_format: Filter by model format (e.g. "diffusers") (optional)
If none of the optional filters are passed, will return all
models in the database.

View File

@ -49,6 +49,7 @@ from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
ModelConfigFactory,
ModelFormat,
ModelType,
)
@ -225,6 +226,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
model_format: Optional[ModelFormat] = None,
) -> List[AnyModelConfig]:
"""
Return models matching name, base and/or type.
@ -232,6 +234,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
:param model_name: Filter by name of model (optional)
:param base_model: Filter by base model (optional)
:param model_type: Filter by type of model (optional)
:param model_format: Filter by model format (e.g. "diffusers") (optional)
If none of the optional filters are passed, will return all
models in the database.
@ -248,6 +251,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
if model_type:
where_clause.append("type=?")
bindings.append(model_type)
if model_format:
where_clause.append("format=?")
bindings.append(model_format)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
with self._db.lock:
self._cursor.execute(