Rename graph processor to session runner to better describe what it's doing, add before/after callbacks for sessions

This commit is contained in:
Brandon Rising 2024-03-04 15:44:29 -05:00 committed by Brandon
parent 7d5a88b69d
commit 46c904d08a

View File

@ -21,8 +21,8 @@ from .session_processor_base import SessionProcessorBase
from .session_processor_common import SessionProcessorStatus from .session_processor_common import SessionProcessorStatus
class GraphProcessor: class SessionRunner:
"""Process a graph of invocations""" """Processes a single session's invocations"""
def __init__( def __init__(
self, self,
@ -167,6 +167,8 @@ class GraphProcessor:
class DefaultSessionProcessor(SessionProcessorBase): class DefaultSessionProcessor(SessionProcessorBase):
"""Processes sessions from the session queue"""
def start( def start(
self, self,
invoker: Invoker, invoker: Invoker,
@ -174,10 +176,14 @@ class DefaultSessionProcessor(SessionProcessorBase):
polling_interval: int = 1, polling_interval: int = 1,
on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None, on_before_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None, on_after_run_node: Union[Callable[[BaseInvocation, SessionQueueItem], bool], None] = None,
on_before_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
on_after_run_session: Union[Callable[[SessionQueueItem], bool], None] = None,
) -> None: ) -> None:
self._invoker: Invoker = invoker self._invoker: Invoker = invoker
self._queue_item: Optional[SessionQueueItem] = None self._queue_item: Optional[SessionQueueItem] = None
self._invocation: Optional[BaseInvocation] = None self._invocation: Optional[BaseInvocation] = None
self.on_before_run_session = on_before_run_session
self.on_after_run_session = on_after_run_session
self._resume_event = ThreadEvent() self._resume_event = ThreadEvent()
self._stop_event = ThreadEvent() self._stop_event = ThreadEvent()
@ -202,7 +208,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
else None else None
) )
self.graph_processor = GraphProcessor( self.session_runner = SessionRunner(
services=self._invoker.services, services=self._invoker.services,
cancel_event=self._cancel_event, cancel_event=self._cancel_event,
profiler=self._profiler, profiler=self._profiler,
@ -275,11 +281,19 @@ class DefaultSessionProcessor(SessionProcessorBase):
# Get the next session to process # Get the next session to process
self._queue_item = self._invoker.services.session_queue.dequeue() self._queue_item = self._invoker.services.session_queue.dequeue()
if self._queue_item is not None and resume_event.is_set(): if self._queue_item is not None and resume_event.is_set():
# If we have a on_before_run_session callback, call it
if self.on_before_run_session is not None:
self.on_before_run_session(self._queue_item)
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}") self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
cancel_event.clear() cancel_event.clear()
# Run the graph # Run the graph
self.graph_processor.run(queue_item=self._queue_item) self.session_runner.run(queue_item=self._queue_item)
# If we have a on_after_run_session callback, call it
if self.on_after_run_session is not None:
self.on_after_run_session(self._queue_item)
# The session is complete, immediately poll for next session # The session is complete, immediately poll for next session
self._queue_item = None self._queue_item = None