diff --git a/invokeai/app/services/session_processor/session_processor_default.py b/invokeai/app/services/session_processor/session_processor_default.py index e49f79bcf3..dc08fc8345 100644 --- a/invokeai/app/services/session_processor/session_processor_default.py +++ b/invokeai/app/services/session_processor/session_processor_default.py @@ -157,13 +157,13 @@ class DefaultSessionProcessor(SessionProcessorBase): invocation, self._queue_item.session.id ): # Build invocation context (the node-facing API) - context_data = InvocationContextData( + data = InvocationContextData( invocation=invocation, source_invocation_id=source_invocation_id, queue_item=self._queue_item, ) context = build_invocation_context( - context_data=context_data, + data=data, services=self._invoker.services, cancel_event=self._cancel_event, ) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 994c99dc45..f8425523bf 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -58,9 +58,9 @@ class InvocationContextData: class InvocationContextInterface: - def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None: + def __init__(self, services: InvocationServices, data: InvocationContextData) -> None: self._services = services - self._context_data = context_data + self._data = data class BoardsInterface(InvocationContextInterface): @@ -166,26 +166,26 @@ class ImagesInterface(InvocationContextInterface): metadata_ = None if metadata: metadata_ = metadata - elif isinstance(self._context_data.invocation, WithMetadata): - metadata_ = self._context_data.invocation.metadata + elif isinstance(self._data.invocation, WithMetadata): + metadata_ = self._data.invocation.metadata # If `board_id` is provided directly, use that. Else, use the board provided by `WithBoard`, falling back to None. board_id_ = None if board_id: board_id_ = board_id - elif isinstance(self._context_data.invocation, WithBoard) and self._context_data.invocation.board: - board_id_ = self._context_data.invocation.board.board_id + elif isinstance(self._data.invocation, WithBoard) and self._data.invocation.board: + board_id_ = self._data.invocation.board.board_id return self._services.images.create( image=image, - is_intermediate=self._context_data.invocation.is_intermediate, + is_intermediate=self._data.invocation.is_intermediate, image_category=image_category, board_id=board_id_, metadata=metadata_, image_origin=ResourceOrigin.INTERNAL, - workflow=self._context_data.queue_item.workflow, - session_id=self._context_data.queue_item.session_id, - node_id=self._context_data.invocation.id, + workflow=self._data.queue_item.workflow, + session_id=self._data.queue_item.session_id, + node_id=self._data.invocation.id, ) def get_pil(self, image_name: str, mode: IMAGE_MODES | None = None) -> Image: @@ -285,7 +285,7 @@ class ModelsInterface(InvocationContextInterface): # the event payloads. return self._services.model_manager.load_model_by_key( - key=key, submodel_type=submodel_type, context_data=self._context_data + key=key, submodel_type=submodel_type, context_data=self._data ) def load_by_attrs( @@ -304,7 +304,7 @@ class ModelsInterface(InvocationContextInterface): base_model=base_model, model_type=model_type, submodel=submodel, - context_data=self._context_data, + context_data=self._data, ) def get_config(self, key: str) -> AnyModelConfig: @@ -364,9 +364,9 @@ class ConfigInterface(InvocationContextInterface): class UtilInterface(InvocationContextInterface): def __init__( - self, services: InvocationServices, context_data: InvocationContextData, cancel_event: threading.Event + self, services: InvocationServices, data: InvocationContextData, cancel_event: threading.Event ) -> None: - super().__init__(services, context_data) + super().__init__(services, data) self._cancel_event = cancel_event def is_canceled(self) -> bool: @@ -385,7 +385,7 @@ class UtilInterface(InvocationContextInterface): """ stable_diffusion_step_callback( - context_data=self._context_data, + context_data=self._data, intermediate_state=intermediate_state, base_model=base_model, events=self._services.events, @@ -408,7 +408,7 @@ class InvocationContext: config: ConfigInterface, util: UtilInterface, boards: BoardsInterface, - context_data: InvocationContextData, + data: InvocationContextData, services: InvocationServices, ) -> None: self.images = images @@ -427,7 +427,7 @@ class InvocationContext: """Provides utility methods.""" self.boards = boards """Provides methods to interact with boards.""" - self._data = context_data + self._data = data """Provides data about the current queue item and invocation. This is an internal API and may change without warning.""" self._services = services """Provides access to the full application services. This is an internal API and may change without warning.""" @@ -435,7 +435,7 @@ class InvocationContext: def build_invocation_context( services: InvocationServices, - context_data: InvocationContextData, + data: InvocationContextData, cancel_event: threading.Event, ) -> InvocationContext: """ @@ -445,14 +445,14 @@ def build_invocation_context( :param invocation_context_data: The invocation context data. """ - logger = LoggerInterface(services=services, context_data=context_data) - images = ImagesInterface(services=services, context_data=context_data) - tensors = TensorsInterface(services=services, context_data=context_data) - models = ModelsInterface(services=services, context_data=context_data) - config = ConfigInterface(services=services, context_data=context_data) - util = UtilInterface(services=services, context_data=context_data, cancel_event=cancel_event) - conditioning = ConditioningInterface(services=services, context_data=context_data) - boards = BoardsInterface(services=services, context_data=context_data) + logger = LoggerInterface(services=services, data=data) + images = ImagesInterface(services=services, data=data) + tensors = TensorsInterface(services=services, data=data) + models = ModelsInterface(services=services, data=data) + config = ConfigInterface(services=services, data=data) + util = UtilInterface(services=services, data=data, cancel_event=cancel_event) + conditioning = ConditioningInterface(services=services, data=data) + boards = BoardsInterface(services=services, data=data) ctx = InvocationContext( images=images, @@ -460,7 +460,7 @@ def build_invocation_context( config=config, tensors=tensors, models=models, - context_data=context_data, + data=data, util=util, conditioning=conditioning, services=services, diff --git a/tests/test_graph_execution_state.py b/tests/test_graph_execution_state.py index f839a4a878..9cff502acf 100644 --- a/tests/test_graph_execution_state.py +++ b/tests/test_graph_execution_state.py @@ -86,7 +86,7 @@ def invoke_next(g: GraphExecutionState, services: InvocationServices) -> tuple[B InvocationContext( conditioning=None, config=None, - context_data=None, + data=None, images=None, tensors=None, logger=None,