feat(mm): add logic to scan_folder route to check if a model is already installed

This was done in the frontend before but it's something the backend should handle.

The logic compares the found model paths to the path and source of all installed models. It excludes core models.
This commit is contained in:
psychedelicious 2024-02-24 18:46:54 +11:00
parent 98d60e7db5
commit 531d6f40f4

View File

@ -9,7 +9,7 @@ from typing import Any, Dict, List, Optional, Set
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
from starlette.exceptions import HTTPException
from typing_extensions import Annotated
@ -233,6 +233,11 @@ async def list_tags() -> Set[str]:
return result
class FoundModel(BaseModel):
path: str = Field(description="Path to the model")
is_installed: bool = Field(description="Whether or not the model is already installed")
@model_manager_router.get(
"/scan_folder",
operation_id="scan_for_models",
@ -241,11 +246,11 @@ async def list_tags() -> Set[str]:
400: {"description": "Invalid directory path"},
},
status_code=200,
response_model=List[pathlib.Path],
response_model=List[FoundModel],
)
async def scan_for_models(
scan_path: str = Query(description="Directory path to search for models", default=None),
) -> List[pathlib.Path]:
) -> List[FoundModel]:
path = pathlib.Path(scan_path)
if not scan_path or not path.is_dir():
raise HTTPException(
@ -255,13 +260,46 @@ async def scan_for_models(
search = ModelSearch()
try:
models_found = list(search.search(path))
found_model_paths = search.search(path)
models_path = ApiDependencies.invoker.services.configuration.models_path
# If the search path includes the main models directory, we need to exclude core models from the list.
# TODO(MM2): Core models should be handled by the model manager so we can determine if they are installed
# without needing to crawl the filesystem.
core_models_path = pathlib.Path(models_path, "core").resolve()
non_core_model_paths = [p for p in found_model_paths if not p.is_relative_to(core_models_path)]
installed_models = ApiDependencies.invoker.services.model_manager.store.search_by_attr()
resolved_installed_model_paths: list[str] = []
installed_model_sources: list[str] = []
# This call lists all installed models.
for model in installed_models:
path = pathlib.Path(model.path)
# If the model has a source, we need to add it to the list of installed sources.
if model.source:
installed_model_sources.append(model.source)
# If the path is not absolute, that means it is in the app models directory, and we need to join it with
# the models path before resolving.
if not path.is_absolute():
resolved_installed_model_paths.append(str(pathlib.Path(models_path, path).resolve()))
continue
resolved_installed_model_paths.append(str(path.resolve()))
scan_results: list[FoundModel] = []
# Check if the model is installed by comparing the resolved paths, appending to the scan result.
for p in non_core_model_paths:
path = str(p)
is_installed = path in resolved_installed_model_paths or path in installed_model_sources
found_model = FoundModel(path=path, is_installed=is_installed)
scan_results.append(found_model)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"An error occurred while searching the directory: {e}",
)
return models_found
return scan_results
@model_manager_router.get(