feat(nodes): add whole queue_item to InvocationContextData

No reason to not have the whole thing in there.
This commit is contained in:
psychedelicious 2024-02-18 11:51:50 +11:00 committed by Brandon Rising
parent fafaa09f5e
commit d35f986351
3 changed files with 18 additions and 30 deletions

View File

@ -139,7 +139,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
# Loop over invocations until the session is complete or canceled # Loop over invocations until the session is complete or canceled
while invocation is not None and not cancel_event.is_set(): while invocation is not None and not cancel_event.is_set():
# get the source node id to provide to clients (the prepared node id is not as useful) # get the source node id to provide to clients (the prepared node id is not as useful)
source_node_id = self._queue_item.session.prepared_source_mapping[invocation.id] source_invocation_id = self._queue_item.session.prepared_source_mapping[invocation.id]
# Send starting event # Send starting event
self._invoker.services.events.emit_invocation_started( self._invoker.services.events.emit_invocation_started(
@ -148,7 +148,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
queue_id=self._queue_item.queue_id, queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session_id, graph_execution_state_id=self._queue_item.session_id,
node=invocation.model_dump(), node=invocation.model_dump(),
source_node_id=source_node_id, source_node_id=source_invocation_id,
) )
# Innermost processor try block; any unhandled exception is an invocation error & will fail the graph # Innermost processor try block; any unhandled exception is an invocation error & will fail the graph
@ -159,12 +159,8 @@ class DefaultSessionProcessor(SessionProcessorBase):
# Build invocation context (the node-facing API) # Build invocation context (the node-facing API)
context_data = InvocationContextData( context_data = InvocationContextData(
invocation=invocation, invocation=invocation,
source_node_id=source_node_id, source_invocation_id=source_invocation_id,
session_id=self._queue_item.session.id, queue_item=self._queue_item,
workflow=self._queue_item.workflow,
queue_id=self._queue_item.queue_id,
queue_item_id=self._queue_item.item_id,
batch_id=self._queue_item.batch_id,
) )
context = build_invocation_context( context = build_invocation_context(
context_data=context_data, context_data=context_data,
@ -187,7 +183,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
queue_id=self._queue_item.queue_id, queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id, graph_execution_state_id=self._queue_item.session.id,
node=invocation.model_dump(), node=invocation.model_dump(),
source_node_id=source_node_id, source_node_id=source_invocation_id,
result=outputs.model_dump(), result=outputs.model_dump(),
) )
@ -224,7 +220,7 @@ class DefaultSessionProcessor(SessionProcessorBase):
queue_id=self._queue_item.queue_id, queue_id=self._queue_item.queue_id,
graph_execution_state_id=self._queue_item.session.id, graph_execution_state_id=self._queue_item.session.id,
node=invocation.model_dump(), node=invocation.model_dump(),
source_node_id=source_node_id, source_node_id=source_invocation_id,
error_type=e.__class__.__name__, error_type=e.__class__.__name__,
error=error, error=error,
) )

View File

@ -13,7 +13,6 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel from invokeai.backend.model_manager.load.load_base import LoadedModel
@ -23,6 +22,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit
if TYPE_CHECKING: if TYPE_CHECKING:
from invokeai.app.invocations.baseinvocation import BaseInvocation from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
""" """
The InvocationContext provides access to various services and data about the current invocation. The InvocationContext provides access to various services and data about the current invocation.
@ -49,20 +49,12 @@ Note: The docstrings are in weird places, but that's where they must be to get I
@dataclass @dataclass
class InvocationContextData: class InvocationContextData:
queue_item: "SessionQueueItem"
"""The queue item that is being executed."""
invocation: "BaseInvocation" invocation: "BaseInvocation"
"""The invocation that is being executed.""" """The invocation that is being executed."""
session_id: str source_invocation_id: str
"""The session that is being executed.""" """The ID of the invocation from which the currently executing invocation was prepared."""
queue_id: str
"""The queue in which the session is being executed."""
source_node_id: str
"""The ID of the node from which the currently executing invocation was prepared."""
queue_item_id: int
"""The ID of the queue item that is being executed."""
batch_id: str
"""The ID of the batch that is being executed."""
workflow: Optional[WorkflowWithoutID] = None
"""The workflow associated with this queue item, if any."""
class InvocationContextInterface: class InvocationContextInterface:
@ -191,8 +183,8 @@ class ImagesInterface(InvocationContextInterface):
board_id=board_id_, board_id=board_id_,
metadata=metadata_, metadata=metadata_,
image_origin=ResourceOrigin.INTERNAL, image_origin=ResourceOrigin.INTERNAL,
workflow=self._context_data.workflow, workflow=self._context_data.queue_item.workflow,
session_id=self._context_data.session_id, session_id=self._context_data.queue_item.session_id,
node_id=self._context_data.invocation.id, node_id=self._context_data.invocation.id,
) )

View File

@ -114,12 +114,12 @@ def stable_diffusion_step_callback(
dataURL = image_to_dataURL(image, image_format="JPEG") dataURL = image_to_dataURL(image, image_format="JPEG")
events.emit_generator_progress( events.emit_generator_progress(
queue_id=context_data.queue_id, queue_id=context_data.queue_item.queue_id,
queue_item_id=context_data.queue_item_id, queue_item_id=context_data.queue_item.item_id,
queue_batch_id=context_data.batch_id, queue_batch_id=context_data.queue_item.batch_id,
graph_execution_state_id=context_data.session_id, graph_execution_state_id=context_data.queue_item.session_id,
node_id=context_data.invocation.id, node_id=context_data.invocation.id,
source_node_id=context_data.source_node_id, source_node_id=context_data.source_invocation_id,
progress_image=ProgressImage(width=width, height=height, dataURL=dataURL), progress_image=ProgressImage(width=width, height=height, dataURL=dataURL),
step=intermediate_state.step, step=intermediate_state.step,
order=intermediate_state.order, order=intermediate_state.order,