Create /search endpoint, update model object structure in scan model page

This commit is contained in:
Brandon Rising 2024-02-21 11:54:02 -05:00 committed by Brandon
parent 4ac5e307c4
commit c0d9990344
3 changed files with 35 additions and 4 deletions

View File

@ -32,6 +32,7 @@ from invokeai.backend.model_manager.config import (
)
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from invokeai.backend.model_manager.search import ModelSearch
from ..dependencies import ApiDependencies
@ -233,6 +234,36 @@ async def list_tags() -> Set[str]:
result: Set[str] = record_store.list_tags()
return result
@model_manager_router.get(
"/search",
operation_id="search_for_models",
responses={
200: {"description": "Directory searched successfully"},
404: {"description": "Invalid directory path"},
},
status_code=200,
response_model=List[pathlib.Path],
)
async def search_for_models(
search_path: str = Query(description="Directory path to search for models", default=None),
) -> List[pathlib.Path]:
path = pathlib.Path(search_path)
if not search_path or not path.is_dir():
raise HTTPException(
status_code=404,
detail=f"The search path '{search_path}' does not exist or is not directory",
)
search = ModelSearch()
try:
models_found = list(search.search(path))
except Exception as e:
raise HTTPException(
status_code=404,
detail=f"An error occurred while searching the directory: {e}",
)
return models_found
@model_manager_router.get(
"/tags/search",

View File

@ -20,7 +20,7 @@ const ImportModelsPanel = () => {
<Button onClick={handleClickAddTab} isChecked={addModelTab === 'add'} size="sm" width="100%">
{t('modelManager.addModel')}
</Button>
<Button onClick={handleClickScanTab} isChecked={addModelTab === 'scan'} size="sm" width="100%" isDisabled>
<Button onClick={handleClickScanTab} isChecked={addModelTab === 'scan'} size="sm" width="100%">
{t('modelManager.scanForModels')}
</Button>
</ButtonGroup>

View File

@ -139,10 +139,10 @@ const modelsFilter = <T extends MainModelConfig | LoRAConfig>(
return;
}
const matchesFilter = model.model_name.toLowerCase().includes(nameFilter.toLowerCase());
const matchesFilter = model.name.toLowerCase().includes(nameFilter.toLowerCase());
const matchesFormat = model_format === undefined || model.model_format === model_format;
const matchesType = model.model_type === model_type;
const matchesFormat = model_format === undefined || model.format === model_format;
const matchesType = model.type === model_type;
if (matchesFilter && matchesFormat && matchesType) {
filteredModels.push(model);