From 4bad96d9d67e9f9697fb9e82378716875bd9aee2 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Mon, 24 Jul 2023 17:41:54 -0400 Subject: [PATCH] WIP running graphs as batches --- invokeai/app/services/graph.py | 23 +++++++++++++++++++++-- invokeai/app/services/invoker.py | 13 ++++++++++++- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/invokeai/app/services/graph.py b/invokeai/app/services/graph.py index 24096da29b..fb3855f8ed 100644 --- a/invokeai/app/services/graph.py +++ b/invokeai/app/services/graph.py @@ -241,6 +241,13 @@ InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore + +class Batch(BaseModel): + batch_id: Optional[str] = Field(default_factory=uuid.uuid4().__str__, description="Identifier for this batch") + data: list[InvocationsUnion] = Field(description="Mapping of ") + node_id: str = Field(description="ID of the node to batch") + + class Graph(BaseModel): id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__()) # TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me @@ -251,13 +258,16 @@ class Graph(BaseModel): description="The connections between nodes and their fields in this graph", default_factory=list, ) + batches: list[Batch] = Field( + description="List of batch configs to apply to this session", + default_factory=list, + ) def add_node(self, node: BaseInvocation) -> None: """Adds a node to a graph :raises NodeAlreadyInGraphError: the node is already present in the graph. """ - if node.id in self.nodes: raise NodeAlreadyInGraphError() @@ -793,6 +803,8 @@ class GraphExecutionState(BaseModel): # TODO: Store a reference to the graph instead of the actual graph? graph: Graph = Field(description="The graph being executed") + batch_index: list[int] = Field(description="Tracker for which batch is currently being processed", default_factory=list) + # The graph of materialized nodes execution_graph: Graph = Field( description="The expanded graph of activated and executed nodes", @@ -865,6 +877,13 @@ class GraphExecutionState(BaseModel): if next_node is not None: self._prepare_inputs(next_node) + if sum(self.batch_index) != 0: + for index in self.batch_index: + if self.batch_index[index] > 0: + self.executed.clear() + self.batch_index[index] -= 1 + return next(self) + # If next is still none, there's no next node, return None return next_node @@ -954,7 +973,7 @@ class GraphExecutionState(BaseModel): new_node = copy.deepcopy(node) # Create the node id (use a random uuid) - new_node.id = str(uuid.uuid4()) + new_node.id = str(f"{uuid.uuid4()}-{node.id}") # Set the iteration index for iteration invocations if isinstance(new_node, IterateInvocation): diff --git a/invokeai/app/services/invoker.py b/invokeai/app/services/invoker.py index 951d3b17c4..3092b15cd3 100644 --- a/invokeai/app/services/invoker.py +++ b/invokeai/app/services/invoker.py @@ -21,11 +21,17 @@ class Invoker: ) -> Optional[str]: """Determines the next node to invoke and enqueues it, preparing if needed. Returns the id of the queued node, or `None` if there are no nodes left to enqueue.""" - # Get the next invocation invocation = graph_execution_state.next() if not invocation: return None + (index, batch) = next(((i,b) for i,b in enumerate(graph_execution_state.graph.batches) if b.node_id in invocation.id), (None, None)) + if batch: + # assert(isinstance(invocation.type, batch.node_type), f"Type mismatch between nodes and batch config on {invocation.id}") + batch_index = graph_execution_state.batch_index[index] + datum = batch.data[batch_index] + for param in datum.keys(): + invocation[param] = datum[param] # Save the execution state self.services.graph_execution_manager.set(graph_execution_state) @@ -45,6 +51,11 @@ class Invoker: def create_execution_state(self, graph: Optional[Graph] = None) -> GraphExecutionState: """Creates a new execution state for the given graph""" new_state = GraphExecutionState(graph=Graph() if graph is None else graph) + if graph.batches: + batch_index = list() + for batch in graph.batches: + batch_index.append(len(batch.data)-1) + new_state.batch_index = batch_index self.services.graph_execution_manager.set(new_state) return new_state