mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(app): make things in session runner private
This commit is contained in:
parent
f7c356d142
commit
cb8e9e1c7b
@ -52,9 +52,9 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
|
||||
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
|
||||
"""Start the session runner"""
|
||||
self.services = services
|
||||
self.cancel_event = cancel_event
|
||||
self.profiler = profiler
|
||||
self._services = services
|
||||
self._cancel_event = cancel_event
|
||||
self._profiler = profiler
|
||||
|
||||
def run(self, queue_item: SessionQueueItem):
|
||||
"""Run the graph"""
|
||||
@ -64,33 +64,33 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
|
||||
while True:
|
||||
invocation = queue_item.session.next()
|
||||
if invocation is None or self.cancel_event.is_set():
|
||||
if invocation is None or self._cancel_event.is_set():
|
||||
break
|
||||
self.run_node(invocation, queue_item)
|
||||
if queue_item.session.is_complete() or self.cancel_event.is_set():
|
||||
if queue_item.session.is_complete() or self._cancel_event.is_set():
|
||||
break
|
||||
|
||||
self._on_after_run_session(queue_item=queue_item)
|
||||
|
||||
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
|
||||
# If profiling is enabled, start the profiler
|
||||
if self.profiler is not None:
|
||||
self.profiler.start(profile_id=queue_item.session_id)
|
||||
if self._profiler is not None:
|
||||
self._profiler.start(profile_id=queue_item.session_id)
|
||||
|
||||
if self.on_before_run_session:
|
||||
self.on_before_run_session(queue_item=queue_item)
|
||||
|
||||
def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
|
||||
# If we are profiling, stop the profiler and dump the profile & stats
|
||||
if self.profiler is not None:
|
||||
profile_path = self.profiler.stop()
|
||||
if self._profiler is not None:
|
||||
profile_path = self._profiler.stop()
|
||||
stats_path = profile_path.with_suffix(".json")
|
||||
self.services.performance_statistics.dump_stats(
|
||||
self._services.performance_statistics.dump_stats(
|
||||
graph_execution_state_id=queue_item.session.id, output_path=stats_path
|
||||
)
|
||||
|
||||
# Send complete event
|
||||
self.services.events.emit_graph_execution_complete(
|
||||
self._services.events.emit_graph_execution_complete(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
@ -100,8 +100,8 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||
# we don't care about that - suppress the error.
|
||||
with suppress(GESStatsNotFoundError):
|
||||
self.services.performance_statistics.log_stats(queue_item.session.id)
|
||||
self.services.performance_statistics.reset_stats()
|
||||
self._services.performance_statistics.log_stats(queue_item.session.id)
|
||||
self._services.performance_statistics.reset_stats()
|
||||
|
||||
if self.on_after_run_session:
|
||||
self.on_after_run_session(queue_item)
|
||||
@ -109,7 +109,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||
"""Run before a node is executed"""
|
||||
# Send starting event
|
||||
self.services.events.emit_invocation_started(
|
||||
self._services.events.emit_invocation_started(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
@ -126,7 +126,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
):
|
||||
"""Run after a node is executed"""
|
||||
# Send complete event on successful runs
|
||||
self.services.events.emit_invocation_complete(
|
||||
self._services.events.emit_invocation_complete(
|
||||
queue_batch_id=queue_item.batch_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
@ -150,13 +150,13 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback)
|
||||
|
||||
queue_item.session.set_node_error(invocation.id, stacktrace)
|
||||
self.services.logger.error(
|
||||
self._services.logger.error(
|
||||
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{exc_type}"
|
||||
)
|
||||
self.services.logger.error(stacktrace)
|
||||
self._services.logger.error(stacktrace)
|
||||
|
||||
# Send error event
|
||||
self.services.events.emit_invocation_error(
|
||||
self._services.events.emit_invocation_error(
|
||||
queue_batch_id=queue_item.session_id,
|
||||
queue_item_id=queue_item.item_id,
|
||||
queue_id=queue_item.queue_id,
|
||||
@ -176,7 +176,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
"""Run a single node in the graph"""
|
||||
try:
|
||||
# Any unhandled exception is an invocation error & will fail the graph
|
||||
with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id):
|
||||
with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id):
|
||||
self._on_before_run_node(invocation, queue_item)
|
||||
|
||||
data = InvocationContextData(
|
||||
@ -186,12 +186,12 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
)
|
||||
context = build_invocation_context(
|
||||
data=data,
|
||||
services=self.services,
|
||||
cancel_event=self.cancel_event,
|
||||
services=self._services,
|
||||
cancel_event=self._cancel_event,
|
||||
)
|
||||
|
||||
# Invoke the node
|
||||
outputs = invocation.invoke_internal(context=context, services=self.services)
|
||||
outputs = invocation.invoke_internal(context=context, services=self._services)
|
||||
# Save outputs and history
|
||||
queue_item.session.complete(invocation.id, outputs)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user