diff --git a/tests/test_node_graph.py b/tests/test_node_graph.py index 94682962ad..87a4948af4 100644 --- a/tests/test_node_graph.py +++ b/tests/test_node_graph.py @@ -421,21 +421,6 @@ def test_graph_invalid_if_edges_reference_missing_nodes(): assert g.is_valid() is False -# def test_graph_invalid_if_subgraph_invalid(): -# g = Graph() -# n1 = GraphInvocation(id="1") -# n1.graph = Graph() - -# n1_1 = TextToImageTestInvocation(id="2", prompt="Banana sushi") -# n1.graph.nodes[n1_1.id] = n1_1 -# e1 = create_edge("1", "image", "2", "image") -# n1.graph.edges.append(e1) - -# g.nodes[n1.id] = n1 - -# assert g.is_valid() is False - - def test_graph_invalid_if_has_cycle(): g = Graph() n1 = ESRGANInvocation(id="1") @@ -462,110 +447,6 @@ def test_graph_invalid_with_invalid_connection(): assert g.is_valid() is False -# # TODO: Subgraph operations -# def test_graph_gets_subgraph_node(): -# g = Graph() -# n1 = GraphInvocation(id="1") -# n1.graph = Graph() - -# n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") -# n1.graph.add_node(n1_1) - -# g.add_node(n1) - -# result = g.get_node("1.1") - -# assert result is not None -# assert result.id == "1" -# assert result == n1_1 - - -# def test_graph_expands_subgraph(): -# g = Graph() -# n1 = GraphInvocation(id="1") -# n1.graph = Graph() - -# n1_1 = AddInvocation(id="1", a=1, b=2) -# n1_2 = SubtractInvocation(id="2", b=3) -# n1.graph.add_node(n1_1) -# n1.graph.add_node(n1_2) -# n1.graph.add_edge(create_edge("1", "value", "2", "a")) - -# g.add_node(n1) - -# n2 = AddInvocation(id="2", b=5) -# g.add_node(n2) -# g.add_edge(create_edge("1.2", "value", "2", "a")) - -# dg = g.nx_graph_flat() -# assert set(dg.nodes) == {"1.1", "1.2", "2"} -# assert set(dg.edges) == {("1.1", "1.2"), ("1.2", "2")} - - -# def test_graph_subgraph_t2i(): -# g = Graph() -# n1 = GraphInvocation(id="1") - -# # Get text to image default graph -# lg = create_text_to_image() -# n1.graph = lg.graph - -# g.add_node(n1) - -# n2 = IntegerInvocation(id="2", value=512) -# n3 = IntegerInvocation(id="3", value=256) - -# g.add_node(n2) -# g.add_node(n3) - -# g.add_edge(create_edge("2", "value", "1.width", "value")) -# g.add_edge(create_edge("3", "value", "1.height", "value")) - -# n4 = ShowImageInvocation(id="4") -# g.add_node(n4) -# g.add_edge(create_edge("1.8", "image", "4", "image")) - -# # Validate -# dg = g.nx_graph_flat() -# assert set(dg.nodes) == {"1.width", "1.height", "1.seed", "1.3", "1.4", "1.5", "1.6", "1.7", "1.8", "2", "3", "4"} -# expected_edges = [(f"1.{e.source.node_id}", f"1.{e.destination.node_id}") for e in lg.graph.edges] -# expected_edges.extend([("2", "1.width"), ("3", "1.height"), ("1.8", "4")]) -# print(expected_edges) -# print(list(dg.edges)) -# assert set(dg.edges) == set(expected_edges) - - -# def test_graph_fails_to_get_missing_subgraph_node(): -# g = Graph() -# n1 = GraphInvocation(id="1") -# n1.graph = Graph() - -# n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") -# n1.graph.add_node(n1_1) - -# g.add_node(n1) - -# with pytest.raises(NodeNotFoundError): -# _ = g.get_node("1.2") - - -# def test_graph_fails_to_enumerate_non_subgraph_node(): -# g = Graph() -# n1 = GraphInvocation(id="1") -# n1.graph = Graph() - -# n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") -# n1.graph.add_node(n1_1) - -# g.add_node(n1) - -# n2 = ESRGANInvocation(id="2") -# g.add_node(n2) - -# with pytest.raises(NodeNotFoundError): -# _ = g.get_node("2.1") - - def test_graph_gets_networkx_graph(): g = Graph() n1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") diff --git a/tests/test_session_queue.py b/tests/test_session_queue.py index 48b980539c..bf26b9b002 100644 --- a/tests/test_session_queue.py +++ b/tests/test_session_queue.py @@ -39,30 +39,6 @@ def batch_graph() -> Graph: return g -# def test_populate_graph_with_subgraph(): -# g1 = Graph() -# g1.add_node(PromptTestInvocation(id="1", prompt="Banana sushi")) -# g1.add_node(PromptTestInvocation(id="2", prompt="Banana sushi")) -# n1 = PromptTestInvocation(id="1", prompt="Banana snake") -# subgraph = Graph() -# subgraph.add_node(n1) -# g1.add_node(GraphInvocation(id="3", graph=subgraph)) - -# nfvs = [ -# NodeFieldValue(node_path="1", field_name="prompt", value="Strawberry sushi"), -# NodeFieldValue(node_path="2", field_name="prompt", value="Strawberry sunday"), -# NodeFieldValue(node_path="3.1", field_name="prompt", value="Strawberry snake"), -# ] - -# g2 = populate_graph(g1, nfvs) - -# # do not mutate g1 -# assert g1 is not g2 -# assert g2.get_node("1").prompt == "Strawberry sushi" -# assert g2.get_node("2").prompt == "Strawberry sunday" -# assert g2.get_node("3.1").prompt == "Strawberry snake" - - def test_create_sessions_from_batch_with_runs(batch_data_collection, batch_graph): b = Batch(graph=batch_graph, data=batch_data_collection, runs=2) t = list(create_session_nfv_tuples(batch=b, maximum=1000))