added cancel_all and prune model install operations to router API

This commit is contained in:
Lincoln Stein
2023-09-25 17:34:59 -04:00
parent ac4634000a
commit effced8560
4 changed files with 90 additions and 5 deletions

View File

@ -414,7 +414,7 @@ async def list_install_jobs() -> List[ModelImportStatus]:
@models_router.patch(
"/jobs/{job_id}",
"/jobs/control/{operation}/{job_id}",
operation_id="control_install_jobs",
responses={
200: {"description": "The control job was updated successfully"},
@ -426,7 +426,7 @@ async def list_install_jobs() -> List[ModelImportStatus]:
)
async def control_install_jobs(
job_id: int = Path(description="Install job_id for start, pause and cancel operations"),
operation: JobControlOperation = Body(description="The operation to perform on the job."),
operation: JobControlOperation = Path(description="The operation to perform on the job."),
priority_delta: Optional[int] = Body(
description="Change in job priority for priority operations only. Negative numbers increase priority.",
default=None,
@ -449,6 +449,7 @@ async def control_install_jobs(
elif operation == JobControlOperation.CHANGE_PRIORITY:
mgr.change_job_priority(job_id, priority_delta)
else:
raise ValueError(f"Unknown operation {JobControlOperation.value}")
@ -468,3 +469,48 @@ async def control_install_jobs(
except Exception as e:
logger.error(str(e))
raise HTTPException(status_code=400, detail=str(e))
@models_router.patch(
"/jobs/cancel_all",
operation_id="cancel_all_jobs",
responses={
204: {"description": "All jobs cancelled successfully"},
400: {"description": "Bad request"},
},
status_code=200,
response_model=ModelImportStatus,
)
async def cancel_install_jobs():
"""Cancel all pending install jobs."""
logger = ApiDependencies.invoker.services.logger
try:
mgr = ApiDependencies.invoker.services.model_manager
logger.info("Cancelling all running model installation jobs.")
mgr.cancel_all_jobs()
return Response(status_code=204)
except Exception as e:
logger.error(str(e))
raise HTTPException(status_code=400, detail=str(e))
@models_router.patch(
"/jobs/prune",
operation_id="prune_jobs",
responses={
204: {"description": "All jobs cancelled successfully"},
400: {"description": "Bad request"},
},
status_code=200,
response_model=ModelImportStatus,
)
async def prune_jobs():
"""Prune all completed and errored jobs."""
logger = ApiDependencies.invoker.services.logger
try:
mgr = ApiDependencies.invoker.services.model_manager
mgr.prune_jobs()
return Response(status_code=204)
except Exception as e:
logger.error(str(e))
raise HTTPException(status_code=400, detail=str(e))

View File

@ -259,6 +259,16 @@ class ModelManagerServiceBase(ABC):
"""Cancel the given install job."""
pass
@abstractmethod
def cancel_all_jobs(self):
"""Cancel all active jobs."""
pass
@abstractmethod
def prune_jobs(self):
"""Remove completed or errored install jobs."""
pass
@abstractmethod
def change_job_priority(self, job_id: int, delta: int):
"""
@ -474,6 +484,16 @@ class ModelManagerService(ModelManagerServiceBase):
queue = self._loader.queue
queue.cancel_job(queue.id_to_job(job_id))
def cancel_all_jobs(self):
"""Cancel all active install job."""
queue = self._loader.queue
queue.cancel_all_jobs()
def prune_jobs(self):
"""Cancel all active install job."""
queue = self._loader.queue
queue.prune_jobs()
def change_job_priority(self, job_id: int, delta: int):
"""
Change an install job's priority.

View File

@ -224,6 +224,13 @@ class DownloadQueueBase(ABC):
"""Pause and dequeue all active jobs."""
pass
@abstractmethod
def prune_jobs(self):
"""
Prune completed and errored queue items from the job list.
"""
pass
@abstractmethod
def cancel_all_jobs(self, preserve_partial: bool = False):
"""

View File

@ -192,6 +192,21 @@ class DownloadQueue(DownloadQueueBase):
finally:
self._lock.release()
def prune_jobs(self):
"""
Prune completed and errored queue items from the job list.
"""
try:
self._lock.acquire()
to_delete = set()
for job in self._jobs:
if self._in_terminal_state(job):
self._job.remove(job)
except KeyError as excp:
raise UnknownJobIDException("Unrecognized job") from excp
finally:
self._lock.release()
def cancel_job(self, job: DownloadJobBase, preserve_partial: bool = False):
"""
Cancel the indicated job.
@ -322,9 +337,6 @@ class DownloadQueue(DownloadQueueBase):
if job.status == DownloadJobStatus.CANCELLED:
self._cleanup_cancelled_job(job)
if self._in_terminal_state(job):
del self._jobs[job.id]
self._queue.task_done()
def _get_metadata_and_url(self, job: DownloadJobBase) -> AnyHttpUrl: