InvokeAI/invokeai/app/api/sockets.py
psychedelicious 9bd78823a3 refactor(events): use pydantic schemas for events
Our events handling and implementation has a couple pain points:
- Adding or removing data from event payloads requires changes wherever the events are dispatched from.
- We have no type safety for events and need to rely on string matching and dict access when interacting with events.
- Frontend types for socket events must be manually typed. This has caused several bugs.

`fastapi-events` has a neat feature where you can create a pydantic model as an event payload, give it an `__event_name__` attr, and then dispatch the model directly.

This allows us to eliminate a layer of indirection and some unpleasant complexity:
- Event handler callbacks get type hints for their event payloads, and can use `isinstance` on them if needed.
- Event payload construction is now the responsibility of the event itself (a pydantic model), not the service. Every event model has a `build` class method, encapsulating this logic. The build methods are provided as few args as possible. For example, `InvocationStartedEvent.build()` gets the invocation instance and queue item, and can choose the data it wants to include in the event payload.
- Frontend event types may be autogenerated from the OpenAPI schema. We use the payload registry feature of `fastapi-events` to collect all payload models into one place, making it trivial to keep our schema and frontend types in sync.

This commit moves the backend over to this improved event handling setup.
2024-05-27 09:06:02 +10:00

120 lines
4.1 KiB
Python

# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any
from fastapi import FastAPI
from pydantic import BaseModel
from socketio import ASGIApp, AsyncServer
from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent,
BulkDownloadCompleteEvent,
BulkDownloadErrorEvent,
BulkDownloadEvent,
BulkDownloadStartedEvent,
FastAPIEvent,
InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent,
InvocationStartedEvent,
ModelEvent,
ModelInstallCancelledEvent,
ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallErrorEvent,
ModelInstallStartedEvent,
ModelLoadCompleteEvent,
ModelLoadStartedEvent,
QueueClearedEvent,
QueueEvent,
QueueItemStatusChangedEvent,
SessionCanceledEvent,
SessionCompleteEvent,
SessionStartedEvent,
register_events,
)
class QueueSubscriptionEvent(BaseModel):
queue_id: str
class BulkDownloadSubscriptionEvent(BaseModel):
bulk_download_id: str
class SocketIO:
_sub_queue = "subscribe_queue"
_unsub_queue = "unsubscribe_queue"
_sub_bulk_download = "subscribe_bulk_download"
_unsub_bulk_download = "unsubscribe_bulk_download"
def __init__(self, app: FastAPI):
self._sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self._app = ASGIApp(socketio_server=self._sio, socketio_path="/ws/socket.io")
app.mount("/ws", self._app)
self._sio.on(self._sub_queue, handler=self._handle_sub_queue)
self._sio.on(self._unsub_queue, handler=self._handle_unsub_queue)
self._sio.on(self._sub_bulk_download, handler=self._handle_sub_bulk_download)
self._sio.on(self._unsub_bulk_download, handler=self._handle_unsub_bulk_download)
register_events(
{
InvocationStartedEvent,
InvocationDenoiseProgressEvent,
InvocationCompleteEvent,
InvocationErrorEvent,
SessionStartedEvent,
SessionCompleteEvent,
SessionCanceledEvent,
QueueItemStatusChangedEvent,
BatchEnqueuedEvent,
QueueClearedEvent,
},
self._handle_queue_event,
)
register_events(
{
ModelLoadStartedEvent,
ModelLoadCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallStartedEvent,
ModelInstallCompleteEvent,
ModelInstallCancelledEvent,
ModelInstallErrorEvent,
},
self._handle_model_event,
)
register_events(
{BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent},
self._handle_bulk_image_download_event,
)
async def _handle_sub_queue(self, sid: str, data: Any) -> None:
await self._sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id)
async def _handle_unsub_queue(self, sid: str, data: Any) -> None:
await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id)
async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None:
await self._sio.enter_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
async def _handle_queue_event(self, event: FastAPIEvent[QueueEvent]):
event_name, payload = event
await self._sio.emit(event=event_name, data=payload.model_dump(), room=payload.queue_id)
async def _handle_model_event(self, event: FastAPIEvent[ModelEvent]) -> None:
event_name, payload = event
await self._sio.emit(event=event_name, data=payload.model_dump())
async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEvent]) -> None:
event_name, payload = event
await self._sio.emit(event=event_name, data=payload.model_dump())