diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 20922b64d3..f213f11c50 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -21,8 +21,8 @@ from .session_processor_base import SessionProcessorBase from .session_processor_common import SessionProcessorStatus -class GraphProcessor: - """Process a graph of invocations""" +class SessionRunner: + """Processes a single session's invocations""" def __init__( self, @@ -167,6 +167,8 @@ class GraphProcessor: class DefaultSessionProcessor(SessionProcessorBase): + """Processes sessions from the session queue""" + def start( self, invoker: Invoker, @@ -174,10 +176,14 @@ class DefaultSessionProcessor(SessionProcessorBase): polling_interval: int = 1, on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None, on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None, + on_before_run_session: Union[Callable[[SessionQueueItem], bool], None] = None, + on_after_run_session: Union[Callable[[SessionQueueItem], bool], None] = None, ) -> None: self._invoker: Invoker = invoker self._queue_item: Optional[SessionQueueItem] = None self._invocation: Optional[BaseInvocation] = None + self.on_before_run_session = on_before_run_session + self.on_after_run_session = on_after_run_session self._resume_event = ThreadEvent() self._stop_event = ThreadEvent() @@ -202,7 +208,7 @@ class DefaultSessionProcessor(SessionProcessorBase): else None ) - self.graph_processor = GraphProcessor( + self.session_runner = SessionRunner( services=self._invoker.services, cancel_event=self._cancel_event, profiler=self._profiler, @@ -275,11 +281,19 @@ class DefaultSessionProcessor(SessionProcessorBase): # Get the next session to process self._queue_item = self._invoker.services.session_queue.dequeue() if self._queue_item is not None and resume_event.is_set(): + # If we have a on_before_run_session callback, call it + if self.on_before_run_session is not None: + self.on_before_run_session(self._queue_item) + self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}") cancel_event.clear() # Run the graph - self.graph_processor.run(queue_item=self._queue_item) + self.session_runner.run(queue_item=self._queue_item) + + # If we have a on_after_run_session callback, call it + if self.on_after_run_session is not None: + self.on_after_run_session(self._queue_item) # The session is complete, immediately poll for next session self._queue_item = None