feat(batch): use node_path instead of node_id to create batched sessions

This commit is contained in:
psychedelicious 2023-09-05 17:30:27 +10:00
parent 26f9ac9f21
commit e8a4a654ac
3 changed files with 20 additions and 25 deletions

View File

@ -3,7 +3,6 @@ from itertools import product
from typing import Optional
from uuid import uuid4
import networkx as nx
from fastapi_events.handlers.local import local_handler
from fastapi_events.typing import Event
from pydantic import BaseModel, Field
@ -114,20 +113,16 @@ class BatchManager(BatchManagerBase):
) -> GraphExecutionState:
graph = batch_process.graph.copy(deep=True)
batch = batch_process.batch
g = graph.nx_graph_flat()
sorted_nodes = nx.topological_sort(g)
for npath in sorted_nodes:
node = graph.get_node(npath)
for index, bdl in enumerate(batch.data):
relevant_bd = [bd for bd in bdl if bd.node_id in node.id]
if not relevant_bd:
for index, bdl in enumerate(batch.data):
for bd in bdl:
node = graph.get_node(bd.node_path)
if node is None:
continue
for bd in relevant_bd:
batch_index = batch_indices[index]
datum = bd.items[batch_index]
key = bd.field_name
node.__dict__[key] = datum
graph.update_node(npath, node)
batch_index = batch_indices[index]
datum = bd.items[batch_index]
key = bd.field_name
node.__dict__[key] = datum
graph.update_node(bd.node_path, node)
return GraphExecutionState(graph=graph)

View File

@ -17,7 +17,7 @@ class BatchData(BaseModel):
A batch data collection.
"""
node_id: str = Field(description="The node into which this batch data collection will be substituted.")
node_path: str = Field(description="The node into which this batch data collection will be substituted.")
field_name: str = Field(description="The field into which this batch data collection will be substituted.")
items: list[BatchDataType] = Field(
default_factory=list, description="The list of items to substitute into the node/field."
@ -64,7 +64,7 @@ class Batch(BaseModel):
count: int = 0
for batch_data in v:
for datum in batch_data:
paths.add((datum.node_id, datum.field_name))
paths.add((datum.node_path, datum.field_name))
count += 1
if len(paths) != count:
raise ValueError("Each batch data must have unique node_id and field_name")

View File

@ -41,7 +41,7 @@ def simple_batch():
data=[
[
BatchData(
node_id="1",
node_path="1",
field_name="prompt",
items=[
"Tomato sushi",
@ -54,7 +54,7 @@ def simple_batch():
],
[
BatchData(
node_id="2",
node_path="2",
field_name="prompt",
items=[
"Ume sushi",
@ -196,11 +196,11 @@ def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batch):
def test_cannot_create_bad_batches():
batch = None
try:
batch = Batch( # This batch has a duplicate node_id|fieldname combo
batch = Batch( # This batch has a duplicate node_path|fieldname combo
data=[
[
BatchData(
node_id="1",
node_path="1",
field_name="prompt",
items=[
"Tomato sushi",
@ -209,7 +209,7 @@ def test_cannot_create_bad_batches():
],
[
BatchData(
node_id="1",
node_path="1",
field_name="prompt",
items=[
"Ume sushi",
@ -225,14 +225,14 @@ def test_cannot_create_bad_batches():
data=[
[
BatchData(
node_id="1",
node_path="1",
field_name="prompt",
items=[
"Tomato sushi",
],
),
BatchData(
node_id="1",
node_path="1",
field_name="prompt",
items=[
"Tomato sushi",
@ -242,7 +242,7 @@ def test_cannot_create_bad_batches():
],
[
BatchData(
node_id="1",
node_path="1",
field_name="prompt",
items=[
"Ume sushi",
@ -258,7 +258,7 @@ def test_cannot_create_bad_batches():
data=[
[
BatchData(
node_id="1",
node_path="1",
field_name="prompt",
items=["Tomato sushi", 5],
),