diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index eb868391b8..d6791fbd57 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -122,152 +122,150 @@ class DefaultSessionProcessor(SessionProcessorBase): # Middle processor try block; any unhandled exception is a non-fatal processor error try: # If we are paused, wait for resume event - if resume_event.is_set(): - # Get the next session to process - self._queue_item = self._invoker.services.session_queue.dequeue() + resume_event.wait() - if self._queue_item is not None: - self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}") - cancel_event.clear() + # Get the next session to process + self._queue_item = self._invoker.services.session_queue.dequeue() - # If profiling is enabled, start the profiler - if self._profiler is not None: - self._profiler.start(profile_id=self._queue_item.session_id) + if self._queue_item is None: + # The queue was empty, wait for next polling interval or event to try again + self._invoker.services.logger.debug("Waiting for next polling interval or event") + poll_now_event.wait(self._polling_interval) + continue - # Prepare invocations and take the first - self._invocation = self._queue_item.session.next() + self._invoker.services.logger.debug(f"Executing queue item {self._queue_item.item_id}") + cancel_event.clear() - # Loop over invocations until the session is complete or canceled - while self._invocation is not None and not cancel_event.is_set(): - # get the source node id to provide to clients (the prepared node id is not as useful) - source_invocation_id = self._queue_item.session.prepared_source_mapping[ - self._invocation.id - ] + # If profiling is enabled, start the profiler + if self._profiler is not None: + self._profiler.start(profile_id=self._queue_item.session_id) - # Send starting event - self._invoker.services.events.emit_invocation_started( + # Prepare invocations and take the first + self._invocation = self._queue_item.session.next() + + # Loop over invocations until the session is complete or canceled + while self._invocation is not None and not cancel_event.is_set(): + # get the source node id to provide to clients (the prepared node id is not as useful) + source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id] + + # Send starting event + self._invoker.services.events.emit_invocation_started( + queue_batch_id=self._queue_item.batch_id, + queue_item_id=self._queue_item.item_id, + queue_id=self._queue_item.queue_id, + graph_execution_state_id=self._queue_item.session_id, + node=self._invocation.model_dump(), + source_node_id=source_invocation_id, + ) + + # Innermost processor try block; any unhandled exception is an invocation error & will fail the graph + try: + with self._invoker.services.performance_statistics.collect_stats( + self._invocation, self._queue_item.session.id + ): + # Build invocation context (the node-facing API) + data = InvocationContextData( + invocation=self._invocation, + source_invocation_id=source_invocation_id, + queue_item=self._queue_item, + ) + context = build_invocation_context( + data=data, + services=self._invoker.services, + cancel_event=self._cancel_event, + ) + + # Invoke the node + outputs = self._invocation.invoke_internal( + context=context, services=self._invoker.services + ) + + # Save outputs and history + self._queue_item.session.complete(self._invocation.id, outputs) + + # Send complete event + self._invoker.services.events.emit_invocation_complete( queue_batch_id=self._queue_item.batch_id, queue_item_id=self._queue_item.item_id, queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session_id, + graph_execution_state_id=self._queue_item.session.id, node=self._invocation.model_dump(), source_node_id=source_invocation_id, + result=outputs.model_dump(), ) - # Innermost processor try block; any unhandled exception is an invocation error & will fail the graph - try: - with self._invoker.services.performance_statistics.collect_stats( - self._invocation, self._queue_item.session.id - ): - # Build invocation context (the node-facing API) - data = InvocationContextData( - invocation=self._invocation, - source_invocation_id=source_invocation_id, - queue_item=self._queue_item, - ) - context = build_invocation_context( - data=data, - services=self._invoker.services, - cancel_event=self._cancel_event, - ) + except KeyboardInterrupt: + # TODO(MM2): Create an event for this + pass - # Invoke the node - outputs = self._invocation.invoke_internal( - context=context, services=self._invoker.services - ) + except CanceledException: + # When the user cancels the graph, we first set the cancel event. The event is checked + # between invocations, in this loop. Some invocations are long-running, and we need to + # be able to cancel them mid-execution. + # + # For example, denoising is a long-running invocation with many steps. A step callback + # is executed after each step. This step callback checks if the canceled event is set, + # then raises a CanceledException to stop execution immediately. + # + # When we get a CanceledException, we don't need to do anything - just pass and let the + # loop go to its next iteration, and the cancel event will be handled correctly. + pass - # Save outputs and history - self._queue_item.session.complete(self._invocation.id, outputs) + except Exception as e: + error = traceback.format_exc() - # Send complete event - self._invoker.services.events.emit_invocation_complete( - queue_batch_id=self._queue_item.batch_id, - queue_item_id=self._queue_item.item_id, - queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session.id, - node=self._invocation.model_dump(), - source_node_id=source_invocation_id, - result=outputs.model_dump(), - ) + # Save error + self._queue_item.session.set_node_error(self._invocation.id, error) + self._invoker.services.logger.error( + f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}" + ) + self._invoker.services.logger.error(error) - except KeyboardInterrupt: - # TODO(MM2): Create an event for this - pass + # Send error event + self._invoker.services.events.emit_invocation_error( + queue_batch_id=self._queue_item.session_id, + queue_item_id=self._queue_item.item_id, + queue_id=self._queue_item.queue_id, + graph_execution_state_id=self._queue_item.session.id, + node=self._invocation.model_dump(), + source_node_id=source_invocation_id, + error_type=e.__class__.__name__, + error=error, + ) + pass - except CanceledException: - # When the user cancels the graph, we first set the cancel event. The event is checked - # between invocations, in this loop. Some invocations are long-running, and we need to - # be able to cancel them mid-execution. - # - # For example, denoising is a long-running invocation with many steps. A step callback - # is executed after each step. This step callback checks if the canceled event is set, - # then raises a CanceledException to stop execution immediately. - # - # When we get a CanceledException, we don't need to do anything - just pass and let the - # loop go to its next iteration, and the cancel event will be handled correctly. - pass + # The session is complete if the all invocations are complete or there was an error + if self._queue_item.session.is_complete() or cancel_event.is_set(): + # Send complete event + self._invoker.services.events.emit_graph_execution_complete( + queue_batch_id=self._queue_item.batch_id, + queue_item_id=self._queue_item.item_id, + queue_id=self._queue_item.queue_id, + graph_execution_state_id=self._queue_item.session.id, + ) + # If we are profiling, stop the profiler and dump the profile & stats + if self._profiler: + profile_path = self._profiler.stop() + stats_path = profile_path.with_suffix(".json") + self._invoker.services.performance_statistics.dump_stats( + graph_execution_state_id=self._queue_item.session.id, output_path=stats_path + ) + # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor + # we don't care about that - suppress the error. + with suppress(GESStatsNotFoundError): + self._invoker.services.performance_statistics.log_stats(self._queue_item.session.id) + self._invoker.services.performance_statistics.reset_stats() - except Exception as e: - error = traceback.format_exc() - - # Save error - self._queue_item.session.set_node_error(self._invocation.id, error) - self._invoker.services.logger.error( - f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}" - ) - self._invoker.services.logger.error(error) - - # Send error event - self._invoker.services.events.emit_invocation_error( - queue_batch_id=self._queue_item.session_id, - queue_item_id=self._queue_item.item_id, - queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session.id, - node=self._invocation.model_dump(), - source_node_id=source_invocation_id, - error_type=e.__class__.__name__, - error=error, - ) - pass - - # The session is complete if the all invocations are complete or there was an error - if self._queue_item.session.is_complete() or cancel_event.is_set(): - # Send complete event - self._invoker.services.events.emit_graph_execution_complete( - queue_batch_id=self._queue_item.batch_id, - queue_item_id=self._queue_item.item_id, - queue_id=self._queue_item.queue_id, - graph_execution_state_id=self._queue_item.session.id, - ) - # If we are profiling, stop the profiler and dump the profile & stats - if self._profiler: - profile_path = self._profiler.stop() - stats_path = profile_path.with_suffix(".json") - self._invoker.services.performance_statistics.dump_stats( - graph_execution_state_id=self._queue_item.session.id, output_path=stats_path - ) - # We'll get a GESStatsNotFoundError if we try to log stats for an untracked graph, but in the processor - # we don't care about that - suppress the error. - with suppress(GESStatsNotFoundError): - self._invoker.services.performance_statistics.log_stats( - self._queue_item.session.id - ) - self._invoker.services.performance_statistics.reset_stats() - - # Set the invocation to None to prepare for the next session - self._invocation = None - else: - # Prepare the next invocation - self._invocation = self._queue_item.session.next() - - # The session is complete, immediately poll for next session - self._queue_item = None - poll_now_event.set() + # Set the invocation to None to prepare for the next session + self._invocation = None else: - # The queue was empty, wait for next polling interval or event to try again - self._invoker.services.logger.debug("Waiting for next polling interval or event") - poll_now_event.wait(self._polling_interval) - continue + # Prepare the next invocation + self._invocation = self._queue_item.session.next() + else: + # The queue was empty, wait for next polling interval or event to try again + self._invoker.services.logger.debug("Waiting for next polling interval or event") + poll_now_event.wait(self._polling_interval) + continue except Exception: # Non-fatal error in processor self._invoker.services.logger.error(