diff --git a/invokeai/app/api/routers/model_records.py b/invokeai/app/api/routers/model_records.py index 997f76a185..934a7d15b3 100644 --- a/invokeai/app/api/routers/model_records.py +++ b/invokeai/app/api/routers/model_records.py @@ -45,6 +45,9 @@ async def list_model_records( base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"), model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"), model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"), + model_format: Optional[str] = Query( + default=None, description="Exact match on the format of the model (e.g. 'diffusers')" + ), ) -> ModelsList: """Get a list of models.""" record_store = ApiDependencies.invoker.services.model_records @@ -52,10 +55,14 @@ async def list_model_records( if base_models: for base_model in base_models: found_models.extend( - record_store.search_by_attr(base_model=base_model, model_type=model_type, model_name=model_name) + record_store.search_by_attr( + base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format + ) ) else: - found_models.extend(record_store.search_by_attr(model_type=model_type, model_name=model_name)) + found_models.extend( + record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format) + ) return ModelsList(models=found_models) diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 679d05fccd..ae0e633990 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from pathlib import Path from typing import List, Optional, Union -from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType +from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType class DuplicateModelException(Exception): @@ -106,6 +106,7 @@ class ModelRecordServiceBase(ABC): model_name: Optional[str] = None, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None, + model_format: Optional[ModelFormat] = None, ) -> List[AnyModelConfig]: """ Return models matching name, base and/or type. @@ -113,6 +114,7 @@ class ModelRecordServiceBase(ABC): :param model_name: Filter by name of model (optional) :param base_model: Filter by base model (optional) :param model_type: Filter by type of model (optional) + :param model_format: Filter by model format (e.g. "diffusers") (optional) If none of the optional filters are passed, will return all models in the database. diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index 08956a960f..a8e777cf64 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -49,6 +49,7 @@ from invokeai.backend.model_manager.config import ( AnyModelConfig, BaseModelType, ModelConfigFactory, + ModelFormat, ModelType, ) @@ -225,6 +226,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): model_name: Optional[str] = None, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None, + model_format: Optional[ModelFormat] = None, ) -> List[AnyModelConfig]: """ Return models matching name, base and/or type. @@ -232,6 +234,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): :param model_name: Filter by name of model (optional) :param base_model: Filter by base model (optional) :param model_type: Filter by type of model (optional) + :param model_format: Filter by model format (e.g. "diffusers") (optional) If none of the optional filters are passed, will return all models in the database. @@ -248,6 +251,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): if model_type: where_clause.append("type=?") bindings.append(model_type) + if model_format: + where_clause.append("format=?") + bindings.append(model_format) where = f"WHERE {' AND '.join(where_clause)}" if where_clause else "" with self._db.lock: self._cursor.execute(