InvokeAI/invokeai/app/services/events/events_fastapievents.py
psychedelicious f6b8970bd1 fix(app): create reference to events task to prevent accidental GC
This wasn't a problem, but it's advised in the official docs so I've done it.
2024-08-12 07:49:58 +10:00

45 lines
1.8 KiB
Python

import asyncio
import threading
from fastapi_events.dispatcher import dispatch
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.events.events_common import EventBase
class FastAPIEventService(EventServiceBase):
def __init__(self, event_handler_id: int, loop: asyncio.AbstractEventLoop) -> None:
self.event_handler_id = event_handler_id
self._queue = asyncio.Queue[EventBase | None]()
self._stop_event = threading.Event()
self._loop = loop
# We need to store a reference to the task so it doesn't get GC'd
# See: https://docs.python.org/3/library/asyncio-task.html#creating-tasks
self._background_tasks: set[asyncio.Task[None]] = set()
task = self._loop.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.remove)
super().__init__()
def stop(self, *args, **kwargs):
self._stop_event.set()
self._loop.call_soon_threadsafe(self._queue.put_nowait, None)
def dispatch(self, event: EventBase) -> None:
self._loop.call_soon_threadsafe(self._queue.put_nowait, event)
async def _dispatch_from_queue(self, stop_event: threading.Event):
"""Get events on from the queue and dispatch them, from the correct thread"""
while not stop_event.is_set():
try:
event = await self._queue.get()
if not event: # Probably stopping
continue
# Leave the payloads as live pydantic models
dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False)
except asyncio.CancelledError as e:
raise e # Raise a proper error