mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(batch): use node_path instead of node_id to create batched sessions
This commit is contained in:
parent
26f9ac9f21
commit
e8a4a654ac
@ -3,7 +3,6 @@ from itertools import product
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import networkx as nx
|
|
||||||
from fastapi_events.handlers.local import local_handler
|
from fastapi_events.handlers.local import local_handler
|
||||||
from fastapi_events.typing import Event
|
from fastapi_events.typing import Event
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
@ -114,20 +113,16 @@ class BatchManager(BatchManagerBase):
|
|||||||
) -> GraphExecutionState:
|
) -> GraphExecutionState:
|
||||||
graph = batch_process.graph.copy(deep=True)
|
graph = batch_process.graph.copy(deep=True)
|
||||||
batch = batch_process.batch
|
batch = batch_process.batch
|
||||||
g = graph.nx_graph_flat()
|
for index, bdl in enumerate(batch.data):
|
||||||
sorted_nodes = nx.topological_sort(g)
|
for bd in bdl:
|
||||||
for npath in sorted_nodes:
|
node = graph.get_node(bd.node_path)
|
||||||
node = graph.get_node(npath)
|
if node is None:
|
||||||
for index, bdl in enumerate(batch.data):
|
|
||||||
relevant_bd = [bd for bd in bdl if bd.node_id in node.id]
|
|
||||||
if not relevant_bd:
|
|
||||||
continue
|
continue
|
||||||
for bd in relevant_bd:
|
batch_index = batch_indices[index]
|
||||||
batch_index = batch_indices[index]
|
datum = bd.items[batch_index]
|
||||||
datum = bd.items[batch_index]
|
key = bd.field_name
|
||||||
key = bd.field_name
|
node.__dict__[key] = datum
|
||||||
node.__dict__[key] = datum
|
graph.update_node(bd.node_path, node)
|
||||||
graph.update_node(npath, node)
|
|
||||||
|
|
||||||
return GraphExecutionState(graph=graph)
|
return GraphExecutionState(graph=graph)
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ class BatchData(BaseModel):
|
|||||||
A batch data collection.
|
A batch data collection.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
node_id: str = Field(description="The node into which this batch data collection will be substituted.")
|
node_path: str = Field(description="The node into which this batch data collection will be substituted.")
|
||||||
field_name: str = Field(description="The field into which this batch data collection will be substituted.")
|
field_name: str = Field(description="The field into which this batch data collection will be substituted.")
|
||||||
items: list[BatchDataType] = Field(
|
items: list[BatchDataType] = Field(
|
||||||
default_factory=list, description="The list of items to substitute into the node/field."
|
default_factory=list, description="The list of items to substitute into the node/field."
|
||||||
@ -64,7 +64,7 @@ class Batch(BaseModel):
|
|||||||
count: int = 0
|
count: int = 0
|
||||||
for batch_data in v:
|
for batch_data in v:
|
||||||
for datum in batch_data:
|
for datum in batch_data:
|
||||||
paths.add((datum.node_id, datum.field_name))
|
paths.add((datum.node_path, datum.field_name))
|
||||||
count += 1
|
count += 1
|
||||||
if len(paths) != count:
|
if len(paths) != count:
|
||||||
raise ValueError("Each batch data must have unique node_id and field_name")
|
raise ValueError("Each batch data must have unique node_id and field_name")
|
||||||
|
@ -41,7 +41,7 @@ def simple_batch():
|
|||||||
data=[
|
data=[
|
||||||
[
|
[
|
||||||
BatchData(
|
BatchData(
|
||||||
node_id="1",
|
node_path="1",
|
||||||
field_name="prompt",
|
field_name="prompt",
|
||||||
items=[
|
items=[
|
||||||
"Tomato sushi",
|
"Tomato sushi",
|
||||||
@ -54,7 +54,7 @@ def simple_batch():
|
|||||||
],
|
],
|
||||||
[
|
[
|
||||||
BatchData(
|
BatchData(
|
||||||
node_id="2",
|
node_path="2",
|
||||||
field_name="prompt",
|
field_name="prompt",
|
||||||
items=[
|
items=[
|
||||||
"Ume sushi",
|
"Ume sushi",
|
||||||
@ -196,11 +196,11 @@ def test_can_create_batch(mock_invoker: Invoker, simple_graph, simple_batch):
|
|||||||
def test_cannot_create_bad_batches():
|
def test_cannot_create_bad_batches():
|
||||||
batch = None
|
batch = None
|
||||||
try:
|
try:
|
||||||
batch = Batch( # This batch has a duplicate node_id|fieldname combo
|
batch = Batch( # This batch has a duplicate node_path|fieldname combo
|
||||||
data=[
|
data=[
|
||||||
[
|
[
|
||||||
BatchData(
|
BatchData(
|
||||||
node_id="1",
|
node_path="1",
|
||||||
field_name="prompt",
|
field_name="prompt",
|
||||||
items=[
|
items=[
|
||||||
"Tomato sushi",
|
"Tomato sushi",
|
||||||
@ -209,7 +209,7 @@ def test_cannot_create_bad_batches():
|
|||||||
],
|
],
|
||||||
[
|
[
|
||||||
BatchData(
|
BatchData(
|
||||||
node_id="1",
|
node_path="1",
|
||||||
field_name="prompt",
|
field_name="prompt",
|
||||||
items=[
|
items=[
|
||||||
"Ume sushi",
|
"Ume sushi",
|
||||||
@ -225,14 +225,14 @@ def test_cannot_create_bad_batches():
|
|||||||
data=[
|
data=[
|
||||||
[
|
[
|
||||||
BatchData(
|
BatchData(
|
||||||
node_id="1",
|
node_path="1",
|
||||||
field_name="prompt",
|
field_name="prompt",
|
||||||
items=[
|
items=[
|
||||||
"Tomato sushi",
|
"Tomato sushi",
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
BatchData(
|
BatchData(
|
||||||
node_id="1",
|
node_path="1",
|
||||||
field_name="prompt",
|
field_name="prompt",
|
||||||
items=[
|
items=[
|
||||||
"Tomato sushi",
|
"Tomato sushi",
|
||||||
@ -242,7 +242,7 @@ def test_cannot_create_bad_batches():
|
|||||||
],
|
],
|
||||||
[
|
[
|
||||||
BatchData(
|
BatchData(
|
||||||
node_id="1",
|
node_path="1",
|
||||||
field_name="prompt",
|
field_name="prompt",
|
||||||
items=[
|
items=[
|
||||||
"Ume sushi",
|
"Ume sushi",
|
||||||
@ -258,7 +258,7 @@ def test_cannot_create_bad_batches():
|
|||||||
data=[
|
data=[
|
||||||
[
|
[
|
||||||
BatchData(
|
BatchData(
|
||||||
node_id="1",
|
node_path="1",
|
||||||
field_name="prompt",
|
field_name="prompt",
|
||||||
items=["Tomato sushi", 5],
|
items=["Tomato sushi", 5],
|
||||||
),
|
),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user