diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 7f9ce0b41a..5e8406c290 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -9,7 +9,6 @@ from invokeai.app.services.events.events_common import ( BatchEnqueuedEvent, FastAPIEvent, QueueClearedEvent, - QueueEventBase, QueueItemStatusChangedEvent, SessionCanceledEvent, register_events, @@ -294,10 +293,10 @@ class DefaultSessionProcessor(SessionProcessorBase): self._poll_now_event = ThreadEvent() self._cancel_event = ThreadEvent() - register_events( - events={SessionCanceledEvent, QueueClearedEvent, BatchEnqueuedEvent, QueueItemStatusChangedEvent}, - func=self._on_queue_event, - ) + register_events(events={SessionCanceledEvent}, func=self._on_session_canceled) + register_events(events={QueueClearedEvent}, func=self._on_queue_cleared) + register_events(events={BatchEnqueuedEvent}, func=self._on_batch_enqueued) + register_events(events={QueueItemStatusChangedEvent}, func=self._on_queue_item_status_changed) self._thread_semaphore = BoundedSemaphore(self._thread_limit) @@ -332,25 +331,21 @@ class DefaultSessionProcessor(SessionProcessorBase): def _poll_now(self) -> None: self._poll_now_event.set() - async def _on_queue_event(self, event: FastAPIEvent[QueueEventBase]) -> None: - _event_name, payload = event - if ( - isinstance(payload, SessionCanceledEvent) - and self._queue_item - and self._queue_item.item_id == payload.item_id - ): + async def _on_session_canceled(self, event: FastAPIEvent[SessionCanceledEvent]) -> None: + if self._queue_item and self._queue_item.item_id == event[1].item_id: self._cancel_event.set() self._poll_now() - elif ( - isinstance(payload, QueueClearedEvent) - and self._queue_item - and self._queue_item.queue_id == payload.queue_id - ): + + async def _on_queue_cleared(self, event: FastAPIEvent[QueueClearedEvent]) -> None: + if self._queue_item and self._queue_item.queue_id == event[1].queue_id: self._cancel_event.set() self._poll_now() - elif isinstance(payload, BatchEnqueuedEvent): - self._poll_now() - elif isinstance(payload, QueueItemStatusChangedEvent) and payload.status in ["completed", "failed", "canceled"]: + + async def _on_batch_enqueued(self, event: FastAPIEvent[BatchEnqueuedEvent]) -> None: + self._poll_now() + + async def _on_queue_item_status_changed(self, event: FastAPIEvent[QueueItemStatusChangedEvent]) -> None: + if self._queue_item and event[1].status in ["completed", "failed", "canceled"]: self._poll_now() def resume(self) -> SessionProcessorStatus: