mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
added cancel_all
and prune
model install operations to router API
This commit is contained in:
@ -414,7 +414,7 @@ async def list_install_jobs() -> List[ModelImportStatus]:
|
||||
|
||||
|
||||
@models_router.patch(
|
||||
"/jobs/{job_id}",
|
||||
"/jobs/control/{operation}/{job_id}",
|
||||
operation_id="control_install_jobs",
|
||||
responses={
|
||||
200: {"description": "The control job was updated successfully"},
|
||||
@ -426,7 +426,7 @@ async def list_install_jobs() -> List[ModelImportStatus]:
|
||||
)
|
||||
async def control_install_jobs(
|
||||
job_id: int = Path(description="Install job_id for start, pause and cancel operations"),
|
||||
operation: JobControlOperation = Body(description="The operation to perform on the job."),
|
||||
operation: JobControlOperation = Path(description="The operation to perform on the job."),
|
||||
priority_delta: Optional[int] = Body(
|
||||
description="Change in job priority for priority operations only. Negative numbers increase priority.",
|
||||
default=None,
|
||||
@ -449,6 +449,7 @@ async def control_install_jobs(
|
||||
|
||||
elif operation == JobControlOperation.CHANGE_PRIORITY:
|
||||
mgr.change_job_priority(job_id, priority_delta)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown operation {JobControlOperation.value}")
|
||||
|
||||
@ -468,3 +469,48 @@ async def control_install_jobs(
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@models_router.patch(
|
||||
"/jobs/cancel_all",
|
||||
operation_id="cancel_all_jobs",
|
||||
responses={
|
||||
204: {"description": "All jobs cancelled successfully"},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=ModelImportStatus,
|
||||
)
|
||||
async def cancel_install_jobs():
|
||||
"""Cancel all pending install jobs."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
mgr = ApiDependencies.invoker.services.model_manager
|
||||
logger.info("Cancelling all running model installation jobs.")
|
||||
mgr.cancel_all_jobs()
|
||||
return Response(status_code=204)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@models_router.patch(
|
||||
"/jobs/prune",
|
||||
operation_id="prune_jobs",
|
||||
responses={
|
||||
204: {"description": "All jobs cancelled successfully"},
|
||||
400: {"description": "Bad request"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=ModelImportStatus,
|
||||
)
|
||||
async def prune_jobs():
|
||||
"""Prune all completed and errored jobs."""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
try:
|
||||
mgr = ApiDependencies.invoker.services.model_manager
|
||||
mgr.prune_jobs()
|
||||
return Response(status_code=204)
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
@ -259,6 +259,16 @@ class ModelManagerServiceBase(ABC):
|
||||
"""Cancel the given install job."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_all_jobs(self):
|
||||
"""Cancel all active jobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def prune_jobs(self):
|
||||
"""Remove completed or errored install jobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def change_job_priority(self, job_id: int, delta: int):
|
||||
"""
|
||||
@ -474,6 +484,16 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
queue = self._loader.queue
|
||||
queue.cancel_job(queue.id_to_job(job_id))
|
||||
|
||||
def cancel_all_jobs(self):
|
||||
"""Cancel all active install job."""
|
||||
queue = self._loader.queue
|
||||
queue.cancel_all_jobs()
|
||||
|
||||
def prune_jobs(self):
|
||||
"""Cancel all active install job."""
|
||||
queue = self._loader.queue
|
||||
queue.prune_jobs()
|
||||
|
||||
def change_job_priority(self, job_id: int, delta: int):
|
||||
"""
|
||||
Change an install job's priority.
|
||||
|
@ -224,6 +224,13 @@ class DownloadQueueBase(ABC):
|
||||
"""Pause and dequeue all active jobs."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def prune_jobs(self):
|
||||
"""
|
||||
Prune completed and errored queue items from the job list.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_all_jobs(self, preserve_partial: bool = False):
|
||||
"""
|
||||
|
@ -192,6 +192,21 @@ class DownloadQueue(DownloadQueueBase):
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def prune_jobs(self):
|
||||
"""
|
||||
Prune completed and errored queue items from the job list.
|
||||
"""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
to_delete = set()
|
||||
for job in self._jobs:
|
||||
if self._in_terminal_state(job):
|
||||
self._job.remove(job)
|
||||
except KeyError as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def cancel_job(self, job: DownloadJobBase, preserve_partial: bool = False):
|
||||
"""
|
||||
Cancel the indicated job.
|
||||
@ -322,9 +337,6 @@ class DownloadQueue(DownloadQueueBase):
|
||||
if job.status == DownloadJobStatus.CANCELLED:
|
||||
self._cleanup_cancelled_job(job)
|
||||
|
||||
if self._in_terminal_state(job):
|
||||
del self._jobs[job.id]
|
||||
|
||||
self._queue.task_done()
|
||||
|
||||
def _get_metadata_and_url(self, job: DownloadJobBase) -> AnyHttpUrl:
|
||||
|
Reference in New Issue
Block a user