InvokeAI/tests/aa_nodes/test_session_queue.py
psychedelicious 5fc745653a tidy(nodes): remove GraphInvocation
`GraphInvocation` is a node that can contain a whole graph. It is removed for a number of reasons:

1. This feature was unused (the UI doesn't support it) and there is no plan for it to be used.

The use-case it served is known in other node execution engines as "node groups" or "blocks" - a self-contained group of nodes, which has group inputs and outputs. This is a planned feature that will be handled client-side.

2. It adds substantial complexity to the graph processing logic. It's probably not enough to have a measurable performance impact but it does make it harder to work in the graph logic.

3. It allows for graphs to be recursive, and the improved invocations union handling does not play well with it. Actually, it works fine within `graph.py` but not in the tests for some reason. I do not understand why. There's probably a workaround, but I took this as encouragement to remove `GraphInvocation` from the app since we don't use it.
2024-02-20 09:48:14 +11:00

257 lines
9.6 KiB
Python

import pytest
from pydantic import TypeAdapter, ValidationError
from invokeai.app.services.session_queue.session_queue_common import (
Batch,
BatchDataCollection,
BatchDatum,
NodeFieldValue,
calc_session_count,
create_session_nfv_tuples,
prepare_values_to_insert,
)
from invokeai.app.services.shared.graph import Graph, GraphExecutionState
from tests.aa_nodes.test_nodes import PromptTestInvocation
@pytest.fixture
def batch_data_collection() -> BatchDataCollection:
return [
[
# zipped
BatchDatum(node_path="1", field_name="prompt", items=["Banana sushi", "Grape sushi"]),
BatchDatum(node_path="2", field_name="prompt", items=["Strawberry sushi", "Blueberry sushi"]),
],
[
BatchDatum(node_path="3", field_name="prompt", items=["Orange sushi", "Apple sushi"]),
],
]
@pytest.fixture
def batch_graph() -> Graph:
g = Graph()
g.add_node(PromptTestInvocation(id="1", prompt="Chevy"))
g.add_node(PromptTestInvocation(id="2", prompt="Toyota"))
g.add_node(PromptTestInvocation(id="3", prompt="Subaru"))
g.add_node(PromptTestInvocation(id="4", prompt="Nissan"))
return g
# def test_populate_graph_with_subgraph():
# g1 = Graph()
# g1.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
# g1.add_node(PromptTestInvocation(id="2", prompt="Banana sushi"))
# n1 = PromptTestInvocation(id="1", prompt="Banana snake")
# subgraph = Graph()
# subgraph.add_node(n1)
# g1.add_node(GraphInvocation(id="3", graph=subgraph))
# nfvs = [
# NodeFieldValue(node_path="1", field_name="prompt", value="Strawberry sushi"),
# NodeFieldValue(node_path="2", field_name="prompt", value="Strawberry sunday"),
# NodeFieldValue(node_path="3.1", field_name="prompt", value="Strawberry snake"),
# ]
# g2 = populate_graph(g1, nfvs)
# # do not mutate g1
# assert g1 is not g2
# assert g2.get_node("1").prompt == "Strawberry sushi"
# assert g2.get_node("2").prompt == "Strawberry sunday"
# assert g2.get_node("3.1").prompt == "Strawberry snake"
def test_create_sessions_from_batch_with_runs(batch_data_collection, batch_graph):
b = Batch(graph=batch_graph, data=batch_data_collection, runs=2)
t = list(create_session_nfv_tuples(batch=b, maximum=1000))
# 2 list[BatchDatum] * length 2 * 2 runs = 8
assert len(t) == 8
assert t[0][0].graph.get_node("1").prompt == "Banana sushi"
assert t[0][0].graph.get_node("2").prompt == "Strawberry sushi"
assert t[0][0].graph.get_node("3").prompt == "Orange sushi"
assert t[0][0].graph.get_node("4").prompt == "Nissan"
assert t[1][0].graph.get_node("1").prompt == "Banana sushi"
assert t[1][0].graph.get_node("2").prompt == "Strawberry sushi"
assert t[1][0].graph.get_node("3").prompt == "Apple sushi"
assert t[1][0].graph.get_node("4").prompt == "Nissan"
assert t[2][0].graph.get_node("1").prompt == "Grape sushi"
assert t[2][0].graph.get_node("2").prompt == "Blueberry sushi"
assert t[2][0].graph.get_node("3").prompt == "Orange sushi"
assert t[2][0].graph.get_node("4").prompt == "Nissan"
assert t[3][0].graph.get_node("1").prompt == "Grape sushi"
assert t[3][0].graph.get_node("2").prompt == "Blueberry sushi"
assert t[3][0].graph.get_node("3").prompt == "Apple sushi"
assert t[3][0].graph.get_node("4").prompt == "Nissan"
# repeat for second run
assert t[4][0].graph.get_node("1").prompt == "Banana sushi"
assert t[4][0].graph.get_node("2").prompt == "Strawberry sushi"
assert t[4][0].graph.get_node("3").prompt == "Orange sushi"
assert t[4][0].graph.get_node("4").prompt == "Nissan"
assert t[5][0].graph.get_node("1").prompt == "Banana sushi"
assert t[5][0].graph.get_node("2").prompt == "Strawberry sushi"
assert t[5][0].graph.get_node("3").prompt == "Apple sushi"
assert t[5][0].graph.get_node("4").prompt == "Nissan"
assert t[6][0].graph.get_node("1").prompt == "Grape sushi"
assert t[6][0].graph.get_node("2").prompt == "Blueberry sushi"
assert t[6][0].graph.get_node("3").prompt == "Orange sushi"
assert t[6][0].graph.get_node("4").prompt == "Nissan"
assert t[7][0].graph.get_node("1").prompt == "Grape sushi"
assert t[7][0].graph.get_node("2").prompt == "Blueberry sushi"
assert t[7][0].graph.get_node("3").prompt == "Apple sushi"
assert t[7][0].graph.get_node("4").prompt == "Nissan"
def test_create_sessions_from_batch_without_runs(batch_data_collection, batch_graph):
b = Batch(graph=batch_graph, data=batch_data_collection)
t = list(create_session_nfv_tuples(batch=b, maximum=1000))
# 2 list[BatchDatum] * length 2 * 1 runs = 8
assert len(t) == 4
def test_create_sessions_from_batch_without_batch(batch_graph):
b = Batch(graph=batch_graph, runs=2)
t = list(create_session_nfv_tuples(batch=b, maximum=1000))
# 2 runs
assert len(t) == 2
def test_create_sessions_from_batch_without_batch_or_runs(batch_graph):
b = Batch(graph=batch_graph)
t = list(create_session_nfv_tuples(batch=b, maximum=1000))
# 1 run
assert len(t) == 1
def test_create_sessions_from_batch_with_runs_and_max(batch_data_collection, batch_graph):
b = Batch(graph=batch_graph, data=batch_data_collection, runs=2)
t = list(create_session_nfv_tuples(batch=b, maximum=5))
# 2 list[BatchDatum] * length 2 * 2 runs = 8, but max is 5
assert len(t) == 5
def test_calc_session_count(batch_data_collection, batch_graph):
b = Batch(graph=batch_graph, data=batch_data_collection, runs=2)
# 2 list[BatchDatum] * length 2 * 2 runs = 8
assert calc_session_count(batch=b) == 8
def test_prepare_values_to_insert(batch_data_collection, batch_graph):
b = Batch(graph=batch_graph, data=batch_data_collection, runs=2)
values = prepare_values_to_insert(queue_id="default", batch=b, priority=0, max_new_queue_items=1000)
assert len(values) == 8
GraphExecutionStateValidator = TypeAdapter(GraphExecutionState)
# graph should be serialized
ges = GraphExecutionStateValidator.validate_json(values[0].session)
# graph values should be populated
assert ges.graph.get_node("1").prompt == "Banana sushi"
assert ges.graph.get_node("2").prompt == "Strawberry sushi"
assert ges.graph.get_node("3").prompt == "Orange sushi"
assert ges.graph.get_node("4").prompt == "Nissan"
# session ids should match deserialized graph
assert [v.session_id for v in values] == [GraphExecutionStateValidator.validate_json(v.session).id for v in values]
# should unique session ids
sids = [v.session_id for v in values]
assert len(sids) == len(set(sids))
NodeFieldValueValidator = TypeAdapter(list[NodeFieldValue])
# should have 3 node field values
assert isinstance(values[0].field_values, str)
assert len(NodeFieldValueValidator.validate_json(values[0].field_values)) == 3
# should have batch id and priority
assert all(v.batch_id == b.batch_id for v in values)
assert all(v.priority == 0 for v in values)
def test_prepare_values_to_insert_with_priority(batch_data_collection, batch_graph):
b = Batch(graph=batch_graph, data=batch_data_collection, runs=2)
values = prepare_values_to_insert(queue_id="default", batch=b, priority=1, max_new_queue_items=1000)
assert all(v.priority == 1 for v in values)
def test_prepare_values_to_insert_with_max(batch_data_collection, batch_graph):
b = Batch(graph=batch_graph, data=batch_data_collection, runs=2)
values = prepare_values_to_insert(queue_id="default", batch=b, priority=1, max_new_queue_items=5)
assert len(values) == 5
def test_cannot_create_bad_batch_items_length(batch_graph):
with pytest.raises(ValidationError, match="Zipped batch items must all have the same length"):
Batch(
graph=batch_graph,
data=[
[
BatchDatum(node_path="1", field_name="prompt", items=["Banana sushi"]), # 1 item
BatchDatum(node_path="2", field_name="prompt", items=["Toyota", "Nissan"]), # 2 items
],
],
)
def test_cannot_create_bad_batch_items_type(batch_graph):
with pytest.raises(ValidationError, match="All items in a batch must have the same type"):
Batch(
graph=batch_graph,
data=[
[
BatchDatum(node_path="1", field_name="prompt", items=["Banana sushi", 123]),
]
],
)
def test_cannot_create_bad_batch_unique_ids(batch_graph):
with pytest.raises(ValidationError, match="Each batch data must have unique node_id and field_name"):
Batch(
graph=batch_graph,
data=[
[
BatchDatum(node_path="1", field_name="prompt", items=["Banana sushi"]),
],
[
BatchDatum(node_path="1", field_name="prompt", items=["Banana sushi"]),
],
],
)
def test_cannot_create_bad_batch_nodes_exist(
batch_graph,
):
with pytest.raises(ValidationError, match=r"Node .* not found in graph"):
Batch(
graph=batch_graph,
data=[
[
BatchDatum(node_path="batman", field_name="prompt", items=["Banana sushi"]),
],
],
)
def test_cannot_create_bad_batch_fields_exist(
batch_graph,
):
with pytest.raises(ValidationError, match=r"Field .* not found in node"):
Batch(
graph=batch_graph,
data=[
[
BatchDatum(node_path="1", field_name="batman", items=["Banana sushi"]),
],
],
)