Merge branch 'psyche/fix/nodes/processor-cpu-usage' into lstein/feat/multi-gpu

This commit is contained in:
Lincoln Stein 2024-03-31 17:05:23 -04:00
commit cef51ad80d

View File

@ -122,152 +122,150 @@ class DefaultSessionProcessor(SessionProcessorBase):
# Middle processor try block; any unhandled exception is a non-fatal processor error # Middle processor try block; any unhandled exception is a non-fatal processor error
try: try:
# If we are paused, wait for resume event # If we are paused, wait for resume event
if resume_event.is_set(): resume_event.wait()
# Get the next session to process
self._queue_item = self._invoker.services.session_queue.dequeue()
if self._queue_item is not None: # Get the next session to process
self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}") self._queue_item = self._invoker.services.session_queue.dequeue()
cancel_event.clear()
# If profiling is enabled, start the profiler if self._queue_item is None:
if self._profiler is not None: # The queue was empty, wait for next polling interval or event to try again
self._profiler.start(profile_id=self._queue_item.session_id) self._invoker.services.logger.debug("Waiting for next polling interval or event")
poll_now_event.wait(self._polling_interval)
continue
# Prepare invocations and take the first self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}")
self._invocation = self._queue_item.session.next() cancel_event.clear()
# Loop over invocations until the session is complete or canceled # If profiling is enabled, start the profiler
while self._invocation is not None and not cancel_event.is_set(): if self._profiler is not None:
# get the source node id to provide to clients (the prepared node id is not as useful) self._profiler.start(profile_id=self._queue_item.session_id)
source_invocation_id = self._queue_item.session.prepared_source_mapping[
self._invocation.id
]
# Send starting event # Prepare invocations and take the first
self._invoker.services.events.emit_invocation_started( self._invocation = self._queue_item.session.next()
# Loop over invocations until the session is complete or canceled
while self._invocation is not None and not cancel_event.is_set():
# get the source node id to provide to clients (the prepared node id is not as useful)
source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id]
# Send starting event
self._invoker.services.events.emit_invocation_started(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session_id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
)
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
try:
with self._invoker.services.performance_statistics.collect_stats(
self._invocation, self._queue_item.session.id
):
# Build invocation context (the node-facing API)
data = InvocationContextData(
invocation=self._invocation,
source_invocation_id=source_invocation_id,
queue_item=self._queue_item,
)
context = build_invocation_context(
data=data,
services=self._invoker.services,
cancel_event=self._cancel_event,
)
# Invoke the node
outputs = self._invocation.invoke_internal(
context=context, services=self._invoker.services
)
# Save outputs and history
self._queue_item.session.complete(self._invocation.id, outputs)
# Send complete event
self._invoker.services.events.emit_invocation_complete(
queue_batch_id=self._queue_item.batch_id, queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id, queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id, queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session_id, graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(), node=self._invocation.model_dump(),
source_node_id=source_invocation_id, source_node_id=source_invocation_id,
result=outputs.model_dump(),
) )
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph except KeyboardInterrupt:
try: # TODO(MM2): Create an event for this
with self._invoker.services.performance_statistics.collect_stats( pass
self._invocation, self._queue_item.session.id
):
# Build invocation context (the node-facing API)
data = InvocationContextData(
invocation=self._invocation,
source_invocation_id=source_invocation_id,
queue_item=self._queue_item,
)
context = build_invocation_context(
data=data,
services=self._invoker.services,
cancel_event=self._cancel_event,
)
# Invoke the node except CanceledException:
outputs = self._invocation.invoke_internal( # When the user cancels the graph, we first set the cancel event. The event is checked
context=context, services=self._invoker.services # between invocations, in this loop. Some invocations are long-running, and we need to
) # be able to cancel them mid-execution.
#
# For example, denoising is a long-running invocation with many steps. A step callback
# is executed after each step. This step callback checks if the canceled event is set,
# then raises a CanceledException to stop execution immediately.
#
# When we get a CanceledException, we don't need to do anything - just pass and let the
# loop go to its next iteration, and the cancel event will be handled correctly.
pass
# Save outputs and history except Exception as e:
self._queue_item.session.complete(self._invocation.id, outputs) error = traceback.format_exc()
# Send complete event # Save error
self._invoker.services.events.emit_invocation_complete( self._queue_item.session.set_node_error(self._invocation.id, error)
queue_batch_id=self._queue_item.batch_id, self._invoker.services.logger.error(
queue_item_id=self._queue_item.item_id, f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
queue_id=self._queue_item.queue_id, )
graph_execution_state_id=self._queue_item.session.id, self._invoker.services.logger.error(error)
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
result=outputs.model_dump(),
)
except KeyboardInterrupt: # Send error event
# TODO(MM2): Create an event for this self._invoker.services.events.emit_invocation_error(
pass queue_batch_id=self._queue_item.session_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
error_type=e.__class__.__name__,
error=error,
)
pass
except CanceledException: # The session is complete if the all invocations are complete or there was an error
# When the user cancels the graph, we first set the cancel event. The event is checked if self._queue_item.session.is_complete() or cancel_event.is_set():
# between invocations, in this loop. Some invocations are long-running, and we need to # Send complete event
# be able to cancel them mid-execution. self._invoker.services.events.emit_graph_execution_complete(
# queue_batch_id=self._queue_item.batch_id,
# For example, denoising is a long-running invocation with many steps. A step callback queue_item_id=self._queue_item.item_id,
# is executed after each step. This step callback checks if the canceled event is set, queue_id=self._queue_item.queue_id,
# then raises a CanceledException to stop execution immediately. graph_execution_state_id=self._queue_item.session.id,
# )
# When we get a CanceledException, we don't need to do anything - just pass and let the # If we are profiling, stop the profiler and dump the profile & stats
# loop go to its next iteration, and the cancel event will be handled correctly. if self._profiler:
pass profile_path = self._profiler.stop()
stats_path = profile_path.with_suffix(".json")
self._invoker.services.performance_statistics.dump_stats(
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
)
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id)
self._invoker.services.performance_statistics.reset_stats()
except Exception as e: # Set the invocation to None to prepare for the next session
error = traceback.format_exc() self._invocation = None
# Save error
self._queue_item.session.set_node_error(self._invocation.id, error)
self._invoker.services.logger.error(
f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}"
)
self._invoker.services.logger.error(error)
# Send error event
self._invoker.services.events.emit_invocation_error(
queue_batch_id=self._queue_item.session_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
node=self._invocation.model_dump(),
source_node_id=source_invocation_id,
error_type=e.__class__.__name__,
error=error,
)
pass
# The session is complete if the all invocations are complete or there was an error
if self._queue_item.session.is_complete() or cancel_event.is_set():
# Send complete event
self._invoker.services.events.emit_graph_execution_complete(
queue_batch_id=self._queue_item.batch_id,
queue_item_id=self._queue_item.item_id,
queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id,
)
# If we are profiling, stop the profiler and dump the profile & stats
if self._profiler:
profile_path = self._profiler.stop()
stats_path = profile_path.with_suffix(".json")
self._invoker.services.performance_statistics.dump_stats(
graph_execution_state_id=self._queue_item.session.id, output_path=stats_path
)
# We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor
# we don't care about that - suppress the error.
with suppress(GESStatsNotFoundError):
self._invoker.services.performance_statistics.log_stats(
self._queue_item.session.id
)
self._invoker.services.performance_statistics.reset_stats()
# Set the invocation to None to prepare for the next session
self._invocation = None
else:
# Prepare the next invocation
self._invocation = self._queue_item.session.next()
# The session is complete, immediately poll for next session
self._queue_item = None
poll_now_event.set()
else: else:
# The queue was empty, wait for next polling interval or event to try again # Prepare the next invocation
self._invoker.services.logger.debug("Waiting for next polling interval or event") self._invocation = self._queue_item.session.next()
poll_now_event.wait(self._polling_interval) else:
continue # The queue was empty, wait for next polling interval or event to try again
self._invoker.services.logger.debug("Waiting for next polling interval or event")
poll_now_event.wait(self._polling_interval)
continue
except Exception: except Exception:
# Non-fatal error in processor # Non-fatal error in processor
self._invoker.services.logger.error( self._invoker.services.logger.error(