mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
1 Commits
improve-co
...
feat/queue
Author | SHA1 | Date | |
---|---|---|---|
5e6b5c8fd6 |
@ -93,6 +93,18 @@ async def Pause(
|
|||||||
return ApiDependencies.invoker.services.session_processor.pause()
|
return ApiDependencies.invoker.services.session_processor.pause()
|
||||||
|
|
||||||
|
|
||||||
|
@session_queue_router.put(
|
||||||
|
"/{queue_id}/processor/take_one",
|
||||||
|
operation_id="take_one",
|
||||||
|
responses={200: {"model": SessionProcessorStatus}},
|
||||||
|
)
|
||||||
|
async def take_one(
|
||||||
|
queue_id: str = Path(description="The queue id to perform this operation on"),
|
||||||
|
) -> SessionProcessorStatus:
|
||||||
|
"""Executes the next-in-line queue item, pausing the processor afterwards. Has no effect if the queue is resumed."""
|
||||||
|
return ApiDependencies.invoker.services.session_processor.take_one()
|
||||||
|
|
||||||
|
|
||||||
@session_queue_router.put(
|
@session_queue_router.put(
|
||||||
"/{queue_id}/cancel_by_batch_ids",
|
"/{queue_id}/cancel_by_batch_ids",
|
||||||
operation_id="cancel_by_batch_ids",
|
operation_id="cancel_by_batch_ids",
|
||||||
|
@ -22,6 +22,11 @@ class SessionProcessorBase(ABC):
|
|||||||
"""Pauses the session processor"""
|
"""Pauses the session processor"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def take_one(self) -> SessionProcessorStatus:
|
||||||
|
"""Takes one session from the queue and executes it"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_status(self) -> SessionProcessorStatus:
|
def get_status(self) -> SessionProcessorStatus:
|
||||||
"""Gets the status of the session processor"""
|
"""Gets the status of the session processor"""
|
||||||
|
@ -25,6 +25,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
self.__resume_event = ThreadEvent()
|
self.__resume_event = ThreadEvent()
|
||||||
self.__stop_event = ThreadEvent()
|
self.__stop_event = ThreadEvent()
|
||||||
self.__poll_now_event = ThreadEvent()
|
self.__poll_now_event = ThreadEvent()
|
||||||
|
self.__take_one_event = ThreadEvent()
|
||||||
|
|
||||||
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
|
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_queue_event)
|
||||||
|
|
||||||
@ -36,6 +37,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
"stop_event": self.__stop_event,
|
"stop_event": self.__stop_event,
|
||||||
"poll_now_event": self.__poll_now_event,
|
"poll_now_event": self.__poll_now_event,
|
||||||
"resume_event": self.__resume_event,
|
"resume_event": self.__resume_event,
|
||||||
|
"take_one_event": self.__take_one_event,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
self.__thread.start()
|
self.__thread.start()
|
||||||
@ -81,6 +83,13 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
self.__resume_event.clear()
|
self.__resume_event.clear()
|
||||||
return self.get_status()
|
return self.get_status()
|
||||||
|
|
||||||
|
def take_one(self) -> SessionProcessorStatus:
|
||||||
|
if self.__queue_item is None and not self.__resume_event.is_set():
|
||||||
|
self.__resume_event.set()
|
||||||
|
self.__take_one_event.set()
|
||||||
|
self._poll_now()
|
||||||
|
return self.get_status()
|
||||||
|
|
||||||
def get_status(self) -> SessionProcessorStatus:
|
def get_status(self) -> SessionProcessorStatus:
|
||||||
return SessionProcessorStatus(
|
return SessionProcessorStatus(
|
||||||
is_started=self.__resume_event.is_set(),
|
is_started=self.__resume_event.is_set(),
|
||||||
@ -92,9 +101,11 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
stop_event: ThreadEvent,
|
stop_event: ThreadEvent,
|
||||||
poll_now_event: ThreadEvent,
|
poll_now_event: ThreadEvent,
|
||||||
resume_event: ThreadEvent,
|
resume_event: ThreadEvent,
|
||||||
|
take_one_event: ThreadEvent,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
stop_event.clear()
|
stop_event.clear()
|
||||||
|
take_one_event.clear()
|
||||||
resume_event.set()
|
resume_event.set()
|
||||||
self.__threadLimit.acquire()
|
self.__threadLimit.acquire()
|
||||||
queue_item: Optional[SessionQueueItem] = None
|
queue_item: Optional[SessionQueueItem] = None
|
||||||
@ -118,6 +129,10 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
)
|
)
|
||||||
queue_item = None
|
queue_item = None
|
||||||
|
|
||||||
|
if take_one_event.is_set():
|
||||||
|
resume_event.clear()
|
||||||
|
take_one_event.clear()
|
||||||
|
|
||||||
if queue_item is None:
|
if queue_item is None:
|
||||||
self.__invoker.services.logger.debug("Waiting for next polling interval or event")
|
self.__invoker.services.logger.debug("Waiting for next polling interval or event")
|
||||||
poll_now_event.wait(POLLING_INTERVAL)
|
poll_now_event.wait(POLLING_INTERVAL)
|
||||||
|
Reference in New Issue
Block a user