From b777dba430d8611a4f07c3dcd4ea91f93ea0be1f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 5 Sep 2023 21:17:33 +1000 Subject: [PATCH] feat: batch events When a batch creates a session, we need to alert the client of this. Because the sessions are created by the batch manager (not directly in response to a client action), we need to emit an event with the session id. To accomodate this, a secondary set of sio sub/unsub/event handlers are created. These are specifically for batch events. The room is the `batch_id`. When creating a batch, the client subscribes to this batch room. When the batch manager creates a batch session, a `batch_session_created` event is emitted in the appropriate room. It includes the session id. The client then may subscribe to the session room, and all socket stuff proceeds as it did before. --- invokeai/app/api/sockets.py | 29 ++++++++++++++++++++------ invokeai/app/services/batch_manager.py | 1 + invokeai/app/services/events.py | 21 +++++++++++++++++++ 3 files changed, 45 insertions(+), 6 deletions(-) diff --git a/invokeai/app/api/sockets.py b/invokeai/app/api/sockets.py index 4591bac540..1c4a59d0a4 100644 --- a/invokeai/app/api/sockets.py +++ b/invokeai/app/api/sockets.py @@ -13,11 +13,15 @@ class SocketIO: def __init__(self, app: FastAPI): self.__sio = SocketManager(app=app) - self.__sio.on("subscribe", handler=self._handle_sub) - self.__sio.on("unsubscribe", handler=self._handle_unsub) + self.__sio.on("subscribe_session", handler=self._handle_sub_session) + self.__sio.on("unsubscribe_session", handler=self._handle_unsub_session) local_handler.register(event_name=EventServiceBase.session_event, _func=self._handle_session_event) + self.__sio.on("subscribe_batch", handler=self._handle_sub_batch) + self.__sio.on("unsubscribe_batch", handler=self._handle_unsub_batch) + local_handler.register(event_name=EventServiceBase.batch_event, _func=self._handle_batch_event) + async def _handle_session_event(self, event: Event): await self.__sio.emit( event=event[1]["event"], @@ -25,12 +29,25 @@ class SocketIO: room=event[1]["data"]["graph_execution_state_id"], ) - async def _handle_sub(self, sid, data, *args, **kwargs): + async def _handle_sub_session(self, sid, data, *args, **kwargs): if "session" in data: self.__sio.enter_room(sid, data["session"]) - # @app.sio.on('unsubscribe') - - async def _handle_unsub(self, sid, data, *args, **kwargs): + async def _handle_unsub_session(self, sid, data, *args, **kwargs): if "session" in data: self.__sio.leave_room(sid, data["session"]) + + async def _handle_batch_event(self, event: Event): + await self.__sio.emit( + event=event[1]["event"], + data=event[1]["data"], + room=event[1]["data"]["batch_id"], + ) + + async def _handle_sub_batch(self, sid, data, *args, **kwargs): + if "batch_id" in data: + self.__sio.enter_room(sid, data["batch_id"]) + + async def _handle_unsub_batch(self, sid, data, *args, **kwargs): + if "batch_id" in data: + self.__sio.enter_room(sid, data["batch_id"]) diff --git a/invokeai/app/services/batch_manager.py b/invokeai/app/services/batch_manager.py index fce6760815..d6432d5fc7 100644 --- a/invokeai/app/services/batch_manager.py +++ b/invokeai/app/services/batch_manager.py @@ -160,6 +160,7 @@ class BatchManager(BatchManagerBase): session_id=next_session.session_id, changes=BatchSessionChanges(state="in_progress"), ) + self.__invoker.services.events.emit_batch_session_created(next_session.batch_id, next_session.session_id) self.__invoker.invoke(ges, invoke_all=True) def create_batch_process(self, batch: Batch, graph: Graph) -> BatchProcessResponse: diff --git a/invokeai/app/services/events.py b/invokeai/app/services/events.py index a266fe4f18..0ac7587f19 100644 --- a/invokeai/app/services/events.py +++ b/invokeai/app/services/events.py @@ -13,6 +13,7 @@ from invokeai.app.services.model_manager_service import ( class EventServiceBase: session_event: str = "session_event" + batch_event: str = "batch_event" """Basic event bus, to have an empty stand-in when not needed""" @@ -20,12 +21,21 @@ class EventServiceBase: pass def __emit_session_event(self, event_name: str, payload: dict) -> None: + """Session events are emitted to a room with the session_id as the room name""" payload["timestamp"] = get_timestamp() self.dispatch( event_name=EventServiceBase.session_event, payload=dict(event=event_name, data=payload), ) + def __emit_batch_event(self, event_name: str, payload: dict) -> None: + """Batch events are emitted to a room with the batch_id as the room name""" + payload["timestamp"] = get_timestamp() + self.dispatch( + event_name=EventServiceBase.batch_event, + payload=dict(event=event_name, data=payload), + ) + # Define events here for every event in the system. # This will make them easier to integrate until we find a schema generator. def emit_generator_progress( @@ -187,3 +197,14 @@ class EventServiceBase: error=error, ), ) + + def emit_batch_session_created( + self, + batch_id: str, + graph_execution_state_id: str, + ) -> None: + """Emitted when a batch session is created""" + self.__emit_batch_event( + event_name="batch_session_created", + payload=dict(batch_id=batch_id, graph_execution_state_id=graph_execution_state_id), + )