mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add ability to filter model listings by format
This commit is contained in:
parent
340957f920
commit
569ae7c482
@ -45,6 +45,7 @@ async def list_model_records(
|
||||
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||
model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"),
|
||||
model_format: Optional[str] = Query(default=None, description="Exact match on the format of the model (e.g. 'diffusers')"),
|
||||
) -> ModelsList:
|
||||
"""Get a list of models."""
|
||||
record_store = ApiDependencies.invoker.services.model_records
|
||||
@ -52,10 +53,10 @@ async def list_model_records(
|
||||
if base_models:
|
||||
for base_model in base_models:
|
||||
found_models.extend(
|
||||
record_store.search_by_attr(base_model=base_model, model_type=model_type, model_name=model_name)
|
||||
record_store.search_by_attr(base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format)
|
||||
)
|
||||
else:
|
||||
found_models.extend(record_store.search_by_attr(model_type=model_type, model_name=model_name))
|
||||
found_models.extend(record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format))
|
||||
return ModelsList(models=found_models)
|
||||
|
||||
|
||||
|
@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType
|
||||
|
||||
|
||||
class DuplicateModelException(Exception):
|
||||
@ -106,6 +106,7 @@ class ModelRecordServiceBase(ABC):
|
||||
model_name: Optional[str] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
model_format: Optional[ModelFormat] = None,
|
||||
) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Return models matching name, base and/or type.
|
||||
@ -113,6 +114,7 @@ class ModelRecordServiceBase(ABC):
|
||||
:param model_name: Filter by name of model (optional)
|
||||
:param base_model: Filter by base model (optional)
|
||||
:param model_type: Filter by type of model (optional)
|
||||
:param model_format: Filter by model format (e.g. "diffusers") (optional)
|
||||
|
||||
If none of the optional filters are passed, will return all
|
||||
models in the database.
|
||||
|
@ -49,6 +49,7 @@ from invokeai.backend.model_manager.config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelConfigFactory,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
)
|
||||
|
||||
@ -225,6 +226,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
model_name: Optional[str] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
model_format: Optional[ModelFormat] = None,
|
||||
) -> List[AnyModelConfig]:
|
||||
"""
|
||||
Return models matching name, base and/or type.
|
||||
@ -232,6 +234,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
:param model_name: Filter by name of model (optional)
|
||||
:param base_model: Filter by base model (optional)
|
||||
:param model_type: Filter by type of model (optional)
|
||||
:param model_format: Filter by model format (e.g. "diffusers") (optional)
|
||||
|
||||
If none of the optional filters are passed, will return all
|
||||
models in the database.
|
||||
@ -248,6 +251,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
||||
if model_type:
|
||||
where_clause.append("type=?")
|
||||
bindings.append(model_type)
|
||||
if model_format:
|
||||
where_clause.append("format=?")
|
||||
bindings.append(model_format)
|
||||
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
|
||||
with self._db.lock:
|
||||
self._cursor.execute(
|
||||
|
Loading…
Reference in New Issue
Block a user