add ability to filter model listings by format

This commit is contained in:
Lincoln Stein 2023-12-13 15:59:21 -05:00
parent 340957f920
commit 569ae7c482
3 changed files with 12 additions and 3 deletions

View File

@ -45,6 +45,7 @@ async def list_model_records(
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"), base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"), model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"), model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"),
model_format: Optional[str] = Query(default=None, description="Exact match on the format of the model (e.g. 'diffusers')"),
) -> ModelsList: ) -> ModelsList:
"""Get a list of models.""" """Get a list of models."""
record_store = ApiDependencies.invoker.services.model_records record_store = ApiDependencies.invoker.services.model_records
@ -52,10 +53,10 @@ async def list_model_records(
if base_models: if base_models:
for base_model in base_models: for base_model in base_models:
found_models.extend( found_models.extend(
record_store.search_by_attr(base_model=base_model, model_type=model_type, model_name=model_name) record_store.search_by_attr(base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format)
) )
else: else:
found_models.extend(record_store.search_by_attr(model_type=model_type, model_name=model_name)) found_models.extend(record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format))
return ModelsList(models=found_models) return ModelsList(models=found_models)

View File

@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import List, Optional, Union from typing import List, Optional, Union
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType
class DuplicateModelException(Exception): class DuplicateModelException(Exception):
@ -106,6 +106,7 @@ class ModelRecordServiceBase(ABC):
model_name: Optional[str] = None, model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None, base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None, model_type: Optional[ModelType] = None,
model_format: Optional[ModelFormat] = None,
) -> List[AnyModelConfig]: ) -> List[AnyModelConfig]:
""" """
Return models matching name, base and/or type. Return models matching name, base and/or type.
@ -113,6 +114,7 @@ class ModelRecordServiceBase(ABC):
:param model_name: Filter by name of model (optional) :param model_name: Filter by name of model (optional)
:param base_model: Filter by base model (optional) :param base_model: Filter by base model (optional)
:param model_type: Filter by type of model (optional) :param model_type: Filter by type of model (optional)
:param model_format: Filter by model format (e.g. "diffusers") (optional)
If none of the optional filters are passed, will return all If none of the optional filters are passed, will return all
models in the database. models in the database.

View File

@ -49,6 +49,7 @@ from invokeai.backend.model_manager.config import (
AnyModelConfig, AnyModelConfig,
BaseModelType, BaseModelType,
ModelConfigFactory, ModelConfigFactory,
ModelFormat,
ModelType, ModelType,
) )
@ -225,6 +226,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
model_name: Optional[str] = None, model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None, base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None, model_type: Optional[ModelType] = None,
model_format: Optional[ModelFormat] = None,
) -> List[AnyModelConfig]: ) -> List[AnyModelConfig]:
""" """
Return models matching name, base and/or type. Return models matching name, base and/or type.
@ -232,6 +234,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
:param model_name: Filter by name of model (optional) :param model_name: Filter by name of model (optional)
:param base_model: Filter by base model (optional) :param base_model: Filter by base model (optional)
:param model_type: Filter by type of model (optional) :param model_type: Filter by type of model (optional)
:param model_format: Filter by model format (e.g. "diffusers") (optional)
If none of the optional filters are passed, will return all If none of the optional filters are passed, will return all
models in the database. models in the database.
@ -248,6 +251,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
if model_type: if model_type:
where_clause.append("type=?") where_clause.append("type=?")
bindings.append(model_type) bindings.append(model_type)
if model_format:
where_clause.append("format=?")
bindings.append(model_format)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else "" where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
with self._db.lock: with self._db.lock:
self._cursor.execute( self._cursor.execute(