mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: batch events
When a batch creates a session, we need to alert the client of this. Because the sessions are created by the batch manager (not directly in response to a client action), we need to emit an event with the session id. To accomodate this, a secondary set of sio sub/unsub/event handlers are created. These are specifically for batch events. The room is the `batch_id`. When creating a batch, the client subscribes to this batch room. When the batch manager creates a batch session, a `batch_session_created` event is emitted in the appropriate room. It includes the session id. The client then may subscribe to the session room, and all socket stuff proceeds as it did before.
This commit is contained in:
parent
531c3bb1e2
commit
b777dba430
@ -13,11 +13,15 @@ class SocketIO:
|
||||
|
||||
def __init__(self, app: FastAPI):
|
||||
self.__sio = SocketManager(app=app)
|
||||
self.__sio.on("subscribe", handler=self._handle_sub)
|
||||
self.__sio.on("unsubscribe", handler=self._handle_unsub)
|
||||
|
||||
self.__sio.on("subscribe_session", handler=self._handle_sub_session)
|
||||
self.__sio.on("unsubscribe_session", handler=self._handle_unsub_session)
|
||||
local_handler.register(event_name=EventServiceBase.session_event, _func=self._handle_session_event)
|
||||
|
||||
self.__sio.on("subscribe_batch", handler=self._handle_sub_batch)
|
||||
self.__sio.on("unsubscribe_batch", handler=self._handle_unsub_batch)
|
||||
local_handler.register(event_name=EventServiceBase.batch_event, _func=self._handle_batch_event)
|
||||
|
||||
async def _handle_session_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
event=event[1]["event"],
|
||||
@ -25,12 +29,25 @@ class SocketIO:
|
||||
room=event[1]["data"]["graph_execution_state_id"],
|
||||
)
|
||||
|
||||
async def _handle_sub(self, sid, data, *args, **kwargs):
|
||||
async def _handle_sub_session(self, sid, data, *args, **kwargs):
|
||||
if "session" in data:
|
||||
self.__sio.enter_room(sid, data["session"])
|
||||
|
||||
# @app.sio.on('unsubscribe')
|
||||
|
||||
async def _handle_unsub(self, sid, data, *args, **kwargs):
|
||||
async def _handle_unsub_session(self, sid, data, *args, **kwargs):
|
||||
if "session" in data:
|
||||
self.__sio.leave_room(sid, data["session"])
|
||||
|
||||
async def _handle_batch_event(self, event: Event):
|
||||
await self.__sio.emit(
|
||||
event=event[1]["event"],
|
||||
data=event[1]["data"],
|
||||
room=event[1]["data"]["batch_id"],
|
||||
)
|
||||
|
||||
async def _handle_sub_batch(self, sid, data, *args, **kwargs):
|
||||
if "batch_id" in data:
|
||||
self.__sio.enter_room(sid, data["batch_id"])
|
||||
|
||||
async def _handle_unsub_batch(self, sid, data, *args, **kwargs):
|
||||
if "batch_id" in data:
|
||||
self.__sio.enter_room(sid, data["batch_id"])
|
||||
|
@ -160,6 +160,7 @@ class BatchManager(BatchManagerBase):
|
||||
session_id=next_session.session_id,
|
||||
changes=BatchSessionChanges(state="in_progress"),
|
||||
)
|
||||
self.__invoker.services.events.emit_batch_session_created(next_session.batch_id, next_session.session_id)
|
||||
self.__invoker.invoke(ges, invoke_all=True)
|
||||
|
||||
def create_batch_process(self, batch: Batch, graph: Graph) -> BatchProcessResponse:
|
||||
|
@ -13,6 +13,7 @@ from invokeai.app.services.model_manager_service import (
|
||||
|
||||
class EventServiceBase:
|
||||
session_event: str = "session_event"
|
||||
batch_event: str = "batch_event"
|
||||
|
||||
"""Basic event bus, to have an empty stand-in when not needed"""
|
||||
|
||||
@ -20,12 +21,21 @@ class EventServiceBase:
|
||||
pass
|
||||
|
||||
def __emit_session_event(self, event_name: str, payload: dict) -> None:
|
||||
"""Session events are emitted to a room with the session_id as the room name"""
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.session_event,
|
||||
payload=dict(event=event_name, data=payload),
|
||||
)
|
||||
|
||||
def __emit_batch_event(self, event_name: str, payload: dict) -> None:
|
||||
"""Batch events are emitted to a room with the batch_id as the room name"""
|
||||
payload["timestamp"] = get_timestamp()
|
||||
self.dispatch(
|
||||
event_name=EventServiceBase.batch_event,
|
||||
payload=dict(event=event_name, data=payload),
|
||||
)
|
||||
|
||||
# Define events here for every event in the system.
|
||||
# This will make them easier to integrate until we find a schema generator.
|
||||
def emit_generator_progress(
|
||||
@ -187,3 +197,14 @@ class EventServiceBase:
|
||||
error=error,
|
||||
),
|
||||
)
|
||||
|
||||
def emit_batch_session_created(
|
||||
self,
|
||||
batch_id: str,
|
||||
graph_execution_state_id: str,
|
||||
) -> None:
|
||||
"""Emitted when a batch session is created"""
|
||||
self.__emit_batch_event(
|
||||
event_name="batch_session_created",
|
||||
payload=dict(batch_id=batch_id, graph_execution_state_id=graph_execution_state_id),
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user