InvokeAI/invokeai/app/api/sockets.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

125 lines
4.3 KiB
Python
Raw Normal View History

# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
2024-03-14 08:04:19 +00:00
from typing import Any
from fastapi import FastAPI
2024-03-14 08:04:19 +00:00
from pydantic import BaseModel
from socketio import ASGIApp, AsyncServer
2023-03-03 06:02:00 +00:00
2024-03-14 08:04:19 +00:00
from invokeai.app.services.events.events_common import (
BatchEnqueuedEvent,
BulkDownloadCompleteEvent,
BulkDownloadErrorEvent,
BulkDownloadEventBase,
2024-03-14 08:04:19 +00:00
BulkDownloadStartedEvent,
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
2024-03-14 08:04:19 +00:00
FastAPIEvent,
InvocationCompleteEvent,
InvocationDenoiseProgressEvent,
InvocationErrorEvent,
InvocationStartedEvent,
ModelEventBase,
2024-03-14 08:04:19 +00:00
ModelInstallCancelledEvent,
ModelInstallCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
2024-03-14 08:04:19 +00:00
ModelInstallErrorEvent,
ModelInstallStartedEvent,
ModelLoadCompleteEvent,
ModelLoadStartedEvent,
QueueClearedEvent,
QueueEventBase,
2024-03-14 08:04:19 +00:00
QueueItemStatusChangedEvent,
register_events,
)
class QueueSubscriptionEvent(BaseModel):
2024-05-26 11:05:27 +00:00
"""Event data for subscribing to the socket.io queue room.
This is a pydantic model to ensure the data is in the correct format."""
2024-03-14 08:04:19 +00:00
queue_id: str
class BulkDownloadSubscriptionEvent(BaseModel):
2024-05-26 11:05:27 +00:00
"""Event data for subscribing to the socket.io bulk downloads room.
This is a pydantic model to ensure the data is in the correct format."""
2024-03-14 08:04:19 +00:00
bulk_download_id: str
2023-03-03 06:02:00 +00:00
QUEUE_EVENTS = {
InvocationStartedEvent,
InvocationDenoiseProgressEvent,
InvocationCompleteEvent,
InvocationErrorEvent,
QueueItemStatusChangedEvent,
BatchEnqueuedEvent,
QueueClearedEvent,
}
MODEL_EVENTS = {
DownloadCancelledEvent,
DownloadCompleteEvent,
DownloadErrorEvent,
DownloadProgressEvent,
DownloadStartedEvent,
ModelLoadStartedEvent,
ModelLoadCompleteEvent,
ModelInstallDownloadProgressEvent,
ModelInstallDownloadsCompleteEvent,
ModelInstallStartedEvent,
ModelInstallCompleteEvent,
ModelInstallCancelledEvent,
ModelInstallErrorEvent,
}
BULK_DOWNLOAD_EVENTS = {BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent}
class SocketIO:
2024-03-14 08:04:19 +00:00
_sub_queue = "subscribe_queue"
_unsub_queue = "unsubscribe_queue"
2024-01-08 00:55:59 +00:00
2024-03-14 08:04:19 +00:00
_sub_bulk_download = "subscribe_bulk_download"
_unsub_bulk_download = "unsubscribe_bulk_download"
2024-01-08 00:55:59 +00:00
def __init__(self, app: FastAPI):
2024-03-14 08:04:19 +00:00
self._sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self._app = ASGIApp(socketio_server=self._sio, socketio_path="/ws/socket.io")
app.mount("/ws", self._app)
self._sio.on(self._sub_queue, handler=self._handle_sub_queue)
self._sio.on(self._unsub_queue, handler=self._handle_unsub_queue)
self._sio.on(self._sub_bulk_download, handler=self._handle_sub_bulk_download)
self._sio.on(self._unsub_bulk_download, handler=self._handle_unsub_bulk_download)
register_events(QUEUE_EVENTS, self._handle_queue_event)
register_events(MODEL_EVENTS, self._handle_model_event)
register_events(BULK_DOWNLOAD_EVENTS, self._handle_bulk_image_download_event)
2023-03-03 06:02:00 +00:00
2024-03-14 08:04:19 +00:00
async def _handle_sub_queue(self, sid: str, data: Any) -> None:
await self._sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id)
2024-03-14 08:04:19 +00:00
async def _handle_unsub_queue(self, sid: str, data: Any) -> None:
await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id)
2024-01-08 00:55:59 +00:00
2024-03-14 08:04:19 +00:00
async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None:
await self._sio.enter_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].queue_id)
2024-01-08 00:55:59 +00:00
async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase]) -> None:
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"))
2024-01-08 00:55:59 +00:00
async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None:
await self._sio.emit(event=event[0], data=event[1].model_dump(mode="json"), room=event[1].bulk_download_id)