adding socket events for bulk download

This commit is contained in:
Stefan Tobler 2024-01-07 19:55:59 -05:00 committed by Brandon Rising
parent aba9cd3f9a
commit cf6eb1394a
2 changed files with 59 additions and 2 deletions

View File

@ -12,16 +12,27 @@ class SocketIO:
__sio: AsyncServer __sio: AsyncServer
__app: ASGIApp __app: ASGIApp
__sub_queue: str = "subscribe_queue"
__unsub_queue: str = "unsubscribe_queue"
__sub_bulk_download: str = "subscribe_bulk_download"
__unsub_bulk_download: str = "unsubscribe_bulk_download"
def __init__(self, app: FastAPI): def __init__(self, app: FastAPI):
self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*") self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*")
self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io") self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io")
app.mount("/ws", self.__app) app.mount("/ws", self.__app)
self.__sio.on("subscribe_queue", handler=self._handle_sub_queue) self.__sio.on(self.__sub_queue, handler=self._handle_sub_queue)
self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue) self.__sio.on(self.__unsub_queue, handler=self._handle_unsub_queue)
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event) local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event)
local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event) local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event)
self.__sio.on(self.__sub_bulk_download, handler=self._handle_sub_bulk_download)
self.__sio.on(self.__unsub_bulk_download, handler=self._handle_unsub_bulk_download)
local_handler.register(event_name=EventServiceBase.bulk_download_event, _func=self._handle_bulk_download_event)
async def _handle_queue_event(self, event: Event): async def _handle_queue_event(self, event: Event):
await self.__sio.emit( await self.__sio.emit(
event=event[1]["event"], event=event[1]["event"],
@ -39,3 +50,18 @@ class SocketIO:
async def _handle_model_event(self, event: Event) -> None: async def _handle_model_event(self, event: Event) -> None:
await self.__sio.emit(event=event[1]["event"], data=event[1]["data"]) await self.__sio.emit(event=event[1]["event"], data=event[1]["data"])
async def _handle_bulk_download_event(self, event: Event):
await self.__sio.emit(
event=event[1]["event"],
data=event[1]["data"],
room=event[1]["data"]["bulk_download_id"],
)
async def _handle_sub_bulk_download(self, sid, data, *args, **kwargs):
if "bulk_download_id" in data:
await self.__sio.enter_room(sid, data["bulk_download_id"])
async def _handle_unsub_bulk_download(self, sid, data, *args, **kwargs):
if "bulk_download_id" in data:
await self.__sio.leave_room(sid, data["bulk_download_id"])

View File

@ -16,6 +16,7 @@ from invokeai.backend.model_manager import AnyModelConfig
class EventServiceBase: class EventServiceBase:
queue_event: str = "queue_event" queue_event: str = "queue_event"
bulk_download_event: str = "bulk_download_event"
download_event: str = "download_event" download_event: str = "download_event"
model_event: str = "model_event" model_event: str = "model_event"
@ -24,6 +25,14 @@ class EventServiceBase:
def dispatch(self, event_name: str, payload: Any) -> None: def dispatch(self, event_name: str, payload: Any) -> None:
pass pass
def _emit_bulk_download_event(self, event_name: str, payload: dict) -> None:
"""Bulk download events are emitted to a room with queue_id as the room name"""
payload["timestamp"] = get_timestamp()
self.dispatch(
event_name=EventServiceBase.bulk_download_event,
payload={"event": event_name, "data": payload},
)
def __emit_queue_event(self, event_name: str, payload: dict) -> None: def __emit_queue_event(self, event_name: str, payload: dict) -> None:
"""Queue events are emitted to a room with queue_id as the room name""" """Queue events are emitted to a room with queue_id as the room name"""
payload["timestamp"] = get_timestamp() payload["timestamp"] = get_timestamp()
@ -430,3 +439,25 @@ class EventServiceBase:
"error": error, "error": error,
}, },
) )
def emit_bulk_download_started(self, bulk_download_id: str) -> None:
"""Emitted when a bulk download starts"""
self._emit_bulk_download_event(
event_name="bulk_download_started",
payload={"bulk_download_id": bulk_download_id, }
)
def emit_bulk_download_completed(self, bulk_download_id: str, file_path: str) -> None:
"""Emitted when a bulk download completes"""
self._emit_bulk_download_event(
event_name="bulk_download_completed",
payload={"bulk_download_id": bulk_download_id,
"file_path": file_path}
)
def emit_bulk_download_failed(self, bulk_download_id: str, error: str) -> None:
"""Emitted when a bulk download fails"""
self._emit_bulk_download_event(
event_name="bulk_download_failed",
payload={"bulk_download_id": bulk_download_id, "error": error}
)