tidy(processor): use separate handlers for each event type

Just a bit clearer without needing `isinstance` checks.
This commit is contained in:
psychedelicious 2024-05-25 20:23:14 +10:00
parent 39415428de
commit 25d1d2b591

View File

@ -9,7 +9,6 @@ from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent, BatchEnqueuedEvent,
FastAPIEvent, FastAPIEvent,
QueueClearedEvent, QueueClearedEvent,
QueueEventBase,
QueueItemStatusChangedEvent, QueueItemStatusChangedEvent,
SessionCanceledEvent, SessionCanceledEvent,
register_events, register_events,
@ -294,10 +293,10 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._poll_now_event = ThreadEvent() self._poll_now_event = ThreadEvent()
self._cancel_event = ThreadEvent() self._cancel_event = ThreadEvent()
register_events( register_events(events={SessionCanceledEvent}, func=self._on_session_canceled)
events={SessionCanceledEvent, QueueClearedEvent, BatchEnqueuedEvent, QueueItemStatusChangedEvent}, register_events(events={QueueClearedEvent}, func=self._on_queue_cleared)
func=self._on_queue_event, register_events(events={BatchEnqueuedEvent}, func=self._on_batch_enqueued)
) register_events(events={QueueItemStatusChangedEvent}, func=self._on_queue_item_status_changed)
self._thread_semaphore = BoundedSemaphore(self._thread_limit) self._thread_semaphore = BoundedSemaphore(self._thread_limit)
@ -332,25 +331,21 @@ class DefaultSessionProcessor(SessionProcessorBase):
def _poll_now(self) -> None: def _poll_now(self) -> None:
self._poll_now_event.set() self._poll_now_event.set()
async def _on_queue_event(self, event: FastAPIEvent[QueueEventBase]) -> None: async def _on_session_canceled(self, event: FastAPIEvent[SessionCanceledEvent]) -> None:
_event_name, payload = event if self._queue_item and self._queue_item.item_id == event[1].item_id:
if (
isinstance(payload, SessionCanceledEvent)
and self._queue_item
and self._queue_item.item_id == payload.item_id
):
self._cancel_event.set() self._cancel_event.set()
self._poll_now() self._poll_now()
elif (
isinstance(payload, QueueClearedEvent) async def _on_queue_cleared(self, event: FastAPIEvent[QueueClearedEvent]) -> None:
and self._queue_item if self._queue_item and self._queue_item.queue_id == event[1].queue_id:
and self._queue_item.queue_id == payload.queue_id
):
self._cancel_event.set() self._cancel_event.set()
self._poll_now() self._poll_now()
elif isinstance(payload, BatchEnqueuedEvent):
async def _on_batch_enqueued(self, event: FastAPIEvent[BatchEnqueuedEvent]) -> None:
self._poll_now() self._poll_now()
elif isinstance(payload, QueueItemStatusChangedEvent) and payload.status in ["completed", "failed", "canceled"]:
async def _on_queue_item_status_changed(self, event: FastAPIEvent[QueueItemStatusChangedEvent]) -> None:
if self._queue_item and event[1].status in ["completed", "failed", "canceled"]:
self._poll_now() self._poll_now()
def resume(self) -> SessionProcessorStatus: def resume(self) -> SessionProcessorStatus: