diff --git a/invokeai/app/api/sockets.py b/invokeai/app/api/sockets.py index e651e43559..c5d9ace8d2 100644 --- a/invokeai/app/api/sockets.py +++ b/invokeai/app/api/sockets.py @@ -12,16 +12,27 @@ class SocketIO: __sio: AsyncServer __app: ASGIApp + __sub_queue: str = "subscribe_queue" + __unsub_queue: str = "unsubscribe_queue" + + __sub_bulk_download: str = "subscribe_bulk_download" + __unsub_bulk_download: str = "unsubscribe_bulk_download" + + def __init__(self, app: FastAPI): self.__sio = AsyncServer(async_mode="asgi", cors_allowed_origins="*") self.__app = ASGIApp(socketio_server=self.__sio, socketio_path="/ws/socket.io") app.mount("/ws", self.__app) - self.__sio.on("subscribe_queue", handler=self._handle_sub_queue) - self.__sio.on("unsubscribe_queue", handler=self._handle_unsub_queue) + self.__sio.on(self.__sub_queue, handler=self._handle_sub_queue) + self.__sio.on(self.__unsub_queue, handler=self._handle_unsub_queue) local_handler.register(event_name=EventServiceBase.queue_event, _func=self._handle_queue_event) local_handler.register(event_name=EventServiceBase.model_event, _func=self._handle_model_event) + self.__sio.on(self.__sub_bulk_download, handler=self._handle_sub_bulk_download) + self.__sio.on(self.__unsub_bulk_download, handler=self._handle_unsub_bulk_download) + local_handler.register(event_name=EventServiceBase.bulk_download_event, _func=self._handle_bulk_download_event) + async def _handle_queue_event(self, event: Event): await self.__sio.emit( event=event[1]["event"], @@ -39,3 +50,18 @@ class SocketIO: async def _handle_model_event(self, event: Event) -> None: await self.__sio.emit(event=event[1]["event"], data=event[1]["data"]) + + async def _handle_bulk_download_event(self, event: Event): + await self.__sio.emit( + event=event[1]["event"], + data=event[1]["data"], + room=event[1]["data"]["bulk_download_id"], + ) + + async def _handle_sub_bulk_download(self, sid, data, *args, **kwargs): + if "bulk_download_id" in data: + await self.__sio.enter_room(sid, data["bulk_download_id"]) + + async def _handle_unsub_bulk_download(self, sid, data, *args, **kwargs): + if "bulk_download_id" in data: + await self.__sio.leave_room(sid, data["bulk_download_id"]) diff --git a/invokeai/app/services/events/events_base.py b/invokeai/app/services/events/events_base.py index 5355fe2298..0a0668b274 100644 --- a/invokeai/app/services/events/events_base.py +++ b/invokeai/app/services/events/events_base.py @@ -16,6 +16,7 @@ from invokeai.backend.model_manager import AnyModelConfig class EventServiceBase: queue_event: str = "queue_event" + bulk_download_event: str = "bulk_download_event" download_event: str = "download_event" model_event: str = "model_event" @@ -24,6 +25,14 @@ class EventServiceBase: def dispatch(self, event_name: str, payload: Any) -> None: pass + def _emit_bulk_download_event(self, event_name: str, payload: dict) -> None: + """Bulk download events are emitted to a room with queue_id as the room name""" + payload["timestamp"] = get_timestamp() + self.dispatch( + event_name=EventServiceBase.bulk_download_event, + payload={"event": event_name, "data": payload}, + ) + def __emit_queue_event(self, event_name: str, payload: dict) -> None: """Queue events are emitted to a room with queue_id as the room name""" payload["timestamp"] = get_timestamp() @@ -430,3 +439,25 @@ class EventServiceBase: "error": error, }, ) + + def emit_bulk_download_started(self, bulk_download_id: str) -> None: + """Emitted when a bulk download starts""" + self._emit_bulk_download_event( + event_name="bulk_download_started", + payload={"bulk_download_id": bulk_download_id, } + ) + + def emit_bulk_download_completed(self, bulk_download_id: str, file_path: str) -> None: + """Emitted when a bulk download completes""" + self._emit_bulk_download_event( + event_name="bulk_download_completed", + payload={"bulk_download_id": bulk_download_id, + "file_path": file_path} + ) + + def emit_bulk_download_failed(self, bulk_download_id: str, error: str) -> None: + """Emitted when a bulk download fails""" + self._emit_bulk_download_event( + event_name="bulk_download_failed", + payload={"bulk_download_id": bulk_download_id, "error": error} + )