mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(app): use asyncio queue and existing event loop for events
Around the time we (I) implemented pydantic events, I noticed a short pause between progress images every 4 or 5 steps when generating with SDXL. It didn't happen with SD1.5, but I did notice that with SD1.5, we'd get 4 or 5 progress events simultaneously. I'd expect one event every ~25ms, matching my it/s with SD1.5. Mysterious! Digging in, I found an issue is related to our use of a synchronous queue for events. When the event queue is empty, we must call `asyncio.sleep` before checking again. We were sleeping for 100ms. Said another way, every time we clear the event queue, we have to wait 100ms before another event can be dispatched, even if it is put on the queue immediately after we start waiting. In practice, this means our events get buffered into batches, dispatched once every 100ms. This explains why I was getting batches of 4 or 5 SD1.5 progress events at once, but not the intermittent SDXL delay. But this 100ms wait has another effect when the events are put on the queue in intervals that don't perfectly line up with the 100ms wait. This is most noticeable when the time between events is >100ms, and can add up to 100ms delay before the event is dispatched. For example, say the queue is empty and we start a 100ms wait. Then, immediately after - like 0.01ms later - we push an event on to the queue. We still need to wait another 99.9ms before that event will be dispatched. That's the SDXL delay. The easy fix is to reduce the sleep to something like 0.01 seconds, but this feels kinda dirty. Can't we just wait on the queue and dispatch every event immediately? Not with the normal synchronous queue - but we can with `asyncio.Queue`. I switched the events queue to use `asyncio.Queue` (as seen in this commit), which lets us asynchronous wait on the queue in a loop. Unfortunately, I ran into another issue - events now felt like their timing was inconsistent, but in a different way than with the 100ms sleep. The time between pushing events on the queue and dispatching them was not consistently ~0ms as I'd expect - it was highly variable from ~0ms up to ~100ms. This is resolved by passing the asyncio loop directly into the events service and using its methods to create the task and interact with the queue. I don't fully understand why this resolved the issue, because either way we are interacting with the same event loop (as shown by `asyncio.get_running_loop()`). I suppose there's some scheduling magic happening.
This commit is contained in:
parent
8ecf72838d
commit
29325a7214
@ -1,5 +1,6 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -63,7 +64,12 @@ class ApiDependencies:
|
|||||||
invoker: Invoker
|
invoker: Invoker
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger = logger) -> None:
|
def initialize(
|
||||||
|
config: InvokeAIAppConfig,
|
||||||
|
event_handler_id: int,
|
||||||
|
loop: asyncio.AbstractEventLoop,
|
||||||
|
logger: Logger = logger,
|
||||||
|
) -> None:
|
||||||
logger.info(f"InvokeAI version {__version__}")
|
logger.info(f"InvokeAI version {__version__}")
|
||||||
logger.info(f"Root directory = {str(config.root_path)}")
|
logger.info(f"Root directory = {str(config.root_path)}")
|
||||||
|
|
||||||
@ -84,7 +90,7 @@ class ApiDependencies:
|
|||||||
board_images = BoardImagesService()
|
board_images = BoardImagesService()
|
||||||
board_records = SqliteBoardRecordStorage(db=db)
|
board_records = SqliteBoardRecordStorage(db=db)
|
||||||
boards = BoardService()
|
boards = BoardService()
|
||||||
events = FastAPIEventService(event_handler_id)
|
events = FastAPIEventService(event_handler_id, loop=loop)
|
||||||
bulk_download = BulkDownloadService()
|
bulk_download = BulkDownloadService()
|
||||||
image_records = SqliteImageRecordStorage(db=db)
|
image_records = SqliteImageRecordStorage(db=db)
|
||||||
images = ImageService()
|
images = ImageService()
|
||||||
|
@ -55,11 +55,13 @@ mimetypes.add_type("text/css", ".css")
|
|||||||
torch_device_name = TorchDevice.get_torch_device_name()
|
torch_device_name = TorchDevice.get_torch_device_name()
|
||||||
logger.info(f"Using torch device: {torch_device_name}")
|
logger.info(f"Using torch device: {torch_device_name}")
|
||||||
|
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
# Add startup event to load dependencies
|
# Add startup event to load dependencies
|
||||||
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, logger=logger)
|
ApiDependencies.initialize(config=app_config, event_handler_id=event_handler_id, loop=loop, logger=logger)
|
||||||
yield
|
yield
|
||||||
# Shut down threads
|
# Shut down threads
|
||||||
ApiDependencies.shutdown()
|
ApiDependencies.shutdown()
|
||||||
@ -184,8 +186,6 @@ def invoke_api() -> None:
|
|||||||
|
|
||||||
check_cudnn(logger)
|
check_cudnn(logger)
|
||||||
|
|
||||||
# Start our own event loop for eventing usage
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
config = uvicorn.Config(
|
config = uvicorn.Config(
|
||||||
app=app,
|
app=app,
|
||||||
host=app_config.host,
|
host=app_config.host,
|
||||||
|
@ -1,46 +1,38 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import threading
|
import threading
|
||||||
from queue import Empty, Queue
|
|
||||||
|
|
||||||
from fastapi_events.dispatcher import dispatch
|
from fastapi_events.dispatcher import dispatch
|
||||||
|
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.services.events.events_common import (
|
from invokeai.app.services.events.events_common import EventBase
|
||||||
EventBase,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FastAPIEventService(EventServiceBase):
|
class FastAPIEventService(EventServiceBase):
|
||||||
def __init__(self, event_handler_id: int) -> None:
|
def __init__(self, event_handler_id: int, loop: asyncio.AbstractEventLoop) -> None:
|
||||||
self.event_handler_id = event_handler_id
|
self.event_handler_id = event_handler_id
|
||||||
self._queue = Queue[EventBase | None]()
|
self._queue = asyncio.Queue[EventBase | None]()
|
||||||
self._stop_event = threading.Event()
|
self._stop_event = threading.Event()
|
||||||
asyncio.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
|
self._loop = loop
|
||||||
|
self._loop.create_task(self._dispatch_from_queue(stop_event=self._stop_event))
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def stop(self, *args, **kwargs):
|
def stop(self, *args, **kwargs):
|
||||||
self._stop_event.set()
|
self._stop_event.set()
|
||||||
self._queue.put(None)
|
self._loop.call_soon_threadsafe(self._queue.put_nowait, None)
|
||||||
|
|
||||||
def dispatch(self, event: EventBase) -> None:
|
def dispatch(self, event: EventBase) -> None:
|
||||||
self._queue.put(event)
|
self._loop.call_soon_threadsafe(self._queue.put_nowait, event)
|
||||||
|
|
||||||
async def _dispatch_from_queue(self, stop_event: threading.Event):
|
async def _dispatch_from_queue(self, stop_event: threading.Event):
|
||||||
"""Get events on from the queue and dispatch them, from the correct thread"""
|
"""Get events on from the queue and dispatch them, from the correct thread"""
|
||||||
while not stop_event.is_set():
|
while not stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
event = self._queue.get(block=False)
|
event = await self._queue.get()
|
||||||
if not event: # Probably stopping
|
if not event: # Probably stopping
|
||||||
continue
|
continue
|
||||||
# Leave the payloads as live pydantic models
|
# Leave the payloads as live pydantic models
|
||||||
dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False)
|
dispatch(event, middleware_id=self.event_handler_id, payload_schema_dump=False)
|
||||||
|
|
||||||
except Empty:
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
pass
|
|
||||||
|
|
||||||
except asyncio.CancelledError as e:
|
except asyncio.CancelledError as e:
|
||||||
raise e # Raise a proper error
|
raise e # Raise a proper error
|
||||||
|
Loading…
Reference in New Issue
Block a user