diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index 7d761e627f..9ba726bff3 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -7,6 +7,7 @@ from typing import Optional from fastapi_events.handlers.local import local_handler from fastapi_events.typing import Event as FastAPIEvent +from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invocation_stats.invocation_stats_common import GESStatsNotFoundError from invokeai.app.services.session_processor.session_processor_common import CanceledException @@ -23,6 +24,7 @@ class DefaultSessionProcessor(SessionProcessorBase): def start(self, invoker: Invoker, thread_limit: int = 1, polling_interval: int = 1) -> None: self._invoker: Invoker = invoker self._queue_item: Optional[SessionQueueItem] = None + self._invocation: Optional[BaseInvocation] = None self._resume_event = ThreadEvent() self._stop_event = ThreadEvent() @@ -134,12 +136,12 @@ class DefaultSessionProcessor(SessionProcessorBase): profiler.start(profile_id=self._queue_item.session_id) # Prepare invocations and take the first - invocation = self._queue_item.session.next() + self._invocation = self._queue_item.session.next() # Loop over invocations until the session is complete or canceled - while invocation is not None and not cancel_event.is_set(): + while self._invocation is not None and not cancel_event.is_set(): # get the source node id to provide to clients (the prepared node id is not as useful) - source_invocation_id = self._queue_item.session.prepared_source_mapping[invocation.id] + source_invocation_id = self._queue_item.session.prepared_source_mapping[self._invocation.id] # Send starting event self._invoker.services.events.emit_invocation_started( @@ -147,18 +149,18 @@ class DefaultSessionProcessor(SessionProcessorBase): queue_item_id=self._queue_item.item_id, queue_id=self._queue_item.queue_id, graph_execution_state_id=self._queue_item.session_id, - node=invocation.model_dump(), + node=self._invocation.model_dump(), source_node_id=source_invocation_id, ) # Innermost processor try block; any unhandled exception is an invocation error & will fail the graph try: with self._invoker.services.performance_statistics.collect_stats( - invocation, self._queue_item.session.id + self._invocation, self._queue_item.session.id ): # Build invocation context (the node-facing API) data = InvocationContextData( - invocation=invocation, + invocation=self._invocation, source_invocation_id=source_invocation_id, queue_item=self._queue_item, ) @@ -169,12 +171,12 @@ class DefaultSessionProcessor(SessionProcessorBase): ) # Invoke the node - outputs = invocation.invoke_internal( + outputs = self._invocation.invoke_internal( context=context, services=self._invoker.services ) # Save outputs and history - self._queue_item.session.complete(invocation.id, outputs) + self._queue_item.session.complete(self._invocation.id, outputs) # Send complete event self._invoker.services.events.emit_invocation_complete( @@ -182,7 +184,7 @@ class DefaultSessionProcessor(SessionProcessorBase): queue_item_id=self._queue_item.item_id, queue_id=self._queue_item.queue_id, graph_execution_state_id=self._queue_item.session.id, - node=invocation.model_dump(), + node=self._invocation.model_dump(), source_node_id=source_invocation_id, result=outputs.model_dump(), ) @@ -208,9 +210,9 @@ class DefaultSessionProcessor(SessionProcessorBase): error = traceback.format_exc() # Save error - self._queue_item.session.set_node_error(invocation.id, error) + self._queue_item.session.set_node_error(self._invocation.id, error) self._invoker.services.logger.error( - f"Error while invoking session {self._queue_item.session_id}, invocation {invocation.id} ({invocation.get_type()}):\n{e}" + f"Error while invoking session {self._queue_item.session_id}, invocation {self._invocation.id} ({self._invocation.get_type()}):\n{e}" ) # Send error event @@ -219,7 +221,7 @@ class DefaultSessionProcessor(SessionProcessorBase): queue_item_id=self._queue_item.item_id, queue_id=self._queue_item.queue_id, graph_execution_state_id=self._queue_item.session.id, - node=invocation.model_dump(), + node=self._invocation.model_dump(), source_node_id=source_invocation_id, error_type=e.__class__.__name__, error=error, @@ -236,10 +238,10 @@ class DefaultSessionProcessor(SessionProcessorBase): ) # Save the stats and stop the profiler if it's running stats_cleanup(self._queue_item.session.id) - invocation = None + self._invocation = None else: # Prepare the next invocation - invocation = self._queue_item.session.next() + self._invocation = self._queue_item.session.next() # The session is complete, immediately poll for next session self._queue_item = None