tidy(app): rearrange proccessor

This commit is contained in:
psychedelicious 2024-05-22 19:07:57 +10:00
parent 1d973f92ff
commit 33f9fe2c86

View File

@ -30,6 +30,8 @@ from .session_processor_common import SessionProcessorStatus
def get_stacktrace(exc_type: type, exc_value: BaseException, exc_traceback: TracebackType) -> str:
"""Formats a stacktrace as a string"""
return "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
@ -72,6 +74,54 @@ class DefaultSessionRunner(SessionRunnerBase):
self._on_after_run_session(queue_item=queue_item)
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Run a single node in the graph"""
try:
# Any unhandled exception is an invocation error & will fail the graph
with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id):
self._on_before_run_node(invocation, queue_item)
data = InvocationContextData(
invocation=invocation,
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
queue_item=queue_item,
)
context = build_invocation_context(
data=data,
services=self._services,
cancel_event=self._cancel_event,
)
# Invoke the node
outputs = invocation.invoke_internal(context=context, services=self._services)
# Save outputs and history
queue_item.session.complete(invocation.id, outputs)
self._on_after_run_node(invocation, queue_item, outputs)
except KeyboardInterrupt:
# TODO(MM2): Create an event for this
pass
except CanceledException:
# When the user cancels the graph, we first set the cancel event. The event is checked
# between invocations, in this loop. Some invocations are long-running, and we need to
# be able to cancel them mid-execution.
#
# For example, denoising is a long-running invocation with many steps. A step callback
# is executed after each step. This step callback checks if the canceled event is set,
# then raises a CanceledException to stop execution immediately.
#
# When we get a CanceledException, we don't need to do anything - just pass and let the
# loop go to its next iteration, and the cancel event will be handled correctly.
pass
except Exception as e:
# Must extract the exception traceback here to not lose its stacktrace when we change scope
exc_type = type(e)
exc_value = e
exc_traceback = e.__traceback__
assert exc_traceback is not None
self._on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback)
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
# If profiling is enabled, start the profiler
if self._profiler is not None:
@ -172,54 +222,6 @@ class DefaultSessionRunner(SessionRunnerBase):
for callback in self._on_node_error_callbacks:
callback(invocation, queue_item, exc_type, exc_value, exc_traceback)
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
"""Run a single node in the graph"""
try:
# Any unhandled exception is an invocation error & will fail the graph
with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id):
self._on_before_run_node(invocation, queue_item)
data = InvocationContextData(
invocation=invocation,
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
queue_item=queue_item,
)
context = build_invocation_context(
data=data,
services=self._services,
cancel_event=self._cancel_event,
)
# Invoke the node
outputs = invocation.invoke_internal(context=context, services=self._services)
# Save outputs and history
queue_item.session.complete(invocation.id, outputs)
self._on_after_run_node(invocation, queue_item, outputs)
except KeyboardInterrupt:
# TODO(MM2): Create an event for this
pass
except CanceledException:
# When the user cancels the graph, we first set the cancel event. The event is checked
# between invocations, in this loop. Some invocations are long-running, and we need to
# be able to cancel them mid-execution.
#
# For example, denoising is a long-running invocation with many steps. A step callback
# is executed after each step. This step callback checks if the canceled event is set,
# then raises a CanceledException to stop execution immediately.
#
# When we get a CanceledException, we don't need to do anything - just pass and let the
# loop go to its next iteration, and the cancel event will be handled correctly.
pass
except Exception as e:
# Must extract the exception traceback here to not lose its stacktrace when we change scope
exc_type = type(e)
exc_value = e
exc_traceback = e.__traceback__
assert exc_traceback is not None
self._on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback)
class DefaultSessionProcessor(SessionProcessorBase):
def __init__(
@ -236,24 +238,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
self._thread_limit = thread_limit
self._polling_interval = polling_interval
def _on_non_fatal_processor_error(
self,
queue_item: Optional[SessionQueueItem],
exc_type: type,
exc_value: BaseException,
exc_traceback: TracebackType,
) -> None:
stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback)
# Non-fatal error in processor
self._invoker.services.logger.error(f"Non-fatal error in session processor:\n{stacktrace}")
# Cancel the queue item
if queue_item is not None:
self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace)
for callback in self._on_non_fatal_processor_error_callbacks:
callback(exc_type, exc_value, exc_traceback, queue_item)
def start(self, invoker: Invoker) -> None:
self._invoker: Invoker = invoker
self._queue_item: Optional[SessionQueueItem] = None
@ -396,3 +380,21 @@ class DefaultSessionProcessor(SessionProcessorBase):
poll_now_event.clear()
self._queue_item = None
self._thread_semaphore.release()
def _on_non_fatal_processor_error(
self,
queue_item: Optional[SessionQueueItem],
exc_type: type,
exc_value: BaseException,
exc_traceback: TracebackType,
) -> None:
stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback)
# Non-fatal error in processor
self._invoker.services.logger.error(f"Non-fatal error in session processor:\n{stacktrace}")
# Cancel the queue item
if queue_item is not None:
self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace)
for callback in self._on_non_fatal_processor_error_callbacks:
callback(exc_type, exc_value, exc_traceback, queue_item)