feat(app): make things in session runner private

This commit is contained in:
psychedelicious 2024-05-22 18:55:33 +10:00
parent f7c356d142
commit cb8e9e1c7b

View File

@ -52,9 +52,9 @@ class DefaultSessionRunner(SessionRunnerBase):
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
"""Start the session runner"""
self.services = services
self.cancel_event = cancel_event
self.profiler = profiler
self._services = services
self._cancel_event = cancel_event
self._profiler = profiler
def run(self, queue_item: SessionQueueItem):
"""Run the graph"""
@ -64,33 +64,33 @@ class DefaultSessionRunner(SessionRunnerBase):
while True:
invocation = queue_item.session.next()
if invocation is None or self.cancel_event.is_set():
if invocation is None or self._cancel_event.is_set():
break
self.run_node(invocation, queue_item)
if queue_item.session.is_complete() or self.cancel_event.is_set():
if queue_item.session.is_complete() or self._cancel_event.is_set():
break
self._on_after_run_session(queue_item=queue_item)
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
# If profiling is enabled, start the profiler
if self.profiler is not None:
self.profiler.start(profile_id=queue_item.session_id)
if self._profiler is not None:
self._profiler.start(profile_id=queue_item.session_id)
if self.on_before_run_session:
self.on_before_run_session(queue_item=queue_item)
def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
# If we are profiling, stop the profiler and dump the profile & stats
if self.profiler is not None:
profile_path = self.profiler.stop()
if self._profiler is not None:
profile_path = self._profiler.stop()
stats_path = profile_path.with_suffix(".json")
self.services.performance_statistics.dump_stats(
self._services.performance_statistics.dump_stats(
graph_execution_state_id=queue_item.session.id, output_path=stats_path
)
# Send complete event
self.services.events.emit_graph_execution_complete(
self._services.events.emit_graph_execution_complete(
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
@ -100,8 +100,8 @@ class DefaultSessionRunner(SessionRunnerBase):
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self.services.performance_statistics.log_stats(queue_item.session.id)
self.services.performance_statistics.reset_stats()
self._services.performance_statistics.log_stats(queue_item.session.id)
self._services.performance_statistics.reset_stats()
if self.on_after_run_session:
self.on_after_run_session(queue_item)
@ -109,7 +109,7 @@ class DefaultSessionRunner(SessionRunnerBase):
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Run before a node is executed"""
# Send starting event
self.services.events.emit_invocation_started(
self._services.events.emit_invocation_started(
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
@ -126,7 +126,7 @@ class DefaultSessionRunner(SessionRunnerBase):
):
"""Run after a node is executed"""
# Send complete event on successful runs
self.services.events.emit_invocation_complete(
self._services.events.emit_invocation_complete(
queue_batch_id=queue_item.batch_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
@ -150,13 +150,13 @@ class DefaultSessionRunner(SessionRunnerBase):
stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback)
queue_item.session.set_node_error(invocation.id, stacktrace)
self.services.logger.error(
self._services.logger.error(
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{exc_type}"
)
self.services.logger.error(stacktrace)
self._services.logger.error(stacktrace)
# Send error event
self.services.events.emit_invocation_error(
self._services.events.emit_invocation_error(
queue_batch_id=queue_item.session_id,
queue_item_id=queue_item.item_id,
queue_id=queue_item.queue_id,
@ -176,7 +176,7 @@ class DefaultSessionRunner(SessionRunnerBase):
"""Run a single node in the graph"""
try:
# Any unhandled exception is an invocation error & will fail the graph
with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id):
with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id):
self._on_before_run_node(invocation, queue_item)
data = InvocationContextData(
@ -186,12 +186,12 @@ class DefaultSessionRunner(SessionRunnerBase):
)
context = build_invocation_context(
data=data,
services=self.services,
cancel_event=self.cancel_event,
services=self._services,
cancel_event=self._cancel_event,
)
# Invoke the node
outputs = invocation.invoke_internal(context=context, services=self.services)
outputs = invocation.invoke_internal(context=context, services=self._services)
# Save outputs and history
queue_item.session.complete(invocation.id, outputs)