docs(processor): update docstrings, comments

This commit is contained in:
psychedelicious 2024-05-24 10:20:20 +10:00
parent 6063487b20
commit a98ddedb95
2 changed files with 97 additions and 20 deletions

View File

@ -16,17 +16,33 @@ class SessionRunnerBase(ABC):
@abstractmethod
def start(self, services: InvocationServices, cancel_event: Event, profiler: Optional[Profiler] = None) -> None:
"""Starts the session runner"""
"""Starts the session runner.
Args:
services: The invocation services.
cancel_event: The cancel event.
profiler: The profiler to use for session profiling via cProfile. Omit to disable profiling. Basic session
stats will be still be recorded and logged when profiling is disabled.
"""
pass
@abstractmethod
def run(self, queue_item: SessionQueueItem) -> None:
"""Runs the session"""
"""Runs a session.
Args:
queue_item: The session to run.
"""
pass
@abstractmethod
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
"""Runs an already prepared node on the session"""
"""Run a single node in the graph.
Args:
invocation: The invocation to run.
queue_item: The session queue item.
"""
pass
@ -56,13 +72,25 @@ class SessionProcessorBase(ABC):
class OnBeforeRunNode(Protocol):
def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> bool: ...
def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
"""Callback to run before executing a node.
Args:
invocation: The invocation that will be executed.
queue_item: The session queue item.
"""
...
class OnAfterRunNode(Protocol):
def __call__(
self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput
) -> bool: ...
def __call__(self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput) -> None:
"""Callback to run before executing a node.
Args:
invocation: The invocation that was executed.
queue_item: The session queue item.
"""
...
class OnNodeError(Protocol):
@ -73,15 +101,37 @@ class OnNodeError(Protocol):
error_type: str,
error_message: str,
error_traceback: str,
) -> bool: ...
) -> None:
"""Callback to run when a node has an error.
Args:
invocation: The invocation that errored.
queue_item: The session queue item.
error_type: The type of error, e.g. "ValueError".
error_message: The error message, e.g. "Invalid value".
error_traceback: The stringified error traceback.
"""
...
class OnBeforeRunSession(Protocol):
def __call__(self, queue_item: SessionQueueItem) -> bool: ...
def __call__(self, queue_item: SessionQueueItem) -> None:
"""Callback to run before executing a session.
Args:
queue_item: The session queue item.
"""
...
class OnAfterRunSession(Protocol):
def __call__(self, queue_item: SessionQueueItem) -> bool: ...
def __call__(self, queue_item: SessionQueueItem) -> None:
"""Callback to run after executing a session.
Args:
queue_item: The session queue item.
"""
...
class OnNonFatalProcessorError(Protocol):
@ -91,4 +141,13 @@ class OnNonFatalProcessorError(Protocol):
error_type: str,
error_message: str,
error_traceback: str,
) -> bool: ...
) -> None:
"""Callback to run when a non-fatal error occurs in the processor.
Args:
queue_item: The session queue item, if one was being executed when the error occurred.
error_type: The type of error, e.g. "ValueError".
error_message: The error message, e.g. "Invalid value".
error_traceback: The stringified error traceback.
"""
...

View File

@ -30,7 +30,7 @@ from .session_processor_common import SessionProcessorStatus
class DefaultSessionRunner(SessionRunnerBase):
"""Processes a single session's invocations"""
"""Processes a single session's invocations."""
def __init__(
self,
@ -40,6 +40,15 @@ class DefaultSessionRunner(SessionRunnerBase):
on_node_error_callbacks: Optional[list[OnNodeError]] = None,
on_after_run_session_callbacks: Optional[list[OnAfterRunSession]] = None,
):
"""
Args:
on_before_run_session_callbacks: Callbacks to run before the session starts.
on_before_run_node_callbacks: Callbacks to run before each node starts.
on_after_run_node_callbacks: Callbacks to run after each node completes.
on_node_error_callbacks: Callbacks to run when a node errors.
on_after_run_session_callbacks: Callbacks to run after the session completes.
"""
self._on_before_run_session_callbacks = on_before_run_session_callbacks or []
self._on_before_run_node_callbacks = on_before_run_node_callbacks or []
self._on_after_run_node_callbacks = on_after_run_node_callbacks or []
@ -47,14 +56,12 @@ class DefaultSessionRunner(SessionRunnerBase):
self._on_after_run_session_callbacks = on_after_run_session_callbacks or []
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
"""Start the session runner"""
self._services = services
self._cancel_event = cancel_event
self._profiler = profiler
def run(self, queue_item: SessionQueueItem):
"""Run the graph"""
# Exceptions raised outside `run_node` are handled by the processor.
# Exceptions raised outside `run_node` are handled by the processor. There is no need to catch them here.
self._on_before_run_session(queue_item=queue_item)
@ -78,14 +85,16 @@ class DefaultSessionRunner(SessionRunnerBase):
if invocation is None or self._cancel_event.is_set():
break
self.run_node(invocation, queue_item)
# The session is complete if all invocations have been run or there is an error on the session.
if queue_item.session.is_complete() or self._cancel_event.is_set():
break
self._on_after_run_session(queue_item=queue_item)
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Run a single node in the graph"""
try:
# Any unhandled exception in this scope is an invocation error & will fail the graph
with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id):
@ -110,7 +119,7 @@ class DefaultSessionRunner(SessionRunnerBase):
self._on_after_run_node(invocation, queue_item, output)
except KeyboardInterrupt:
# TODO(MM2): Create an event for this
# TODO(psyche): This is expected to be caught in the main thread. Do we need to catch this here?
pass
except CanceledException:
# When the user cancels the graph, we first set the cancel event. The event is checked
@ -137,6 +146,8 @@ class DefaultSessionRunner(SessionRunnerBase):
)
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
"""Run before a session is executed"""
# If profiling is enabled, start the profiler
if self._profiler is not None:
self._profiler.start(profile_id=queue_item.session_id)
@ -145,6 +156,8 @@ class DefaultSessionRunner(SessionRunnerBase):
callback(queue_item=queue_item)
def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
"""Run after a session is executed"""
# If we are profiling, stop the profiler and dump the profile & stats
if self._profiler is not None:
profile_path = self._profiler.stop()
@ -156,7 +169,8 @@ class DefaultSessionRunner(SessionRunnerBase):
# Update the queue item with the completed session
self._services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
# Send complete event
# TODO(psyche): This feels jumbled - we should review separation of concerns here.
# Send complete event. The events service will receive this and update the queue item's status.
self._services.events.emit_graph_execution_complete(
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
@ -175,6 +189,7 @@ class DefaultSessionRunner(SessionRunnerBase):
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Run before a node is executed"""
# Send starting event
self._services.events.emit_invocation_started(
queue_batch_id=queue_item.batch_id,
@ -192,6 +207,7 @@ class DefaultSessionRunner(SessionRunnerBase):
self, invocation: BaseInvocation, queue_item: SessionQueueItem, output: BaseInvocationOutput
):
"""Run after a node is executed"""
# Send complete event on successful runs
self._services.events.emit_invocation_complete(
queue_batch_id=queue_item.batch_id,
@ -214,6 +230,8 @@ class DefaultSessionRunner(SessionRunnerBase):
error_message: str,
error_traceback: str,
):
"""Run when a node errors"""
# Node errors do not get the full traceback. Only the queue item gets the full traceback.
node_error = f"{error_type}: {error_message}"
queue_item.session.set_node_error(invocation.id, node_error)
@ -356,8 +374,8 @@ class DefaultSessionProcessor(SessionProcessorBase):
resume_event: ThreadEvent,
cancel_event: ThreadEvent,
):
# Outermost processor try block; any unhandled exception is a fatal processor error
try:
# Any unhandled exception in this block is a fatal processor error and will stop the processor.
self._thread_semaphore.acquire()
stop_event.clear()
resume_event.set()
@ -365,8 +383,8 @@ class DefaultSessionProcessor(SessionProcessorBase):
while not stop_event.is_set():
poll_now_event.clear()
# Middle processor try block; any unhandled exception is a non-fatal processor error
try:
# Any unhandled exception in this block is a nonfatal processor error and will be handled.
# If we are paused, wait for resume event
resume_event.wait()