WIP running graphs as batches

This commit is contained in:
Brandon Rising 2023-07-24 17:41:54 -04:00
parent 4f9c728db0
commit 4bad96d9d6
2 changed files with 33 additions and 3 deletions

View File

@ -241,6 +241,13 @@ InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore
class Batch(BaseModel):
batch_id: Optional[str] = Field(default_factory=uuid.uuid4().__str__, description="Identifier for this batch")
data: list[InvocationsUnion] = Field(description="Mapping of ")
node_id: str = Field(description="ID of the node to batch")
class Graph(BaseModel): class Graph(BaseModel):
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__()) id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me # TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
@ -251,13 +258,16 @@ class Graph(BaseModel):
description="The connections between nodes and their fields in this graph", description="The connections between nodes and their fields in this graph",
default_factory=list, default_factory=list,
) )
batches: list[Batch] = Field(
description="List of batch configs to apply to this session",
default_factory=list,
)
def add_node(self, node: BaseInvocation) -> None: def add_node(self, node: BaseInvocation) -> None:
"""Adds a node to a graph """Adds a node to a graph
:raises NodeAlreadyInGraphError: the node is already present in the graph. :raises NodeAlreadyInGraphError: the node is already present in the graph.
""" """
if node.id in self.nodes: if node.id in self.nodes:
raise NodeAlreadyInGraphError() raise NodeAlreadyInGraphError()
@ -793,6 +803,8 @@ class GraphExecutionState(BaseModel):
# TODO: Store a reference to the graph instead of the actual graph? # TODO: Store a reference to the graph instead of the actual graph?
graph: Graph = Field(description="The graph being executed") graph: Graph = Field(description="The graph being executed")
batch_index: list[int] = Field(description="Tracker for which batch is currently being processed", default_factory=list)
# The graph of materialized nodes # The graph of materialized nodes
execution_graph: Graph = Field( execution_graph: Graph = Field(
description="The expanded graph of activated and executed nodes", description="The expanded graph of activated and executed nodes",
@ -865,6 +877,13 @@ class GraphExecutionState(BaseModel):
if next_node is not None: if next_node is not None:
self._prepare_inputs(next_node) self._prepare_inputs(next_node)
if sum(self.batch_index) != 0:
for index in self.batch_index:
if self.batch_index[index] > 0:
self.executed.clear()
self.batch_index[index] -= 1
return next(self)
# If next is still none, there's no next node, return None # If next is still none, there's no next node, return None
return next_node return next_node
@ -954,7 +973,7 @@ class GraphExecutionState(BaseModel):
new_node = copy.deepcopy(node) new_node = copy.deepcopy(node)
# Create the node id (use a random uuid) # Create the node id (use a random uuid)
new_node.id = str(uuid.uuid4()) new_node.id = str(f"{uuid.uuid4()}-{node.id}")
# Set the iteration index for iteration invocations # Set the iteration index for iteration invocations
if isinstance(new_node, IterateInvocation): if isinstance(new_node, IterateInvocation):

View File

@ -21,11 +21,17 @@ class Invoker:
) -> Optional[str]: ) -> Optional[str]:
"""Determines the next node to invoke and enqueues it, preparing if needed. """Determines the next node to invoke and enqueues it, preparing if needed.
Returns the id of the queued node, or `None` if there are no nodes left to enqueue.""" Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
# Get the next invocation # Get the next invocation
invocation = graph_execution_state.next() invocation = graph_execution_state.next()
if not invocation: if not invocation:
return None return None
(index, batch) = next(((i,b) for i,b in enumerate(graph_execution_state.graph.batches) if b.node_id in invocation.id), (None, None))
if batch:
# assert(isinstance(invocation.type, batch.node_type), f"Type mismatch between nodes and batch config on {invocation.id}")
batch_index = graph_execution_state.batch_index[index]
datum = batch.data[batch_index]
for param in datum.keys():
invocation[param] = datum[param]
# Save the execution state # Save the execution state
self.services.graph_execution_manager.set(graph_execution_state) self.services.graph_execution_manager.set(graph_execution_state)
@ -45,6 +51,11 @@ class Invoker:
def create_execution_state(self, graph: Optional[Graph] = None) -> GraphExecutionState: def create_execution_state(self, graph: Optional[Graph] = None) -> GraphExecutionState:
"""Creates a new execution state for the given graph""" """Creates a new execution state for the given graph"""
new_state = GraphExecutionState(graph=Graph() if graph is None else graph) new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
if graph.batches:
batch_index = list()
for batch in graph.batches:
batch_index.append(len(batch.data)-1)
new_state.batch_index = batch_index
self.services.graph_execution_manager.set(new_state) self.services.graph_execution_manager.set(new_state)
return new_state return new_state