mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tidy(app): rearrange proccessor
This commit is contained in:
parent
1d973f92ff
commit
33f9fe2c86
@ -30,6 +30,8 @@ from .session_processor_common import SessionProcessorStatus
|
||||
|
||||
|
||||
def get_stacktrace(exc_type: type, exc_value: BaseException, exc_traceback: TracebackType) -> str:
|
||||
"""Formats a stacktrace as a string"""
|
||||
|
||||
return "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
|
||||
|
||||
|
||||
@ -72,6 +74,54 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
|
||||
self._on_after_run_session(queue_item=queue_item)
|
||||
|
||||
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||
"""Run a single node in the graph"""
|
||||
try:
|
||||
# Any unhandled exception is an invocation error & will fail the graph
|
||||
with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id):
|
||||
self._on_before_run_node(invocation, queue_item)
|
||||
|
||||
data = InvocationContextData(
|
||||
invocation=invocation,
|
||||
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
queue_item=queue_item,
|
||||
)
|
||||
context = build_invocation_context(
|
||||
data=data,
|
||||
services=self._services,
|
||||
cancel_event=self._cancel_event,
|
||||
)
|
||||
|
||||
# Invoke the node
|
||||
outputs = invocation.invoke_internal(context=context, services=self._services)
|
||||
# Save outputs and history
|
||||
queue_item.session.complete(invocation.id, outputs)
|
||||
|
||||
self._on_after_run_node(invocation, queue_item, outputs)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
# TODO(MM2): Create an event for this
|
||||
pass
|
||||
except CanceledException:
|
||||
# When the user cancels the graph, we first set the cancel event. The event is checked
|
||||
# between invocations, in this loop. Some invocations are long-running, and we need to
|
||||
# be able to cancel them mid-execution.
|
||||
#
|
||||
# For example, denoising is a long-running invocation with many steps. A step callback
|
||||
# is executed after each step. This step callback checks if the canceled event is set,
|
||||
# then raises a CanceledException to stop execution immediately.
|
||||
#
|
||||
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||
pass
|
||||
except Exception as e:
|
||||
# Must extract the exception traceback here to not lose its stacktrace when we change scope
|
||||
exc_type = type(e)
|
||||
exc_value = e
|
||||
exc_traceback = e.__traceback__
|
||||
assert exc_traceback is not None
|
||||
self._on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback)
|
||||
|
||||
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
|
||||
# If profiling is enabled, start the profiler
|
||||
if self._profiler is not None:
|
||||
@ -172,54 +222,6 @@ class DefaultSessionRunner(SessionRunnerBase):
|
||||
for callback in self._on_node_error_callbacks:
|
||||
callback(invocation, queue_item, exc_type, exc_value, exc_traceback)
|
||||
|
||||
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||
"""Run a single node in the graph"""
|
||||
try:
|
||||
# Any unhandled exception is an invocation error & will fail the graph
|
||||
with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id):
|
||||
self._on_before_run_node(invocation, queue_item)
|
||||
|
||||
data = InvocationContextData(
|
||||
invocation=invocation,
|
||||
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||
queue_item=queue_item,
|
||||
)
|
||||
context = build_invocation_context(
|
||||
data=data,
|
||||
services=self._services,
|
||||
cancel_event=self._cancel_event,
|
||||
)
|
||||
|
||||
# Invoke the node
|
||||
outputs = invocation.invoke_internal(context=context, services=self._services)
|
||||
# Save outputs and history
|
||||
queue_item.session.complete(invocation.id, outputs)
|
||||
|
||||
self._on_after_run_node(invocation, queue_item, outputs)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
# TODO(MM2): Create an event for this
|
||||
pass
|
||||
except CanceledException:
|
||||
# When the user cancels the graph, we first set the cancel event. The event is checked
|
||||
# between invocations, in this loop. Some invocations are long-running, and we need to
|
||||
# be able to cancel them mid-execution.
|
||||
#
|
||||
# For example, denoising is a long-running invocation with many steps. A step callback
|
||||
# is executed after each step. This step callback checks if the canceled event is set,
|
||||
# then raises a CanceledException to stop execution immediately.
|
||||
#
|
||||
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||
pass
|
||||
except Exception as e:
|
||||
# Must extract the exception traceback here to not lose its stacktrace when we change scope
|
||||
exc_type = type(e)
|
||||
exc_value = e
|
||||
exc_traceback = e.__traceback__
|
||||
assert exc_traceback is not None
|
||||
self._on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback)
|
||||
|
||||
|
||||
class DefaultSessionProcessor(SessionProcessorBase):
|
||||
def __init__(
|
||||
@ -236,24 +238,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
self._thread_limit = thread_limit
|
||||
self._polling_interval = polling_interval
|
||||
|
||||
def _on_non_fatal_processor_error(
|
||||
self,
|
||||
queue_item: Optional[SessionQueueItem],
|
||||
exc_type: type,
|
||||
exc_value: BaseException,
|
||||
exc_traceback: TracebackType,
|
||||
) -> None:
|
||||
stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback)
|
||||
# Non-fatal error in processor
|
||||
self._invoker.services.logger.error(f"Non-fatal error in session processor:\n{stacktrace}")
|
||||
# Cancel the queue item
|
||||
if queue_item is not None:
|
||||
self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
|
||||
self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace)
|
||||
|
||||
for callback in self._on_non_fatal_processor_error_callbacks:
|
||||
callback(exc_type, exc_value, exc_traceback, queue_item)
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker: Invoker = invoker
|
||||
self._queue_item: Optional[SessionQueueItem] = None
|
||||
@ -396,3 +380,21 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
||||
poll_now_event.clear()
|
||||
self._queue_item = None
|
||||
self._thread_semaphore.release()
|
||||
|
||||
def _on_non_fatal_processor_error(
|
||||
self,
|
||||
queue_item: Optional[SessionQueueItem],
|
||||
exc_type: type,
|
||||
exc_value: BaseException,
|
||||
exc_traceback: TracebackType,
|
||||
) -> None:
|
||||
stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback)
|
||||
# Non-fatal error in processor
|
||||
self._invoker.services.logger.error(f"Non-fatal error in session processor:\n{stacktrace}")
|
||||
# Cancel the queue item
|
||||
if queue_item is not None:
|
||||
self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
|
||||
self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace)
|
||||
|
||||
for callback in self._on_non_fatal_processor_error_callbacks:
|
||||
callback(exc_type, exc_value, exc_traceback, queue_item)
|
||||
|
Loading…
Reference in New Issue
Block a user