import_model API now working

This commit is contained in:
Lincoln Stein
2023-09-16 22:17:39 -04:00
parent c029534243
commit 539776a15a
2 changed files with 26 additions and 21 deletions

View File

@ -15,9 +15,11 @@ from invokeai.backend.model_manager import (
DuplicateModelException,
InvalidModelException,
ModelConfigBase,
ModelInstallJob,
SchedulerPredictionType,
UnknownModelException,
)
from invokeai.backend.model_manager.download import DownloadJobStatus
from invokeai.backend.model_manager.merge import MergeInterpolationMethod
from ..dependencies import ApiDependencies
@ -39,6 +41,15 @@ class ModelsList(BaseModel):
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
class ModelImportStatus(BaseModel):
"""Return information about a background installation job."""
job_id: int
source: str
priority: int
status: DownloadJobStatus
@models_router.get(
"/",
operation_id="list_models",
@ -107,36 +118,31 @@ async def update_model(
409: {"description": "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
response_model=ImportModelResponse,
response_model=ModelImportStatus,
)
async def import_model(
location: str = Body(description="A model path, repo_id or URL to import"),
prediction_type: Optional[Literal["v_prediction", "epsilon", "sample"]] = Body(
description="Prediction type for SDv2 checkpoint files", default="v_prediction"
),
) -> ImportModelResponse:
"""Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically"""
) -> ModelImportStatus:
"""
Add a model using its local path, repo_id, or remote URL.
Model characteristics will be probed and configured automatically.
The return object is a ModelInstallJob job ID. The work will be
performed in the background. Listen on the event bus for a series of
`model_event` events with an `id` matching the returned job id to get
the progress, completion status, errors, and information on the
model that was installed.
"""
items_to_import = {location}
prediction_types = {x.value: x for x in SchedulerPredictionType}
logger = ApiDependencies.invoker.services.logger
try:
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
items_to_import=items_to_import, prediction_type_helper=lambda x: prediction_types.get(prediction_type)
result = ApiDependencies.invoker.services.model_manager.install_model(
location, model_attributes={"prediction_type": SchedulerPredictionType(prediction_type)}
)
info = installed_models.get(location)
if not info:
logger.error("Import failed")
raise HTTPException(status_code=415)
logger.info(f"Successfully imported {location}, got {info}")
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=info.name, base_model=info.base_model, model_type=info.model_type
)
return parse_obj_as(ImportModelResponse, model_raw)
return ModelImportStatus(job_id=result.id, source=result.source, priority=result.priority, status=result.status)
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))

View File

@ -221,7 +221,6 @@ export const modelsApi = api.injectEndpoints({
const tags: ApiFullTagDescription[] = [
{ type: 'MainModel', id: LIST_TAG },
];
if (result) {
tags.push(
...result.ids.map((id) => ({