mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
084cf26ed6
There's no longer any need for session-scoped events now that we have the session queue. Session started/completed/canceled map 1-to-1 to queue item status events, but queue item status events also have an event for failed state. We can simplify queue and processor handling substantially by removing session events and instead using queue item events. - Remove the session-scoped events entirely. - Remove all event handling from session queue. The processor still needs to respond to some events from the queue: `QueueClearedEvent`, `BatchEnqueuedEvent` and `QueueItemStatusChangedEvent`. - Pass an `is_canceled` callback to the invocation context instead of the cancel event - Update processor logic to ensure the local instance of the current queue item is synced with the instance in the database. This prevents race conditions and ensures lifecycle callback do not get stale callbacks. - Update docstrings and comments - Add `complete_queue_item` method to session queue service as an explicit way to mark a queue item as successfully completed. Previously, the queue listened for session complete events to do this. Closes #6442
119 lines
4.0 KiB
Python
119 lines
4.0 KiB
Python
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
|
|
|
from typing import Any
|
|
|
|
from fastapi import FastAPI
|
|
from pydantic import BaseModel
|
|
from socketio import ASGIApp, AsyncServer
|
|
|
|
from invokeai.app.services.events.events_common import (
|
|
BatchEnqueuedEvent,
|
|
BulkDownloadCompleteEvent,
|
|
BulkDownloadErrorEvent,
|
|
BulkDownloadEventBase,
|
|
BulkDownloadStartedEvent,
|
|
DownloadCancelledEvent,
|
|
DownloadCompleteEvent,
|
|
DownloadErrorEvent,
|
|
DownloadProgressEvent,
|
|
DownloadStartedEvent,
|
|
FastAPIEvent,
|
|
InvocationCompleteEvent,
|
|
InvocationDenoiseProgressEvent,
|
|
InvocationErrorEvent,
|
|
InvocationStartedEvent,
|
|
ModelEventBase,
|
|
ModelInstallCancelledEvent,
|
|
ModelInstallCompleteEvent,
|
|
ModelInstallDownloadProgressEvent,
|
|
ModelInstallDownloadsCompleteEvent,
|
|
ModelInstallErrorEvent,
|
|
ModelInstallStartedEvent,
|
|
ModelLoadCompleteEvent,
|
|
ModelLoadStartedEvent,
|
|
QueueClearedEvent,
|
|
QueueEventBase,
|
|
QueueItemStatusChangedEvent,
|
|
register_events,
|
|
)
|
|
|
|
|
|
class QueueSubscriptionEvent(BaseModel):
|
|
queue_id: str
|
|
|
|
|
|
class BulkDownloadSubscriptionEvent(BaseModel):
|
|
bulk_download_id: str
|
|
|
|
|
|
QUEUE_EVENTS = {
|
|
InvocationStartedEvent,
|
|
InvocationDenoiseProgressEvent,
|
|
InvocationCompleteEvent,
|
|
InvocationErrorEvent,
|
|
QueueItemStatusChangedEvent,
|
|
BatchEnqueuedEvent,
|
|
QueueClearedEvent,
|
|
}
|
|
|
|
MODEL_EVENTS = {
|
|
DownloadCancelledEvent,
|
|
DownloadCompleteEvent,
|
|
DownloadErrorEvent,
|
|
DownloadProgressEvent,
|
|
DownloadStartedEvent,
|
|
ModelLoadStartedEvent,
|
|
ModelLoadCompleteEvent,
|
|
ModelInstallDownloadProgressEvent,
|
|
ModelInstallDownloadsCompleteEvent,
|
|
ModelInstallStartedEvent,
|
|
ModelInstallCompleteEvent,
|
|
ModelInstallCancelledEvent,
|
|
ModelInstallErrorEvent,
|
|
}
|
|
|
|
BULK_DOWNLOAD_EVENTS = {BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent}
|
|
|
|
|
|
class SocketIO:
|
|
_sub_queue = "subscribe_queue"
|
|
_unsub_queue = "unsubscribe_queue"
|
|
|
|
_sub_bulk_download = "subscribe_bulk_download"
|
|
_unsub_bulk_download = "unsubscribe_bulk_download"
|
|
|
|
def __init__(self, app: FastAPI):
|
|
self._sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
|
|
self._app = ASGIApp(socketio_server=self._sio, socketio_path="/ws/socket.io")
|
|
app.mount("/ws", self._app)
|
|
|
|
self._sio.on(self._sub_queue, handler=self._handle_sub_queue)
|
|
self._sio.on(self._unsub_queue, handler=self._handle_unsub_queue)
|
|
self._sio.on(self._sub_bulk_download, handler=self._handle_sub_bulk_download)
|
|
self._sio.on(self._unsub_bulk_download, handler=self._handle_unsub_bulk_download)
|
|
|
|
register_events(QUEUE_EVENTS, self._handle_queue_event)
|
|
register_events(MODEL_EVENTS, self._handle_model_event)
|
|
register_events(BULK_DOWNLOAD_EVENTS, self._handle_bulk_image_download_event)
|
|
|
|
async def _handle_sub_queue(self, sid: str, data: Any) -> None:
|
|
await self._sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id)
|
|
|
|
async def _handle_unsub_queue(self, sid: str, data: Any) -> None:
|
|
await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id)
|
|
|
|
async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None:
|
|
await self._sio.enter_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
|
|
|
|
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
|
|
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
|
|
|
|
async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
|
|
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].queue_id)
|
|
|
|
async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase]) -> None:
|
|
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"))
|
|
|
|
async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None:
|
|
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].bulk_download_id)
|