mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(app): make things in session runner private
This commit is contained in:
parent
f7c356d142
commit
cb8e9e1c7b
@ -52,9 +52,9 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
|
|
||||||
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
|
def start(self, services: InvocationServices, cancel_event: ThreadEvent, profiler: Optional[Profiler] = None):
|
||||||
"""Start the session runner"""
|
"""Start the session runner"""
|
||||||
self.services = services
|
self._services = services
|
||||||
self.cancel_event = cancel_event
|
self._cancel_event = cancel_event
|
||||||
self.profiler = profiler
|
self._profiler = profiler
|
||||||
|
|
||||||
def run(self, queue_item: SessionQueueItem):
|
def run(self, queue_item: SessionQueueItem):
|
||||||
"""Run the graph"""
|
"""Run the graph"""
|
||||||
@ -64,33 +64,33 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
|
|
||||||
while True:
|
while True:
|
||||||
invocation = queue_item.session.next()
|
invocation = queue_item.session.next()
|
||||||
if invocation is None or self.cancel_event.is_set():
|
if invocation is None or self._cancel_event.is_set():
|
||||||
break
|
break
|
||||||
self.run_node(invocation, queue_item)
|
self.run_node(invocation, queue_item)
|
||||||
if queue_item.session.is_complete() or self.cancel_event.is_set():
|
if queue_item.session.is_complete() or self._cancel_event.is_set():
|
||||||
break
|
break
|
||||||
|
|
||||||
self._on_after_run_session(queue_item=queue_item)
|
self._on_after_run_session(queue_item=queue_item)
|
||||||
|
|
||||||
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
|
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
|
||||||
# If profiling is enabled, start the profiler
|
# If profiling is enabled, start the profiler
|
||||||
if self.profiler is not None:
|
if self._profiler is not None:
|
||||||
self.profiler.start(profile_id=queue_item.session_id)
|
self._profiler.start(profile_id=queue_item.session_id)
|
||||||
|
|
||||||
if self.on_before_run_session:
|
if self.on_before_run_session:
|
||||||
self.on_before_run_session(queue_item=queue_item)
|
self.on_before_run_session(queue_item=queue_item)
|
||||||
|
|
||||||
def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
|
def _on_after_run_session(self, queue_item: SessionQueueItem) -> None:
|
||||||
# If we are profiling, stop the profiler and dump the profile & stats
|
# If we are profiling, stop the profiler and dump the profile & stats
|
||||||
if self.profiler is not None:
|
if self._profiler is not None:
|
||||||
profile_path = self.profiler.stop()
|
profile_path = self._profiler.stop()
|
||||||
stats_path = profile_path.with_suffix(".json")
|
stats_path = profile_path.with_suffix(".json")
|
||||||
self.services.performance_statistics.dump_stats(
|
self._services.performance_statistics.dump_stats(
|
||||||
graph_execution_state_id=queue_item.session.id, output_path=stats_path
|
graph_execution_state_id=queue_item.session.id, output_path=stats_path
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send complete event
|
# Send complete event
|
||||||
self.services.events.emit_graph_execution_complete(
|
self._services.events.emit_graph_execution_complete(
|
||||||
queue_batch_id=queue_item.batch_id,
|
queue_batch_id=queue_item.batch_id,
|
||||||
queue_item_id=queue_item.item_id,
|
queue_item_id=queue_item.item_id,
|
||||||
queue_id=queue_item.queue_id,
|
queue_id=queue_item.queue_id,
|
||||||
@ -100,8 +100,8 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
|
||||||
# we don't care about that - suppress the error.
|
# we don't care about that - suppress the error.
|
||||||
with suppress(GESStatsNotFoundError):
|
with suppress(GESStatsNotFoundError):
|
||||||
self.services.performance_statistics.log_stats(queue_item.session.id)
|
self._services.performance_statistics.log_stats(queue_item.session.id)
|
||||||
self.services.performance_statistics.reset_stats()
|
self._services.performance_statistics.reset_stats()
|
||||||
|
|
||||||
if self.on_after_run_session:
|
if self.on_after_run_session:
|
||||||
self.on_after_run_session(queue_item)
|
self.on_after_run_session(queue_item)
|
||||||
@ -109,7 +109,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
def _on_before_run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||||
"""Run before a node is executed"""
|
"""Run before a node is executed"""
|
||||||
# Send starting event
|
# Send starting event
|
||||||
self.services.events.emit_invocation_started(
|
self._services.events.emit_invocation_started(
|
||||||
queue_batch_id=queue_item.batch_id,
|
queue_batch_id=queue_item.batch_id,
|
||||||
queue_item_id=queue_item.item_id,
|
queue_item_id=queue_item.item_id,
|
||||||
queue_id=queue_item.queue_id,
|
queue_id=queue_item.queue_id,
|
||||||
@ -126,7 +126,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
):
|
):
|
||||||
"""Run after a node is executed"""
|
"""Run after a node is executed"""
|
||||||
# Send complete event on successful runs
|
# Send complete event on successful runs
|
||||||
self.services.events.emit_invocation_complete(
|
self._services.events.emit_invocation_complete(
|
||||||
queue_batch_id=queue_item.batch_id,
|
queue_batch_id=queue_item.batch_id,
|
||||||
queue_item_id=queue_item.item_id,
|
queue_item_id=queue_item.item_id,
|
||||||
queue_id=queue_item.queue_id,
|
queue_id=queue_item.queue_id,
|
||||||
@ -150,13 +150,13 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback)
|
stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback)
|
||||||
|
|
||||||
queue_item.session.set_node_error(invocation.id, stacktrace)
|
queue_item.session.set_node_error(invocation.id, stacktrace)
|
||||||
self.services.logger.error(
|
self._services.logger.error(
|
||||||
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{exc_type}"
|
f"Error while invoking session {queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{exc_type}"
|
||||||
)
|
)
|
||||||
self.services.logger.error(stacktrace)
|
self._services.logger.error(stacktrace)
|
||||||
|
|
||||||
# Send error event
|
# Send error event
|
||||||
self.services.events.emit_invocation_error(
|
self._services.events.emit_invocation_error(
|
||||||
queue_batch_id=queue_item.session_id,
|
queue_batch_id=queue_item.session_id,
|
||||||
queue_item_id=queue_item.item_id,
|
queue_item_id=queue_item.item_id,
|
||||||
queue_id=queue_item.queue_id,
|
queue_id=queue_item.queue_id,
|
||||||
@ -176,7 +176,7 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
"""Run a single node in the graph"""
|
"""Run a single node in the graph"""
|
||||||
try:
|
try:
|
||||||
# Any unhandled exception is an invocation error & will fail the graph
|
# Any unhandled exception is an invocation error & will fail the graph
|
||||||
with self.services.performance_statistics.collect_stats(invocation, queue_item.session_id):
|
with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id):
|
||||||
self._on_before_run_node(invocation, queue_item)
|
self._on_before_run_node(invocation, queue_item)
|
||||||
|
|
||||||
data = InvocationContextData(
|
data = InvocationContextData(
|
||||||
@ -186,12 +186,12 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
)
|
)
|
||||||
context = build_invocation_context(
|
context = build_invocation_context(
|
||||||
data=data,
|
data=data,
|
||||||
services=self.services,
|
services=self._services,
|
||||||
cancel_event=self.cancel_event,
|
cancel_event=self._cancel_event,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Invoke the node
|
# Invoke the node
|
||||||
outputs = invocation.invoke_internal(context=context, services=self.services)
|
outputs = invocation.invoke_internal(context=context, services=self._services)
|
||||||
# Save outputs and history
|
# Save outputs and history
|
||||||
queue_item.session.complete(invocation.id, outputs)
|
queue_item.session.complete(invocation.id, outputs)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user