WIP running graphs as batches

This commit is contained in:
Brandon Rising 2023-07-24 17:41:54 -04:00
parent 4f9c728db0
commit 4bad96d9d6
2 changed files with 33 additions and 3 deletions

View File

@ -241,6 +241,13 @@ InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore
class Batch(BaseModel):
batch_id: Optional[str] = Field(default_factory=uuid.uuid4().__str__, description="Identifier for this batch")
data: list[InvocationsUnion] = Field(description="Mapping of ")
node_id: str = Field(description="ID of the node to batch")
class Graph(BaseModel):
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
@ -251,13 +258,16 @@ class Graph(BaseModel):
description="The connections between nodes and their fields in this graph",
default_factory=list,
)
batches: list[Batch] = Field(
description="List of batch configs to apply to this session",
default_factory=list,
)
def add_node(self, node: BaseInvocation) -> None:
"""Adds a node to a graph
:raises NodeAlreadyInGraphError: the node is already present in the graph.
"""
if node.id in self.nodes:
raise NodeAlreadyInGraphError()
@ -793,6 +803,8 @@ class GraphExecutionState(BaseModel):
# TODO: Store a reference to the graph instead of the actual graph?
graph: Graph = Field(description="The graph being executed")
batch_index: list[int] = Field(description="Tracker for which batch is currently being processed", default_factory=list)
# The graph of materialized nodes
execution_graph: Graph = Field(
description="The expanded graph of activated and executed nodes",
@ -865,6 +877,13 @@ class GraphExecutionState(BaseModel):
if next_node is not None:
self._prepare_inputs(next_node)
if sum(self.batch_index) != 0:
for index in self.batch_index:
if self.batch_index[index] > 0:
self.executed.clear()
self.batch_index[index] -= 1
return next(self)
# If next is still none, there's no next node, return None
return next_node
@ -954,7 +973,7 @@ class GraphExecutionState(BaseModel):
new_node = copy.deepcopy(node)
# Create the node id (use a random uuid)
new_node.id = str(uuid.uuid4())
new_node.id = str(f"{uuid.uuid4()}-{node.id}")
# Set the iteration index for iteration invocations
if isinstance(new_node, IterateInvocation):

View File

@ -21,11 +21,17 @@ class Invoker:
) -> Optional[str]:
"""Determines the next node to invoke and enqueues it, preparing if needed.
Returns the id of the queued node, or `None` if there are no nodes left to enqueue."""
# Get the next invocation
invocation = graph_execution_state.next()
if not invocation:
return None
(index, batch) = next(((i,b) for i,b in enumerate(graph_execution_state.graph.batches) if b.node_id in invocation.id), (None, None))
if batch:
# assert(isinstance(invocation.type, batch.node_type), f"Type mismatch between nodes and batch config on {invocation.id}")
batch_index = graph_execution_state.batch_index[index]
datum = batch.data[batch_index]
for param in datum.keys():
invocation[param] = datum[param]
# Save the execution state
self.services.graph_execution_manager.set(graph_execution_state)
@ -45,6 +51,11 @@ class Invoker:
def create_execution_state(self, graph: Optional[Graph] = None) -> GraphExecutionState:
"""Creates a new execution state for the given graph"""
new_state = GraphExecutionState(graph=Graph() if graph is None else graph)
if graph.batches:
batch_index = list()
for batch in graph.batches:
batch_index.append(len(batch.data)-1)
new_state.batch_index = batch_index
self.services.graph_execution_manager.set(new_state)
return new_state