feat(app): make things in session runner private

This commit is contained in:
psychedelicious 2024-05-22 18:55:33 +10:00
parent f7c356d142
commit cb8e9e1c7b

View File

@ -52,9 +52,9 @@ class DefaultSessionRunner(SessionRunnerBase):
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None): def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
"""Start the session runner""" """Start the session runner"""
self.services = services self._services = services
self.cancel_event = cancel_event self._cancel_event = cancel_event
self.profiler = profiler self._profiler = profiler
def run(self, queue_item: SessionQueueItem): def run(self, queue_item: SessionQueueItem):
"""Run the graph""" """Run the graph"""
@ -64,33 +64,33 @@ class DefaultSessionRunner(SessionRunnerBase):
while True: while True:
invocation = queue_item.session.next() invocation = queue_item.session.next()
if invocation is None or self.cancel_event.is_set(): if invocation is None or self._cancel_event.is_set():
break break
self.run_node(invocation, queue_item) self.run_node(invocation, queue_item)
if queue_item.session.is_complete() or self.cancel_event.is_set(): if queue_item.session.is_complete() or self._cancel_event.is_set():
break break
self._on_after_run_session(queue_item=queue_item) self._on_after_run_session(queue_item=queue_item)
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None: def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
# If profiling is enabled, start the profiler # If profiling is enabled, start the profiler
if self.profiler is not None: if self._profiler is not None:
self.profiler.start(profile_id=queue_item.session_id) self._profiler.start(profile_id=queue_item.session_id)
if self.on_before_run_session: if self.on_before_run_session:
self.on_before_run_session(queue_item=queue_item) self.on_before_run_session(queue_item=queue_item)
def _on_after_run_session(self, queue_item: SessionQueueItem) -> None: def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
# If we are profiling, stop the profiler and dump the profile & stats # If we are profiling, stop the profiler and dump the profile & stats
if self.profiler is not None: if self._profiler is not None:
profile_path = self.profiler.stop() profile_path = self._profiler.stop()
stats_path = profile_path.with_suffix(".json") stats_path = profile_path.with_suffix(".json")
self.services.performance_statistics.dump_stats( self._services.performance_statistics.dump_stats(
graph_execution_state_id=queue_item.session.id, output_path=stats_path graph_execution_state_id=queue_item.session.id, output_path=stats_path
) )
# Send complete event # Send complete event
self.services.events.emit_graph_execution_complete( self._services.events.emit_graph_execution_complete(
queue_batch_id=queue_item.batch_id, queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id, queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id, queue_id=queue_item.queue_id,
@ -100,8 +100,8 @@ class DefaultSessionRunner(SessionRunnerBase):
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error. # we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError): with suppress(GESStatsNotFoundError):
self.services.performance_statistics.log_stats(queue_item.session.id) self._services.performance_statistics.log_stats(queue_item.session.id)
self.services.performance_statistics.reset_stats() self._services.performance_statistics.reset_stats()
if self.on_after_run_session: if self.on_after_run_session:
self.on_after_run_session(queue_item) self.on_after_run_session(queue_item)
@ -109,7 +109,7 @@ class DefaultSessionRunner(SessionRunnerBase):
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem): def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Run before a node is executed""" """Run before a node is executed"""
# Send starting event # Send starting event
self.services.events.emit_invocation_started( self._services.events.emit_invocation_started(
queue_batch_id=queue_item.batch_id, queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id, queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id, queue_id=queue_item.queue_id,
@ -126,7 +126,7 @@ class DefaultSessionRunner(SessionRunnerBase):
): ):
"""Run after a node is executed""" """Run after a node is executed"""
# Send complete event on successful runs # Send complete event on successful runs
self.services.events.emit_invocation_complete( self._services.events.emit_invocation_complete(
queue_batch_id=queue_item.batch_id, queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id, queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id, queue_id=queue_item.queue_id,
@ -150,13 +150,13 @@ class DefaultSessionRunner(SessionRunnerBase):
stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback) stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback)
queue_item.session.set_node_error(invocation.id, stacktrace) queue_item.session.set_node_error(invocation.id, stacktrace)
self.services.logger.error( self._services.logger.error(
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{exc_type}" f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{exc_type}"
) )
self.services.logger.error(stacktrace) self._services.logger.error(stacktrace)
# Send error event # Send error event
self.services.events.emit_invocation_error( self._services.events.emit_invocation_error(
queue_batch_id=queue_item.session_id, queue_batch_id=queue_item.session_id,
queue_item_id=queue_item.item_id, queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id, queue_id=queue_item.queue_id,
@ -176,7 +176,7 @@ class DefaultSessionRunner(SessionRunnerBase):
"""Run a single node in the graph""" """Run a single node in the graph"""
try: try:
# Any unhandled exception is an invocation error & will fail the graph # Any unhandled exception is an invocation error & will fail the graph
with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id): with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id):
self._on_before_run_node(invocation, queue_item) self._on_before_run_node(invocation, queue_item)
data = InvocationContextData( data = InvocationContextData(
@ -186,12 +186,12 @@ class DefaultSessionRunner(SessionRunnerBase):
) )
context = build_invocation_context( context = build_invocation_context(
data=data, data=data,
services=self.services, services=self._services,
cancel_event=self.cancel_event, cancel_event=self._cancel_event,
) )
# Invoke the node # Invoke the node
outputs = invocation.invoke_internal(context=context, services=self.services) outputs = invocation.invoke_internal(context=context, services=self._services)
# Save outputs and history # Save outputs and history
queue_item.session.complete(invocation.id, outputs) queue_item.session.complete(invocation.id, outputs)