From 569ae7c482a6d8889d49ba8be93f759454edf574 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 13 Dec 2023 15:59:21 -0500 Subject: [PATCH 1/2] add ability to filter model listings by format --- invokeai/app/api/routers/model_records.py | 5 +++-- invokeai/app/services/model_records/model_records_base.py | 4 +++- invokeai/app/services/model_records/model_records_sql.py | 6 ++++++ 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/invokeai/app/api/routers/model_records.py b/invokeai/app/api/routers/model_records.py index 997f76a185..8d029db422 100644 --- a/invokeai/app/api/routers/model_records.py +++ b/invokeai/app/api/routers/model_records.py @@ -45,6 +45,7 @@ async def list_model_records( base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"), model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"), model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"), + model_format: Optional[str] = Query(default=None, description="Exact match on the format of the model (e.g. 'diffusers')"), ) -> ModelsList: """Get a list of models.""" record_store = ApiDependencies.invoker.services.model_records @@ -52,10 +53,10 @@ async def list_model_records( if base_models: for base_model in base_models: found_models.extend( - record_store.search_by_attr(base_model=base_model, model_type=model_type, model_name=model_name) + record_store.search_by_attr(base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format) ) else: - found_models.extend(record_store.search_by_attr(model_type=model_type, model_name=model_name)) + found_models.extend(record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)) return ModelsList(models=found_models) diff --git a/invokeai/app/services/model_records/model_records_base.py b/invokeai/app/services/model_records/model_records_base.py index 679d05fccd..ae0e633990 100644 --- a/invokeai/app/services/model_records/model_records_base.py +++ b/invokeai/app/services/model_records/model_records_base.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from pathlib import Path from typing import List, Optional, Union -from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType +from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType class DuplicateModelException(Exception): @@ -106,6 +106,7 @@ class ModelRecordServiceBase(ABC): model_name: Optional[str] = None, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None, + model_format: Optional[ModelFormat] = None, ) -> List[AnyModelConfig]: """ Return models matching name, base and/or type. @@ -113,6 +114,7 @@ class ModelRecordServiceBase(ABC): :param model_name: Filter by name of model (optional) :param base_model: Filter by base model (optional) :param model_type: Filter by type of model (optional) + :param model_format: Filter by model format (e.g. "diffusers") (optional) If none of the optional filters are passed, will return all models in the database. diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py index 08956a960f..a8e777cf64 100644 --- a/invokeai/app/services/model_records/model_records_sql.py +++ b/invokeai/app/services/model_records/model_records_sql.py @@ -49,6 +49,7 @@ from invokeai.backend.model_manager.config import ( AnyModelConfig, BaseModelType, ModelConfigFactory, + ModelFormat, ModelType, ) @@ -225,6 +226,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): model_name: Optional[str] = None, base_model: Optional[BaseModelType] = None, model_type: Optional[ModelType] = None, + model_format: Optional[ModelFormat] = None, ) -> List[AnyModelConfig]: """ Return models matching name, base and/or type. @@ -232,6 +234,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): :param model_name: Filter by name of model (optional) :param base_model: Filter by base model (optional) :param model_type: Filter by type of model (optional) + :param model_format: Filter by model format (e.g. "diffusers") (optional) If none of the optional filters are passed, will return all models in the database. @@ -248,6 +251,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase): if model_type: where_clause.append("type=?") bindings.append(model_type) + if model_format: + where_clause.append("format=?") + bindings.append(model_format) where = f"WHERE {' AND '.join(where_clause)}" if where_clause else "" with self._db.lock: self._cursor.execute( From 264ea6d94d3593bd9a67f9b30ab313e293e17319 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 14 Dec 2023 23:54:59 -0500 Subject: [PATCH 2/2] fix ruff errors --- invokeai/app/api/routers/model_records.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/invokeai/app/api/routers/model_records.py b/invokeai/app/api/routers/model_records.py index 8d029db422..934a7d15b3 100644 --- a/invokeai/app/api/routers/model_records.py +++ b/invokeai/app/api/routers/model_records.py @@ -45,7 +45,9 @@ async def list_model_records( base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"), model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"), model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"), - model_format: Optional[str] = Query(default=None, description="Exact match on the format of the model (e.g. 'diffusers')"), + model_format: Optional[str] = Query( + default=None, description="Exact match on the format of the model (e.g. 'diffusers')" + ), ) -> ModelsList: """Get a list of models.""" record_store = ApiDependencies.invoker.services.model_records @@ -53,10 +55,14 @@ async def list_model_records( if base_models: for base_model in base_models: found_models.extend( - record_store.search_by_attr(base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format) + record_store.search_by_attr( + base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format + ) ) else: - found_models.extend(record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)) + found_models.extend( + record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format) + ) return ModelsList(models=found_models)