2022-12-01 05:33:20 +00:00
|
|
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
|
|
|
|
2024-03-14 08:04:19 +00:00
|
|
|
from typing import Any
|
|
|
|
|
2022-12-01 05:33:20 +00:00
|
|
|
from fastapi import FastAPI
|
2024-03-14 08:04:19 +00:00
|
|
|
from pydantic import BaseModel
|
2023-09-20 22:30:01 +00:00
|
|
|
from socketio import ASGIApp, AsyncServer
|
2023-03-03 06:02:00 +00:00
|
|
|
|
2024-03-14 08:04:19 +00:00
|
|
|
from invokeai.app.services.events.events_common import (
|
|
|
|
BatchEnqueuedEvent,
|
|
|
|
BulkDownloadCompleteEvent,
|
|
|
|
BulkDownloadErrorEvent,
|
2024-03-10 12:23:11 +00:00
|
|
|
BulkDownloadEventBase,
|
2024-03-14 08:04:19 +00:00
|
|
|
BulkDownloadStartedEvent,
|
2024-03-31 01:03:49 +00:00
|
|
|
DownloadCancelledEvent,
|
|
|
|
DownloadCompleteEvent,
|
|
|
|
DownloadErrorEvent,
|
|
|
|
DownloadProgressEvent,
|
|
|
|
DownloadStartedEvent,
|
2024-03-14 08:04:19 +00:00
|
|
|
FastAPIEvent,
|
|
|
|
InvocationCompleteEvent,
|
|
|
|
InvocationDenoiseProgressEvent,
|
|
|
|
InvocationErrorEvent,
|
|
|
|
InvocationStartedEvent,
|
2024-03-10 12:23:11 +00:00
|
|
|
ModelEventBase,
|
2024-03-14 08:04:19 +00:00
|
|
|
ModelInstallCancelledEvent,
|
|
|
|
ModelInstallCompleteEvent,
|
|
|
|
ModelInstallDownloadProgressEvent,
|
2024-03-31 01:03:49 +00:00
|
|
|
ModelInstallDownloadsCompleteEvent,
|
2024-03-14 08:04:19 +00:00
|
|
|
ModelInstallErrorEvent,
|
|
|
|
ModelInstallStartedEvent,
|
|
|
|
ModelLoadCompleteEvent,
|
|
|
|
ModelLoadStartedEvent,
|
|
|
|
QueueClearedEvent,
|
2024-03-10 12:23:11 +00:00
|
|
|
QueueEventBase,
|
2024-03-14 08:04:19 +00:00
|
|
|
QueueItemStatusChangedEvent,
|
|
|
|
SessionCanceledEvent,
|
|
|
|
SessionCompleteEvent,
|
|
|
|
SessionStartedEvent,
|
|
|
|
register_events,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class QueueSubscriptionEvent(BaseModel):
|
|
|
|
queue_id: str
|
|
|
|
|
|
|
|
|
|
|
|
class BulkDownloadSubscriptionEvent(BaseModel):
|
|
|
|
bulk_download_id: str
|
2022-12-01 05:33:20 +00:00
|
|
|
|
2023-03-03 06:02:00 +00:00
|
|
|
|
2022-12-01 05:33:20 +00:00
|
|
|
class SocketIO:
|
2024-03-14 08:04:19 +00:00
|
|
|
_sub_queue = "subscribe_queue"
|
|
|
|
_unsub_queue = "unsubscribe_queue"
|
2024-01-08 00:55:59 +00:00
|
|
|
|
2024-03-14 08:04:19 +00:00
|
|
|
_sub_bulk_download = "subscribe_bulk_download"
|
|
|
|
_unsub_bulk_download = "unsubscribe_bulk_download"
|
2024-01-08 00:55:59 +00:00
|
|
|
|
2022-12-01 05:33:20 +00:00
|
|
|
def __init__(self, app: FastAPI):
|
2024-03-14 08:04:19 +00:00
|
|
|
self._sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
|
|
|
|
self._app = ASGIApp(socketio_server=self._sio, socketio_path="/ws/socket.io")
|
|
|
|
app.mount("/ws", self._app)
|
|
|
|
|
|
|
|
self._sio.on(self._sub_queue, handler=self._handle_sub_queue)
|
|
|
|
self._sio.on(self._unsub_queue, handler=self._handle_unsub_queue)
|
|
|
|
self._sio.on(self._sub_bulk_download, handler=self._handle_sub_bulk_download)
|
|
|
|
self._sio.on(self._unsub_bulk_download, handler=self._handle_unsub_bulk_download)
|
|
|
|
|
|
|
|
register_events(
|
|
|
|
{
|
|
|
|
InvocationStartedEvent,
|
|
|
|
InvocationDenoiseProgressEvent,
|
|
|
|
InvocationCompleteEvent,
|
|
|
|
InvocationErrorEvent,
|
|
|
|
SessionStartedEvent,
|
|
|
|
SessionCompleteEvent,
|
|
|
|
SessionCanceledEvent,
|
|
|
|
QueueItemStatusChangedEvent,
|
|
|
|
BatchEnqueuedEvent,
|
|
|
|
QueueClearedEvent,
|
|
|
|
},
|
|
|
|
self._handle_queue_event,
|
|
|
|
)
|
|
|
|
|
|
|
|
register_events(
|
|
|
|
{
|
2024-03-31 01:03:49 +00:00
|
|
|
DownloadCancelledEvent,
|
|
|
|
DownloadCompleteEvent,
|
|
|
|
DownloadErrorEvent,
|
|
|
|
DownloadProgressEvent,
|
|
|
|
DownloadStartedEvent,
|
2024-03-14 08:04:19 +00:00
|
|
|
ModelLoadStartedEvent,
|
|
|
|
ModelLoadCompleteEvent,
|
|
|
|
ModelInstallDownloadProgressEvent,
|
2024-03-31 01:03:49 +00:00
|
|
|
ModelInstallDownloadsCompleteEvent,
|
2024-03-14 08:04:19 +00:00
|
|
|
ModelInstallStartedEvent,
|
|
|
|
ModelInstallCompleteEvent,
|
|
|
|
ModelInstallCancelledEvent,
|
|
|
|
ModelInstallErrorEvent,
|
|
|
|
},
|
|
|
|
self._handle_model_event,
|
2022-12-01 05:33:20 +00:00
|
|
|
)
|
|
|
|
|
2024-03-14 08:04:19 +00:00
|
|
|
register_events(
|
|
|
|
{BulkDownloadStartedEvent, BulkDownloadCompleteEvent, BulkDownloadErrorEvent},
|
|
|
|
self._handle_bulk_image_download_event,
|
|
|
|
)
|
2023-03-03 06:02:00 +00:00
|
|
|
|
2024-03-14 08:04:19 +00:00
|
|
|
async def _handle_sub_queue(self, sid: str, data: Any) -> None:
|
|
|
|
await self._sio.enter_room(sid, QueueSubscriptionEvent(**data).queue_id)
|
2023-11-26 02:45:59 +00:00
|
|
|
|
2024-03-14 08:04:19 +00:00
|
|
|
async def _handle_unsub_queue(self, sid: str, data: Any) -> None:
|
|
|
|
await self._sio.leave_room(sid, QueueSubscriptionEvent(**data).queue_id)
|
2024-01-08 00:55:59 +00:00
|
|
|
|
2024-03-14 08:04:19 +00:00
|
|
|
async def _handle_sub_bulk_download(self, sid: str, data: Any) -> None:
|
|
|
|
await self._sio.enter_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
|
|
|
|
|
|
|
|
async def _handle_unsub_bulk_download(self, sid: str, data: Any) -> None:
|
|
|
|
await self._sio.leave_room(sid, BulkDownloadSubscriptionEvent(**data).bulk_download_id)
|
|
|
|
|
2024-03-10 12:23:11 +00:00
|
|
|
async def _handle_queue_event(self, event: FastAPIEvent[QueueEventBase]):
|
2024-03-14 08:04:19 +00:00
|
|
|
event_name, payload = event
|
2024-03-14 08:40:15 +00:00
|
|
|
await self._sio.emit(event=event_name, data=payload.model_dump(mode="json"), room=payload.queue_id)
|
2024-01-08 00:55:59 +00:00
|
|
|
|
2024-03-10 12:23:11 +00:00
|
|
|
async def _handle_model_event(self, event: FastAPIEvent[ModelEventBase]) -> None:
|
2024-03-14 08:04:19 +00:00
|
|
|
event_name, payload = event
|
2024-03-14 08:40:15 +00:00
|
|
|
await self._sio.emit(event=event_name, data=payload.model_dump(mode="json"))
|
2024-01-08 00:55:59 +00:00
|
|
|
|
2024-03-10 12:23:11 +00:00
|
|
|
async def _handle_bulk_image_download_event(self, event: FastAPIEvent[BulkDownloadEventBase]) -> None:
|
2024-03-14 08:04:19 +00:00
|
|
|
event_name, payload = event
|
2024-03-14 08:40:15 +00:00
|
|
|
await self._sio.emit(event=event_name, data=payload.model_dump(mode="json"), room=payload.bulk_download_id)
|