mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
import_model API now working
This commit is contained in:
@ -15,9 +15,11 @@ from invokeai.backend.model_manager import (
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
ModelConfigBase,
|
||||
ModelInstallJob,
|
||||
SchedulerPredictionType,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.backend.model_manager.download import DownloadJobStatus
|
||||
from invokeai.backend.model_manager.merge import MergeInterpolationMethod
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
@ -39,6 +41,15 @@ class ModelsList(BaseModel):
|
||||
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||
|
||||
|
||||
class ModelImportStatus(BaseModel):
|
||||
"""Return information about a background installation job."""
|
||||
|
||||
job_id: int
|
||||
source: str
|
||||
priority: int
|
||||
status: DownloadJobStatus
|
||||
|
||||
|
||||
@models_router.get(
|
||||
"/",
|
||||
operation_id="list_models",
|
||||
@ -107,36 +118,31 @@ async def update_model(
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ImportModelResponse,
|
||||
response_model=ModelImportStatus,
|
||||
)
|
||||
async def import_model(
|
||||
location: str = Body(description="A model path, repo_id or URL to import"),
|
||||
prediction_type: Optional[Literal["v_prediction", "epsilon", "sample"]] = Body(
|
||||
description="Prediction type for SDv2 checkpoint files", default="v_prediction"
|
||||
),
|
||||
) -> ImportModelResponse:
|
||||
"""Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically"""
|
||||
) -> ModelImportStatus:
|
||||
"""
|
||||
Add a model using its local path, repo_id, or remote URL.
|
||||
|
||||
Model characteristics will be probed and configured automatically.
|
||||
The return object is a ModelInstallJob job ID. The work will be
|
||||
performed in the background. Listen on the event bus for a series of
|
||||
`model_event` events with an `id` matching the returned job id to get
|
||||
the progress, completion status, errors, and information on the
|
||||
model that was installed.
|
||||
"""
|
||||
|
||||
items_to_import = {location}
|
||||
prediction_types = {x.value: x for x in SchedulerPredictionType}
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
installed_models = ApiDependencies.invoker.services.model_manager.heuristic_import(
|
||||
items_to_import=items_to_import, prediction_type_helper=lambda x: prediction_types.get(prediction_type)
|
||||
result = ApiDependencies.invoker.services.model_manager.install_model(
|
||||
location, model_attributes={"prediction_type": SchedulerPredictionType(prediction_type)}
|
||||
)
|
||||
info = installed_models.get(location)
|
||||
|
||||
if not info:
|
||||
logger.error("Import failed")
|
||||
raise HTTPException(status_code=415)
|
||||
|
||||
logger.info(f"Successfully imported {location}, got {info}")
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=info.name, base_model=info.base_model, model_type=info.model_type
|
||||
)
|
||||
return parse_obj_as(ImportModelResponse, model_raw)
|
||||
|
||||
return ModelImportStatus(job_id=result.id, source=result.source, priority=result.priority, status=result.status)
|
||||
except UnknownModelException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
@ -221,7 +221,6 @@ export const modelsApi = api.injectEndpoints({
|
||||
const tags: ApiFullTagDescription[] = [
|
||||
{ type: 'MainModel', id: LIST_TAG },
|
||||
];
|
||||
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
|
Reference in New Issue
Block a user