fix(app): use asyncio queue and existing event loop for events

Around the time we (I) implemented pydantic events, I noticed a short pause between progress images every 4 or 5 steps when generating with SDXL. It didn't happen with SD1.5, but I did notice that with SD1.5, we'd get 4 or 5 progress events simultaneously. I'd expect one event every ~25ms, matching my it/s with SD1.5. Mysterious!

Digging in, I found an issue is related to our use of a synchronous queue for events. When the event queue is empty, we must call `asyncio.sleep` before checking again. We were sleeping for 100ms.

Said another way, every time we clear the event queue, we have to wait 100ms before another event can be dispatched, even if it is put on the queue immediately after we start waiting. In practice, this means our events get buffered into batches, dispatched once every 100ms.

This explains why I was getting batches of 4 or 5 SD1.5 progress events at once, but not the intermittent SDXL delay.

But this 100ms wait has another effect when the events are put on the queue in intervals that don't perfectly line up with the 100ms wait. This is most noticeable when the time between events is >100ms, and can add up to 100ms delay before the event is dispatched.

For example, say the queue is empty and we start a 100ms wait. Then, immediately after - like 0.01ms later - we push an event on to the queue. We still need to wait another 99.9ms before that event will be dispatched. That's the SDXL delay.

The easy fix is to reduce the sleep to something like 0.01 seconds, but this feels kinda dirty. Can't we just wait on the queue and dispatch every event immediately? Not with the normal synchronous queue - but we can with `asyncio.Queue`.

I switched the events queue to use `asyncio.Queue` (as seen in this commit), which lets us asynchronous wait on the queue in a loop.

Unfortunately, I ran into another issue - events now felt like their timing was inconsistent, but in a different way than with the 100ms sleep. The time between pushing events on the queue and dispatching them was not consistently ~0ms as I'd expect - it was highly variable from ~0ms up to ~100ms.

This is resolved by passing the asyncio loop directly into the events service and using its methods to create the task and interact with the queue. I don't fully understand why this resolved the issue, because either way we are interacting with the same event loop (as shown by `asyncio.get_running_loop()`). I suppose there's some scheduling magic happening.
This commit is contained in:
psychedelicious 2024-08-04 16:25:50 +10:00
parent 8ecf72838d
commit 29325a7214
3 changed files with 19 additions and 21 deletions

View File

@ -1,5 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio
from logging import Logger from logging import Logger
import torch import torch
@ -63,7 +64,12 @@ class ApiDependencies:
invoker: Invoker invoker: Invoker
@staticmethod @staticmethod
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger) -> None: def initialize(
config: InvokeAIAppConfig,
event_handler_id: int,
loop: asyncio.AbstractEventLoop,
logger: Logger = logger,
) -> None:
logger.info(f"InvokeAI version {__version__}") logger.info(f"InvokeAI version {__version__}")
logger.info(f"Root directory = {str(config.root_path)}") logger.info(f"Root directory = {str(config.root_path)}")
@ -84,7 +90,7 @@ class ApiDependencies:
board_images = BoardImagesService() board_images = BoardImagesService()
board_records = SqliteBoardRecordStorage(db=db) board_records = SqliteBoardRecordStorage(db=db)
boards = BoardService() boards = BoardService()
events = FastAPIEventService(event_handler_id) events = FastAPIEventService(event_handler_id, loop=loop)
bulk_download = BulkDownloadService() bulk_download = BulkDownloadService()
image_records = SqliteImageRecordStorage(db=db) image_records = SqliteImageRecordStorage(db=db)
images = ImageService() images = ImageService()

View File

@ -55,11 +55,13 @@ mimetypes.add_type("text/css", ".css")
torch_device_name = TorchDevice.get_torch_device_name() torch_device_name = TorchDevice.get_torch_device_name()
logger.info(f"Using torch device: {torch_device_name}") logger.info(f"Using torch device: {torch_device_name}")
loop = asyncio.new_event_loop()
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
# Add startup event to load dependencies # Add startup event to load dependencies
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger) ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, loop=loop, logger=logger)
yield yield
# Shut down threads # Shut down threads
ApiDependencies.shutdown() ApiDependencies.shutdown()
@ -184,8 +186,6 @@ def invoke_api() -> None:
check_cudnn(logger) check_cudnn(logger)
# Start our own event loop for eventing usage
loop = asyncio.new_event_loop()
config = uvicorn.Config( config = uvicorn.Config(
app=app, app=app,
host=app_config.host, host=app_config.host,

View File

@ -1,46 +1,38 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import asyncio import asyncio
import threading import threading
from queue import Empty, Queue
from fastapi_events.dispatcher import dispatch from fastapi_events.dispatcher import dispatch
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.events.events_common import ( from invokeai.app.services.events.events_common import EventBase
EventBase,
)
class FastAPIEventService(EventServiceBase): class FastAPIEventService(EventServiceBase):
def __init__(self, event_handler_id: int) -> None: def __init__(self, event_handler_id: int, loop: asyncio.AbstractEventLoop) -> None:
self.event_handler_id = event_handler_id self.event_handler_id = event_handler_id
self._queue = Queue[EventBase | None]() self._queue = asyncio.Queue[EventBase | None]()
self._stop_event = threading.Event() self._stop_event = threading.Event()
asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event)) self._loop = loop
self._loop.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
super().__init__() super().__init__()
def stop(self, *args, **kwargs): def stop(self, *args, **kwargs):
self._stop_event.set() self._stop_event.set()
self._queue.put(None) self._loop.call_soon_threadsafe(self._queue.put_nowait, None)
def dispatch(self, event: EventBase) -> None: def dispatch(self, event: EventBase) -> None:
self._queue.put(event) self._loop.call_soon_threadsafe(self._queue.put_nowait, event)
async def _dispatch_from_queue(self, stop_event: threading.Event): async def _dispatch_from_queue(self, stop_event: threading.Event):
"""Get events on from the queue and dispatch them, from the correct thread""" """Get events on from the queue and dispatch them, from the correct thread"""
while not stop_event.is_set(): while not stop_event.is_set():
try: try:
event = self._queue.get(block=False) event = await self._queue.get()
if not event: # Probably stopping if not event: # Probably stopping
continue continue
# Leave the payloads as live pydantic models # Leave the payloads as live pydantic models
dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False) dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False)
except Empty:
await asyncio.sleep(0.1)
pass
except asyncio.CancelledError as e: except asyncio.CancelledError as e:
raise e # Raise a proper error raise e # Raise a proper error