feat(batch): use node_path instead of node_id to create batched sessions

This commit is contained in:
psychedelicious 2023-09-05 17:30:27 +10:00
parent 26f9ac9f21
commit e8a4a654ac
3 changed files with 20 additions and 25 deletions

View File

@ -3,7 +3,6 @@ from itertools import product
from typing import Optional from typing import Optional
from uuid import uuid4 from uuid import uuid4
import networkx as nx
from fastapi_events.handlers.local import local_handler from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event from fastapi_events.typing import Event
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -114,20 +113,16 @@ class BatchManager(BatchManagerBase):
) -> GraphExecutionState: ) -> GraphExecutionState:
graph = batch_process.graph.copy(deep=True) graph = batch_process.graph.copy(deep=True)
batch = batch_process.batch batch = batch_process.batch
g = graph.nx_graph_flat() for index, bdl in enumerate(batch.data):
sorted_nodes = nx.topological_sort(g) for bd in bdl:
for npath in sorted_nodes: node = graph.get_node(bd.node_path)
node = graph.get_node(npath) if node is None:
for index, bdl in enumerate(batch.data):
relevant_bd = [bd for bd in bdl if bd.node_id in node.id]
if not relevant_bd:
continue continue
for bd in relevant_bd: batch_index = batch_indices[index]
batch_index = batch_indices[index] datum = bd.items[batch_index]
datum = bd.items[batch_index] key = bd.field_name
key = bd.field_name node.__dict__[key] = datum
node.__dict__[key] = datum graph.update_node(bd.node_path, node)
graph.update_node(npath, node)
return GraphExecutionState(graph=graph) return GraphExecutionState(graph=graph)

View File

@ -17,7 +17,7 @@ class BatchData(BaseModel):
A batch data collection. A batch data collection.
""" """
node_id: str = Field(description="The node into which this batch data collection will be substituted.") node_path: str = Field(description="The node into which this batch data collection will be substituted.")
field_name: str = Field(description="The field into which this batch data collection will be substituted.") field_name: str = Field(description="The field into which this batch data collection will be substituted.")
items: list[BatchDataType] = Field( items: list[BatchDataType] = Field(
default_factory=list, description="The list of items to substitute into the node/field." default_factory=list, description="The list of items to substitute into the node/field."
@ -64,7 +64,7 @@ class Batch(BaseModel):
count: int = 0 count: int = 0
for batch_data in v: for batch_data in v:
for datum in batch_data: for datum in batch_data:
paths.add((datum.node_id, datum.field_name)) paths.add((datum.node_path, datum.field_name))
count += 1 count += 1
if len(paths) != count: if len(paths) != count:
raise ValueError("Each batch data must have unique node_id and field_name") raise ValueError("Each batch data must have unique node_id and field_name")

View File

@ -41,7 +41,7 @@ def simple_batch():
data=[ data=[
[ [
BatchData( BatchData(
node_id="1", node_path="1",
field_name="prompt", field_name="prompt",
items=[ items=[
"Tomato sushi", "Tomato sushi",
@ -54,7 +54,7 @@ def simple_batch():
], ],
[ [
BatchData( BatchData(
node_id="2", node_path="2",
field_name="prompt", field_name="prompt",
items=[ items=[
"Ume sushi", "Ume sushi",
@ -196,11 +196,11 @@ def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batch):
def test_cannot_create_bad_batches(): def test_cannot_create_bad_batches():
batch = None batch = None
try: try:
batch = Batch( # This batch has a duplicate node_id|fieldname combo batch = Batch( # This batch has a duplicate node_path|fieldname combo
data=[ data=[
[ [
BatchData( BatchData(
node_id="1", node_path="1",
field_name="prompt", field_name="prompt",
items=[ items=[
"Tomato sushi", "Tomato sushi",
@ -209,7 +209,7 @@ def test_cannot_create_bad_batches():
], ],
[ [
BatchData( BatchData(
node_id="1", node_path="1",
field_name="prompt", field_name="prompt",
items=[ items=[
"Ume sushi", "Ume sushi",
@ -225,14 +225,14 @@ def test_cannot_create_bad_batches():
data=[ data=[
[ [
BatchData( BatchData(
node_id="1", node_path="1",
field_name="prompt", field_name="prompt",
items=[ items=[
"Tomato sushi", "Tomato sushi",
], ],
), ),
BatchData( BatchData(
node_id="1", node_path="1",
field_name="prompt", field_name="prompt",
items=[ items=[
"Tomato sushi", "Tomato sushi",
@ -242,7 +242,7 @@ def test_cannot_create_bad_batches():
], ],
[ [
BatchData( BatchData(
node_id="1", node_path="1",
field_name="prompt", field_name="prompt",
items=[ items=[
"Ume sushi", "Ume sushi",
@ -258,7 +258,7 @@ def test_cannot_create_bad_batches():
data=[ data=[
[ [
BatchData( BatchData(
node_id="1", node_path="1",
field_name="prompt", field_name="prompt",
items=["Tomato sushi", 5], items=["Tomato sushi", 5],
), ),