mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tidy(app): rearrange proccessor
This commit is contained in:
parent
1d973f92ff
commit
33f9fe2c86
@ -30,6 +30,8 @@ from .session_processor_common import SessionProcessorStatus
|
|||||||
|
|
||||||
|
|
||||||
def get_stacktrace(exc_type: type, exc_value: BaseException, exc_traceback: TracebackType) -> str:
|
def get_stacktrace(exc_type: type, exc_value: BaseException, exc_traceback: TracebackType) -> str:
|
||||||
|
"""Formats a stacktrace as a string"""
|
||||||
|
|
||||||
return "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
|
return "".join(traceback.format_exception(exc_type, exc_value, exc_traceback))
|
||||||
|
|
||||||
|
|
||||||
@ -72,6 +74,54 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
|
|
||||||
self._on_after_run_session(queue_item=queue_item)
|
self._on_after_run_session(queue_item=queue_item)
|
||||||
|
|
||||||
|
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
||||||
|
"""Run a single node in the graph"""
|
||||||
|
try:
|
||||||
|
# Any unhandled exception is an invocation error & will fail the graph
|
||||||
|
with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id):
|
||||||
|
self._on_before_run_node(invocation, queue_item)
|
||||||
|
|
||||||
|
data = InvocationContextData(
|
||||||
|
invocation=invocation,
|
||||||
|
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
|
||||||
|
queue_item=queue_item,
|
||||||
|
)
|
||||||
|
context = build_invocation_context(
|
||||||
|
data=data,
|
||||||
|
services=self._services,
|
||||||
|
cancel_event=self._cancel_event,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Invoke the node
|
||||||
|
outputs = invocation.invoke_internal(context=context, services=self._services)
|
||||||
|
# Save outputs and history
|
||||||
|
queue_item.session.complete(invocation.id, outputs)
|
||||||
|
|
||||||
|
self._on_after_run_node(invocation, queue_item, outputs)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
# TODO(MM2): Create an event for this
|
||||||
|
pass
|
||||||
|
except CanceledException:
|
||||||
|
# When the user cancels the graph, we first set the cancel event. The event is checked
|
||||||
|
# between invocations, in this loop. Some invocations are long-running, and we need to
|
||||||
|
# be able to cancel them mid-execution.
|
||||||
|
#
|
||||||
|
# For example, denoising is a long-running invocation with many steps. A step callback
|
||||||
|
# is executed after each step. This step callback checks if the canceled event is set,
|
||||||
|
# then raises a CanceledException to stop execution immediately.
|
||||||
|
#
|
||||||
|
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
||||||
|
# loop go to its next iteration, and the cancel event will be handled correctly.
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
# Must extract the exception traceback here to not lose its stacktrace when we change scope
|
||||||
|
exc_type = type(e)
|
||||||
|
exc_value = e
|
||||||
|
exc_traceback = e.__traceback__
|
||||||
|
assert exc_traceback is not None
|
||||||
|
self._on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback)
|
||||||
|
|
||||||
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
|
def _on_before_run_session(self, queue_item: SessionQueueItem) -> None:
|
||||||
# If profiling is enabled, start the profiler
|
# If profiling is enabled, start the profiler
|
||||||
if self._profiler is not None:
|
if self._profiler is not None:
|
||||||
@ -172,54 +222,6 @@ class DefaultSessionRunner(SessionRunnerBase):
|
|||||||
for callback in self._on_node_error_callbacks:
|
for callback in self._on_node_error_callbacks:
|
||||||
callback(invocation, queue_item, exc_type, exc_value, exc_traceback)
|
callback(invocation, queue_item, exc_type, exc_value, exc_traceback)
|
||||||
|
|
||||||
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
|
|
||||||
"""Run a single node in the graph"""
|
|
||||||
try:
|
|
||||||
# Any unhandled exception is an invocation error & will fail the graph
|
|
||||||
with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id):
|
|
||||||
self._on_before_run_node(invocation, queue_item)
|
|
||||||
|
|
||||||
data = InvocationContextData(
|
|
||||||
invocation=invocation,
|
|
||||||
source_invocation_id=queue_item.session.prepared_source_mapping[invocation.id],
|
|
||||||
queue_item=queue_item,
|
|
||||||
)
|
|
||||||
context = build_invocation_context(
|
|
||||||
data=data,
|
|
||||||
services=self._services,
|
|
||||||
cancel_event=self._cancel_event,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Invoke the node
|
|
||||||
outputs = invocation.invoke_internal(context=context, services=self._services)
|
|
||||||
# Save outputs and history
|
|
||||||
queue_item.session.complete(invocation.id, outputs)
|
|
||||||
|
|
||||||
self._on_after_run_node(invocation, queue_item, outputs)
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
# TODO(MM2): Create an event for this
|
|
||||||
pass
|
|
||||||
except CanceledException:
|
|
||||||
# When the user cancels the graph, we first set the cancel event. The event is checked
|
|
||||||
# between invocations, in this loop. Some invocations are long-running, and we need to
|
|
||||||
# be able to cancel them mid-execution.
|
|
||||||
#
|
|
||||||
# For example, denoising is a long-running invocation with many steps. A step callback
|
|
||||||
# is executed after each step. This step callback checks if the canceled event is set,
|
|
||||||
# then raises a CanceledException to stop execution immediately.
|
|
||||||
#
|
|
||||||
# When we get a CanceledException, we don't need to do anything - just pass and let the
|
|
||||||
# loop go to its next iteration, and the cancel event will be handled correctly.
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
# Must extract the exception traceback here to not lose its stacktrace when we change scope
|
|
||||||
exc_type = type(e)
|
|
||||||
exc_value = e
|
|
||||||
exc_traceback = e.__traceback__
|
|
||||||
assert exc_traceback is not None
|
|
||||||
self._on_node_error(invocation, queue_item, exc_type, exc_value, exc_traceback)
|
|
||||||
|
|
||||||
|
|
||||||
class DefaultSessionProcessor(SessionProcessorBase):
|
class DefaultSessionProcessor(SessionProcessorBase):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -236,24 +238,6 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
self._thread_limit = thread_limit
|
self._thread_limit = thread_limit
|
||||||
self._polling_interval = polling_interval
|
self._polling_interval = polling_interval
|
||||||
|
|
||||||
def _on_non_fatal_processor_error(
|
|
||||||
self,
|
|
||||||
queue_item: Optional[SessionQueueItem],
|
|
||||||
exc_type: type,
|
|
||||||
exc_value: BaseException,
|
|
||||||
exc_traceback: TracebackType,
|
|
||||||
) -> None:
|
|
||||||
stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback)
|
|
||||||
# Non-fatal error in processor
|
|
||||||
self._invoker.services.logger.error(f"Non-fatal error in session processor:\n{stacktrace}")
|
|
||||||
# Cancel the queue item
|
|
||||||
if queue_item is not None:
|
|
||||||
self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
|
|
||||||
self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace)
|
|
||||||
|
|
||||||
for callback in self._on_non_fatal_processor_error_callbacks:
|
|
||||||
callback(exc_type, exc_value, exc_traceback, queue_item)
|
|
||||||
|
|
||||||
def start(self, invoker: Invoker) -> None:
|
def start(self, invoker: Invoker) -> None:
|
||||||
self._invoker: Invoker = invoker
|
self._invoker: Invoker = invoker
|
||||||
self._queue_item: Optional[SessionQueueItem] = None
|
self._queue_item: Optional[SessionQueueItem] = None
|
||||||
@ -396,3 +380,21 @@ class DefaultSessionProcessor(SessionProcessorBase):
|
|||||||
poll_now_event.clear()
|
poll_now_event.clear()
|
||||||
self._queue_item = None
|
self._queue_item = None
|
||||||
self._thread_semaphore.release()
|
self._thread_semaphore.release()
|
||||||
|
|
||||||
|
def _on_non_fatal_processor_error(
|
||||||
|
self,
|
||||||
|
queue_item: Optional[SessionQueueItem],
|
||||||
|
exc_type: type,
|
||||||
|
exc_value: BaseException,
|
||||||
|
exc_traceback: TracebackType,
|
||||||
|
) -> None:
|
||||||
|
stacktrace = get_stacktrace(exc_type, exc_value, exc_traceback)
|
||||||
|
# Non-fatal error in processor
|
||||||
|
self._invoker.services.logger.error(f"Non-fatal error in session processor:\n{stacktrace}")
|
||||||
|
# Cancel the queue item
|
||||||
|
if queue_item is not None:
|
||||||
|
self._invoker.services.session_queue.set_queue_item_session(queue_item.item_id, queue_item.session)
|
||||||
|
self._invoker.services.session_queue.cancel_queue_item(queue_item.item_id, error=stacktrace)
|
||||||
|
|
||||||
|
for callback in self._on_non_fatal_processor_error_callbacks:
|
||||||
|
callback(exc_type, exc_value, exc_traceback, queue_item)
|
||||||
|
Loading…
Reference in New Issue
Block a user