add install job control to web API

This commit is contained in:
Lincoln Stein
2023-09-17 15:28:37 -04:00
parent e880f4bcfb
commit f0ce559d28
6 changed files with 271 additions and 102 deletions

View File

@ -2,6 +2,7 @@
import pathlib
from enum import Enum
from typing import List, Literal, Optional, Union
from fastapi import Body, Path, Query, Response
@ -15,11 +16,10 @@ from invokeai.backend.model_manager import (
DuplicateModelException,
InvalidModelException,
ModelConfigBase,
ModelInstallJob,
SchedulerPredictionType,
UnknownModelException,
)
from invokeai.backend.model_manager.download import DownloadJobStatus
from invokeai.backend.model_manager.download import DownloadJobStatus, UnknownJobIDException
from invokeai.backend.model_manager.merge import MergeInterpolationMethod
from ..dependencies import ApiDependencies
@ -47,9 +47,18 @@ class ModelImportStatus(BaseModel):
job_id: int
source: str
priority: int
bytes: int
total_bytes: int
status: DownloadJobStatus
class JobControlOperation(str, Enum):
START = "Start"
PAUSE = "Pause"
CANCEL = "Cancel"
CHANGE_PRIORITY = "Change Priority"
@models_router.get(
"/",
operation_id="list_models",
@ -129,20 +138,27 @@ async def import_model(
"""
Add a model using its local path, repo_id, or remote URL.
Model characteristics will be probed and configured automatically.
The return object is a ModelInstallJob job ID. The work will be
performed in the background. Listen on the event bus for a series of
`model_event` events with an `id` matching the returned job id to get
the progress, completion status, errors, and information on the
model that was installed.
"""
Models will be downloaded, probed, configured and installed in a
series of background threads. The return object has a `job_id` property
that can be used to control the download job.
Listen on the event bus for a series of `model_event` events with an `id`
matching the returned job id to get the progress, completion status, errors,
and information on the model that was installed.
"""
logger = ApiDependencies.invoker.services.logger
try:
result = ApiDependencies.invoker.services.model_manager.install_model(
location, model_attributes={"prediction_type": SchedulerPredictionType(prediction_type)}
)
return ModelImportStatus(job_id=result.id, source=result.source, priority=result.priority, status=result.status)
return ModelImportStatus(
job_id=result.id,
source=result.source,
priority=result.priority,
bytes=result.bytes,
total_bytes=result.total_bytes,
status=result.status,
)
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
@ -169,19 +185,24 @@ async def import_model(
async def add_model(
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> ImportModelResponse:
"""Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
"""
Add a model using the configuration information appropriate for its type. Only local models can be added by path.
This call will block until the model is installed.
"""
logger = ApiDependencies.invoker.services.logger
try:
ApiDependencies.invoker.services.model_manager.add_model(
info.model_name, info.base_model, info.model_type, model_attributes=info.dict()
)
logger.info(f"Successfully added {info.model_name}")
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=info.model_name, base_model=info.base_model, model_type=info.model_type
)
return parse_obj_as(ImportModelResponse, model_raw)
path = info.path
job = ApiDependencies.invoker.services.model_manager.add_model(path)
ApiDependencies.invoker.services.model_manager.wait_for_installs()
key = job.model_key
logger.info(f"Created model {key} for {path}")
# update with the provided info
info_dict = info.dict()
info_dict = {x: info_dict[x] if info_dict[x] else None for x in info_dict.keys()}
new_config = ApiDependencies.invoker.services.model_manager.update_model(key, new_config=info_dict)
return parse_obj_as(ImportModelResponse, new_config.dict())
except UnknownModelException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
@ -344,3 +365,92 @@ async def merge_models(
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return response
@models_router.get(
"/jobs",
operation_id="list_install_jobs",
responses={
200: {"description": "The control job was updated successfully"},
400: {"description": "Bad request"},
},
status_code=200,
response_model=List[ModelImportStatus],
)
async def list_install_jobs() -> List[ModelImportStatus]:
"""List active and pending model installation jobs."""
logger = ApiDependencies.invoker.services.logger
mgr = ApiDependencies.invoker.services.model_manager
try:
jobs = mgr.list_install_jobs()
return [
ModelImportStatus(
job_id=x.id,
source=x.source,
priority=x.priority,
bytes=x.bytes,
total_bytes=x.total_bytes,
status=x.status,
)
for x in jobs
]
except Exception as e:
logger.error(str(e))
raise HTTPException(status_code=400, detail=str(e))
@models_router.patch(
"/jobs/{job_id}",
operation_id="control_install_jobs",
responses={
200: {"description": "The control job was updated successfully"},
400: {"description": "Bad request"},
404: {"description": "The job could not be found"},
},
status_code=200,
response_model=ModelImportStatus,
)
async def control_install_jobs(
job_id: int = Path(description="Install job_id for start, pause and cancel operations"),
operation: JobControlOperation = Body(description="The operation to perform on the job."),
priority_delta: Optional[int] = Body(
description="Change in job priority for priority operations only. Negative numbers increase priority.",
default=None,
),
) -> ModelImportStatus:
"""Start, pause, cancel, or change the run priority of a running model install job."""
logger = ApiDependencies.invoker.services.logger
mgr = ApiDependencies.invoker.services.model_manager
try:
job = mgr.id_to_job(job_id)
if operation == JobControlOperation.START:
mgr.start_job(job_id)
elif operation == JobControlOperation.PAUSE:
mgr.pause_job(job_id)
elif operation == JobControlOperation.CANCEL:
mgr.cancel_job(job_id)
elif operation == JobControlOperation.CHANGE_PRIORITY:
mgr.change_job_priority(job_id, priority_delta)
else:
raise ValueError(f"Unknown operation {JobControlOperation.value}")
return ModelImportStatus(
job_id=job_id,
source=job.source,
priority=job.priority,
status=job.status,
bytes=job.bytes,
total_bytes=job.total_bytes,
)
except UnknownJobIDException as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
except Exception as e:
logger.error(str(e))
raise HTTPException(status_code=400, detail=str(e))

View File

@ -6,7 +6,6 @@ from invokeai.app.models.image import ProgressImage
from invokeai.app.services.model_manager_service import BaseModelType, ModelInfo, ModelType, SubModelType
from invokeai.app.util.misc import get_timestamp
from invokeai.backend.model_manager.download import DownloadJobBase
from invokeai.backend.util.logging import InvokeAILogger

View File

@ -7,7 +7,7 @@ from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from pydantic import Field
from pydantic import Field, parse_obj_as
from pydantic.networks import AnyHttpUrl
from invokeai.app.models.exceptions import CanceledException
@ -24,6 +24,7 @@ from invokeai.backend.model_manager import (
UnknownModelException,
)
from invokeai.backend.model_manager.cache import CacheStats
from invokeai.backend.model_manager.download import DownloadJobBase
from invokeai.backend.model_manager.merge import MergeInterpolationMethod, ModelMerger
from .config import InvokeAIAppConfig
@ -32,11 +33,16 @@ if TYPE_CHECKING:
from ..invocations.baseinvocation import InvocationContext
# "forward declaration" because of circular import issues
class EventServiceBase:
pass
class ModelManagerServiceBase(ABC):
"""Responsible for managing models on disk and in memory."""
@abstractmethod
def __init__(self, config: InvokeAIAppConfig, event_bus: Optional["EventServiceBase"] = None):
def __init__(self, config: InvokeAIAppConfig, event_bus: Optional[EventServiceBase] = None):
"""
Initialize a ModelManagerService.
@ -211,6 +217,60 @@ class ModelManagerServiceBase(ABC):
"""
pass
@abstractmethod
def list_install_jobs(self) -> List[ModelInstallJob]:
"""Return a series of active or enqueued ModelInstallJobs."""
pass
@abstractmethod
def id_to_job(self, id: int) -> ModelInstallJob:
"""Return the ModelInstallJob instance corresponding to the given job ID."""
pass
@abstractmethod
def wait_for_installs(self) -> Dict[str, str]:
"""
Wait for all pending installs to complete.
This will block until all pending downloads have
completed, been cancelled, or errored out. It will
block indefinitely if one or more jobs are in the
paused state.
It will return a dict that maps the source model
path, URL or repo_id to the ID of the installed model.
"""
pass
@abstractmethod
def start_job(self, job_id: int):
"""Start the given install job if it is paused or idle."""
pass
@abstractmethod
def pause_job(self, job_id: int):
"""Pause the given install job if it is paused or idle."""
pass
@abstractmethod
def cancel_job(self, job_id: int):
"""Cancel the given install job."""
pass
@abstractmethod
def change_job_priority(self, job_id: int, delta: int):
"""
Change an install job's priority.
:param job_id: Job to change
:param delta: Value to increment or decrement priority.
Lower values are higher priority. The default starting value is 10.
Thus to make this a really high priority job:
manager.change_job_priority(-10).
"""
pass
@abstractmethod
def merge_models(
self,
@ -256,35 +316,6 @@ class ModelManagerServiceBase(ABC):
"""Reset model cache statistics for graph with graph_id."""
pass
@abstractmethod
def cancel_job(self, job: ModelInstallJob):
"""Cancel this job."""
pass
@abstractmethod
def pause_job(self, job: ModelInstallJob):
"""Pause this job."""
pass
@abstractmethod
def start_job(self, job: ModelInstallJob):
"""(re)start this job."""
pass
@abstractmethod
def change_priority(self, job: ModelInstallJob, delta: int):
"""
Raise or lower the priority of the job.
:param job: Job to apply change to
:param delta: Value to increment or decrement priority.
Lower values are higher priority. The default starting value is 10.
Thus to make my_job a really high priority job:
manager.change_priority(my_job, -10).
"""
pass
# implementation
class ModelManagerService(ModelManagerServiceBase):
@ -390,7 +421,7 @@ class ModelManagerService(ModelManagerServiceBase):
attach to the model. When installing a URL or repo_id, some metadata
values, such as `tags` will be automagically added.
"""
self.logger.debug(f"add/update model {source}")
self.logger.debug(f"add model {source}")
variant = "fp16" if self._loader.precision == "float16" else None
return self._loader.installer.install(
source,
@ -398,6 +429,59 @@ class ModelManagerService(ModelManagerServiceBase):
variant=variant,
)
def list_install_jobs(self) -> List[ModelInstallJob]:
"""Return a series of active or enqueued ModelInstallJobs."""
queue = self._loader.queue
jobs: List[DownloadJobBase] = queue.list_jobs()
return [parse_obj_as(ModelInstallJob, x) for x in jobs] # downcast to proper type
def id_to_job(self, id: int) -> ModelInstallJob:
"""Return the ModelInstallJob instance corresponding to the given job ID."""
return self._loader.queue.id_to_job(id)
def wait_for_installs(self) -> Dict[str, str]:
"""
Wait for all pending installs to complete.
This will block until all pending downloads have
completed, been cancelled, or errored out. It will
block indefinitely if one or more jobs are in the
paused state.
It will return a dict that maps the source model
path, URL or repo_id to the ID of the installed model.
"""
return self._loader.installer.wait_for_installs()
def start_job(self, job_id: int):
"""Start the given install job if it is paused or idle."""
queue = self._loader.queue
queue.start_job(queue.id_to_job(job_id))
def pause_job(self, job_id: int):
"""Pause the given install job if it is paused or idle."""
queue = self._loader.queue
queue.pause_job(queue.id_to_job(job_id))
def cancel_job(self, job_id: int):
"""Cancel the given install job."""
queue = self._loader.queue
queue.cancel_job(queue.id_to_job(job_id))
def change_job_priority(self, job_id: int, delta: int):
"""
Change an install job's priority.
:param job_id: Job to change
:param delta: Value to increment or decrement priority.
Lower values are higher priority. The default starting value is 10.
Thus to make this a really high priority job:
manager.change_job_priority(-10).
"""
queue = self._loader.queue
queue.change_priority(queue.id_to_job(job_id), delta)
def update_model(
self,
key: str,
@ -415,7 +499,8 @@ class ModelManagerService(ModelManagerServiceBase):
"""
self.logger.debug(f"update model {key}")
new_info = self._loader.store.update_model(key, new_config)
return self._loader.installer.sync_model_path(new_info.key)
self._loader.installer.sync_model_path(new_info.key)
return new_info
def del_model(
self,
@ -570,28 +655,3 @@ class ModelManagerService(ModelManagerServiceBase):
:param new_name: New name for the model
"""
return self.update_model(key, {"name": new_name})
def cancel_job(self, job: ModelInstallJob):
"""Cancel this job."""
self._loader.queue.cancel_job(job)
def pause_job(self, job: ModelInstallJob):
"""Pause this job."""
self._loader.queue.pause_job(job)
def start_job(self, job: ModelInstallJob):
"""(re)start this job."""
self._loader.queue.start_job(job)
def change_priority(self, job: ModelInstallJob, delta: int):
"""
Raise or lower the priority of the job.
:param job: Job to apply change to
:param delta: Value to increment or decrement priority.
Lower values are higher priority. The default starting value is 10.
Thus to make my_job a really high priority job:
manager.change_priority(my_job, -10).
"""
self._loader.queue.change_priority(job, delta)

View File

@ -17,7 +17,7 @@ from .install import ModelInstall, ModelInstallJob # noqa F401
from .loader import ModelInfo, ModelLoad # noqa F401
from .lora import ModelPatcher, ONNXModelPatcher
from .models import OPENAPI_MODEL_CONFIGS, read_checkpoint_meta # noqa F401
from .probe import InvalidModelException, ModelProbeInfo # noqa F401
from .probe import InvalidModelException, ModelProbe, ModelProbeInfo # noqa F401
from .search import ModelSearch # noqa F401
from .storage import ( # noqa F401
DuplicateModelException,

View File

@ -383,22 +383,26 @@ class ModelInstall(ModelInstallBase):
model_format=info.format,
)
# add 'main' specific fields
if info.model_type == ModelType.Main and info.format == ModelFormat.Checkpoint:
try:
config_file = self._legacy_configs[info.base_type][info.variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models
if prediction_type := info.prediction_type:
config_file = config_file[prediction_type]
else:
self._logger.warning(
f"Could not infer prediction type for {model_path.stem}. Guessing 'v_prediction' for a SD-2 768 pixel model"
)
config_file = config_file[SchedulerPredictionType.VPrediction]
registration_data.update(
config=Path(self._config.legacy_conf_dir, config_file).as_posix(),
)
except KeyError as exc:
raise InvalidModelException("Configuration file for this checkpoint could not be determined") from exc
if info.model_type == ModelType.Main:
registration_data.update(variant=info.variant_type)
if info.format == ModelFormat.Checkpoint:
try:
config_file = self._legacy_configs[info.base_type][info.variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models
if prediction_type := info.prediction_type:
config_file = config_file[prediction_type]
else:
self._logger.warning(
f"Could not infer prediction type for {model_path.stem}. Guessing 'v_prediction' for a SD-2 768 pixel model"
)
config_file = config_file[SchedulerPredictionType.VPrediction]
registration_data.update(
config=Path(self._config.legacy_conf_dir, config_file).as_posix(),
)
except KeyError as exc:
raise InvalidModelException(
"Configuration file for this checkpoint could not be determined"
) from exc
self._store.add_model(key, registration_data)
return key

View File

@ -456,12 +456,8 @@ class PipelineFolderProbe(FolderProbeBase):
"""Return the SchedulerPredictionType of a diffusers-style sd-2 model."""
with open(self.model / "scheduler" / "scheduler_config.json", "r") as file:
scheduler_conf = json.load(file)
if scheduler_conf["prediction_type"] == "v_prediction":
return SchedulerPredictionType.VPrediction
elif scheduler_conf["prediction_type"] == "epsilon":
return SchedulerPredictionType.Epsilon
else:
return None
prediction_type = scheduler_conf.get("prediction_type", "epsilon")
return SchedulerPredictionType(prediction_type)
def get_variant_type(self) -> ModelVariantType:
"""Return the ModelVariantType for diffusers-style main models."""