mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
[feature] add ability to filter model listings by format (#5286)
## What type of PR is this? (check all applicable) - [ ] Refactor - [X] Feature - [ ] Bug Fix - [ ] Optimization - [ ] Documentation Update - [ ] Community Node Submission ## Have you discussed this change with the InvokeAI team? - [X] Yes - [ ] No, because: ## Have you updated all relevant documentation? - [X] Yes - [ ] No ## Description This minor change adds the ability to filter the model lists returned by V2 of the model manager using the model file format (e.g. "checkpoint"). Just thought this would be a useful feature. ## Related Tickets & Documents <!-- For pull requests that relate or close an issue, please include them below. For example having the text: "closes #1234" would connect the current pull request to issue 1234. And when we merge the pull request, Github will automatically close the issue. --> - Related Issue # - Closes # ## QA Instructions, Screenshots, Recordings <!-- Please provide steps on how to test changes, any hardware or software specifications as well as any other pertinent information. --> ## Merge Plan This can be merged when approved without any adverse effects. <!-- A merge plan describes how this PR should be handled after it is approved. Example merge plans: - "This PR can be merged when approved" - "This must be squash-merged when approved" - "DO NOT MERGE - I will rebase and tidy commits before merging" - "#dev-chat on discord needs to be advised of this change when it is merged" A merge plan is particularly important for large PRs or PRs that touch the database in any way. --> ## Added/updated tests? - [ ] Yes - [X] No : minor feature - tested informally using the router API ## [optional] Are there any post deployment tasks we need to perform?
This commit is contained in:
commit
454f01e0c1
@ -45,6 +45,9 @@ async def list_model_records(
|
|||||||
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
|
||||||
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
|
||||||
model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"),
|
model_name: Optional[str] = Query(default=None, description="Exact match on the name of the model"),
|
||||||
|
model_format: Optional[str] = Query(
|
||||||
|
default=None, description="Exact match on the format of the model (e.g. 'diffusers')"
|
||||||
|
),
|
||||||
) -> ModelsList:
|
) -> ModelsList:
|
||||||
"""Get a list of models."""
|
"""Get a list of models."""
|
||||||
record_store = ApiDependencies.invoker.services.model_records
|
record_store = ApiDependencies.invoker.services.model_records
|
||||||
@ -52,10 +55,14 @@ async def list_model_records(
|
|||||||
if base_models:
|
if base_models:
|
||||||
for base_model in base_models:
|
for base_model in base_models:
|
||||||
found_models.extend(
|
found_models.extend(
|
||||||
record_store.search_by_attr(base_model=base_model, model_type=model_type, model_name=model_name)
|
record_store.search_by_attr(
|
||||||
|
base_model=base_model, model_type=model_type, model_name=model_name, model_format=model_format
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
found_models.extend(record_store.search_by_attr(model_type=model_type, model_name=model_name))
|
found_models.extend(
|
||||||
|
record_store.search_by_attr(model_type=model_type, model_name=model_name, model_format=model_format)
|
||||||
|
)
|
||||||
return ModelsList(models=found_models)
|
return ModelsList(models=found_models)
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
|
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType
|
||||||
|
|
||||||
|
|
||||||
class DuplicateModelException(Exception):
|
class DuplicateModelException(Exception):
|
||||||
@ -106,6 +106,7 @@ class ModelRecordServiceBase(ABC):
|
|||||||
model_name: Optional[str] = None,
|
model_name: Optional[str] = None,
|
||||||
base_model: Optional[BaseModelType] = None,
|
base_model: Optional[BaseModelType] = None,
|
||||||
model_type: Optional[ModelType] = None,
|
model_type: Optional[ModelType] = None,
|
||||||
|
model_format: Optional[ModelFormat] = None,
|
||||||
) -> List[AnyModelConfig]:
|
) -> List[AnyModelConfig]:
|
||||||
"""
|
"""
|
||||||
Return models matching name, base and/or type.
|
Return models matching name, base and/or type.
|
||||||
@ -113,6 +114,7 @@ class ModelRecordServiceBase(ABC):
|
|||||||
:param model_name: Filter by name of model (optional)
|
:param model_name: Filter by name of model (optional)
|
||||||
:param base_model: Filter by base model (optional)
|
:param base_model: Filter by base model (optional)
|
||||||
:param model_type: Filter by type of model (optional)
|
:param model_type: Filter by type of model (optional)
|
||||||
|
:param model_format: Filter by model format (e.g. "diffusers") (optional)
|
||||||
|
|
||||||
If none of the optional filters are passed, will return all
|
If none of the optional filters are passed, will return all
|
||||||
models in the database.
|
models in the database.
|
||||||
|
@ -49,6 +49,7 @@ from invokeai.backend.model_manager.config import (
|
|||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
ModelConfigFactory,
|
ModelConfigFactory,
|
||||||
|
ModelFormat,
|
||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -225,6 +226,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
model_name: Optional[str] = None,
|
model_name: Optional[str] = None,
|
||||||
base_model: Optional[BaseModelType] = None,
|
base_model: Optional[BaseModelType] = None,
|
||||||
model_type: Optional[ModelType] = None,
|
model_type: Optional[ModelType] = None,
|
||||||
|
model_format: Optional[ModelFormat] = None,
|
||||||
) -> List[AnyModelConfig]:
|
) -> List[AnyModelConfig]:
|
||||||
"""
|
"""
|
||||||
Return models matching name, base and/or type.
|
Return models matching name, base and/or type.
|
||||||
@ -232,6 +234,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
:param model_name: Filter by name of model (optional)
|
:param model_name: Filter by name of model (optional)
|
||||||
:param base_model: Filter by base model (optional)
|
:param base_model: Filter by base model (optional)
|
||||||
:param model_type: Filter by type of model (optional)
|
:param model_type: Filter by type of model (optional)
|
||||||
|
:param model_format: Filter by model format (e.g. "diffusers") (optional)
|
||||||
|
|
||||||
If none of the optional filters are passed, will return all
|
If none of the optional filters are passed, will return all
|
||||||
models in the database.
|
models in the database.
|
||||||
@ -248,6 +251,9 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
|
|||||||
if model_type:
|
if model_type:
|
||||||
where_clause.append("type=?")
|
where_clause.append("type=?")
|
||||||
bindings.append(model_type)
|
bindings.append(model_type)
|
||||||
|
if model_format:
|
||||||
|
where_clause.append("format=?")
|
||||||
|
bindings.append(model_format)
|
||||||
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
|
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
|
||||||
with self._db.lock:
|
with self._db.lock:
|
||||||
self._cursor.execute(
|
self._cursor.execute(
|
||||||
|
Loading…
Reference in New Issue
Block a user