mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(batch): use node_path instead of node_id to create batched sessions
This commit is contained in:
parent
26f9ac9f21
commit
e8a4a654ac
@ -3,7 +3,6 @@ from itertools import product
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import networkx as nx
|
||||
from fastapi_events.handlers.local import local_handler
|
||||
from fastapi_events.typing import Event
|
||||
from pydantic import BaseModel, Field
|
||||
@ -114,20 +113,16 @@ class BatchManager(BatchManagerBase):
|
||||
) -> GraphExecutionState:
|
||||
graph = batch_process.graph.copy(deep=True)
|
||||
batch = batch_process.batch
|
||||
g = graph.nx_graph_flat()
|
||||
sorted_nodes = nx.topological_sort(g)
|
||||
for npath in sorted_nodes:
|
||||
node = graph.get_node(npath)
|
||||
for index, bdl in enumerate(batch.data):
|
||||
relevant_bd = [bd for bd in bdl if bd.node_id in node.id]
|
||||
if not relevant_bd:
|
||||
for index, bdl in enumerate(batch.data):
|
||||
for bd in bdl:
|
||||
node = graph.get_node(bd.node_path)
|
||||
if node is None:
|
||||
continue
|
||||
for bd in relevant_bd:
|
||||
batch_index = batch_indices[index]
|
||||
datum = bd.items[batch_index]
|
||||
key = bd.field_name
|
||||
node.__dict__[key] = datum
|
||||
graph.update_node(npath, node)
|
||||
batch_index = batch_indices[index]
|
||||
datum = bd.items[batch_index]
|
||||
key = bd.field_name
|
||||
node.__dict__[key] = datum
|
||||
graph.update_node(bd.node_path, node)
|
||||
|
||||
return GraphExecutionState(graph=graph)
|
||||
|
||||
|
@ -17,7 +17,7 @@ class BatchData(BaseModel):
|
||||
A batch data collection.
|
||||
"""
|
||||
|
||||
node_id: str = Field(description="The node into which this batch data collection will be substituted.")
|
||||
node_path: str = Field(description="The node into which this batch data collection will be substituted.")
|
||||
field_name: str = Field(description="The field into which this batch data collection will be substituted.")
|
||||
items: list[BatchDataType] = Field(
|
||||
default_factory=list, description="The list of items to substitute into the node/field."
|
||||
@ -64,7 +64,7 @@ class Batch(BaseModel):
|
||||
count: int = 0
|
||||
for batch_data in v:
|
||||
for datum in batch_data:
|
||||
paths.add((datum.node_id, datum.field_name))
|
||||
paths.add((datum.node_path, datum.field_name))
|
||||
count += 1
|
||||
if len(paths) != count:
|
||||
raise ValueError("Each batch data must have unique node_id and field_name")
|
||||
|
@ -41,7 +41,7 @@ def simple_batch():
|
||||
data=[
|
||||
[
|
||||
BatchData(
|
||||
node_id="1",
|
||||
node_path="1",
|
||||
field_name="prompt",
|
||||
items=[
|
||||
"Tomato sushi",
|
||||
@ -54,7 +54,7 @@ def simple_batch():
|
||||
],
|
||||
[
|
||||
BatchData(
|
||||
node_id="2",
|
||||
node_path="2",
|
||||
field_name="prompt",
|
||||
items=[
|
||||
"Ume sushi",
|
||||
@ -196,11 +196,11 @@ def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batch):
|
||||
def test_cannot_create_bad_batches():
|
||||
batch = None
|
||||
try:
|
||||
batch = Batch( # This batch has a duplicate node_id|fieldname combo
|
||||
batch = Batch( # This batch has a duplicate node_path|fieldname combo
|
||||
data=[
|
||||
[
|
||||
BatchData(
|
||||
node_id="1",
|
||||
node_path="1",
|
||||
field_name="prompt",
|
||||
items=[
|
||||
"Tomato sushi",
|
||||
@ -209,7 +209,7 @@ def test_cannot_create_bad_batches():
|
||||
],
|
||||
[
|
||||
BatchData(
|
||||
node_id="1",
|
||||
node_path="1",
|
||||
field_name="prompt",
|
||||
items=[
|
||||
"Ume sushi",
|
||||
@ -225,14 +225,14 @@ def test_cannot_create_bad_batches():
|
||||
data=[
|
||||
[
|
||||
BatchData(
|
||||
node_id="1",
|
||||
node_path="1",
|
||||
field_name="prompt",
|
||||
items=[
|
||||
"Tomato sushi",
|
||||
],
|
||||
),
|
||||
BatchData(
|
||||
node_id="1",
|
||||
node_path="1",
|
||||
field_name="prompt",
|
||||
items=[
|
||||
"Tomato sushi",
|
||||
@ -242,7 +242,7 @@ def test_cannot_create_bad_batches():
|
||||
],
|
||||
[
|
||||
BatchData(
|
||||
node_id="1",
|
||||
node_path="1",
|
||||
field_name="prompt",
|
||||
items=[
|
||||
"Ume sushi",
|
||||
@ -258,7 +258,7 @@ def test_cannot_create_bad_batches():
|
||||
data=[
|
||||
[
|
||||
BatchData(
|
||||
node_id="1",
|
||||
node_path="1",
|
||||
field_name="prompt",
|
||||
items=["Tomato sushi", 5],
|
||||
),
|
||||
|
Loading…
Reference in New Issue
Block a user