diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 1b53f64222..4df9f0c4b0 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -184,10 +184,6 @@ class NodeIdMismatchError(ValueError): pass -class InvalidSubGraphError(ValueError): - pass - - class CyclicalGraphError(ValueError): pass @@ -196,25 +192,6 @@ class UnknownGraphValidationError(ValueError): pass -# TODO: Create and use an Empty output? -@invocation_output("graph_output") -class GraphInvocationOutput(BaseInvocationOutput): - pass - - -# TODO: Fill this out and move to invocations -@invocation("graph", version="1.0.0") -class GraphInvocation(BaseInvocation): - """Execute a graph""" - - # TODO: figure out how to create a default here - graph: "Graph" = InputField(description="The graph to run", default=None) - - def invoke(self, context: InvocationContext) -> GraphInvocationOutput: - """Invoke with provided services and return outputs.""" - return GraphInvocationOutput() - - @invocation_output("iterate_output") class IterateInvocationOutput(BaseInvocationOutput): """Used to connect iteration outputs. Will be expanded to a specific output.""" @@ -346,41 +323,21 @@ class Graph(BaseModel): self.nodes[node.id] = node - def _get_graph_and_node(self, node_path: str) -> tuple["Graph", str]: - """Returns the graph and node id for a node path.""" - # Materialized graphs may have nodes at the top level - if node_path in self.nodes: - return (self, node_path) - - node_id = node_path if "." not in node_path else node_path[: node_path.index(".")] - if node_id not in self.nodes: - raise NodeNotFoundError(f"Node {node_path} not found in graph") - - node = self.nodes[node_id] - - if not isinstance(node, GraphInvocation): - # There's more node path left but this isn't a graph - failure - raise NodeNotFoundError("Node path terminated early at a non-graph node") - - return node.graph._get_graph_and_node(node_path[node_path.index(".") + 1 :]) - - def delete_node(self, node_path: str) -> None: + def delete_node(self, node_id: str) -> None: """Deletes a node from a graph""" try: - graph, node_id = self._get_graph_and_node(node_path) - # Delete edges for this node - input_edges = self._get_input_edges_and_graphs(node_path) - output_edges = self._get_output_edges_and_graphs(node_path) + input_edges = self._get_input_edges(node_id) + output_edges = self._get_output_edges(node_id) - for edge_graph, _, edge in input_edges: - edge_graph.delete_edge(edge) + for edge in input_edges: + self.delete_edge(edge) - for edge_graph, _, edge in output_edges: - edge_graph.delete_edge(edge) + for edge in output_edges: + self.delete_edge(edge) - del graph.nodes[node_id] + del self.nodes[node_id] except NodeNotFoundError: pass # Ignore, not doesn't exist (should this throw?) @@ -430,13 +387,6 @@ class Graph(BaseModel): if k != v.id: raise NodeIdMismatchError(f"Node ids must match, got {k} and {v.id}") - # Validate all subgraphs - for gn in (n for n in self.nodes.values() if isinstance(n, GraphInvocation)): - try: - gn.graph.validate_self() - except Exception as e: - raise InvalidSubGraphError(f"Subgraph {gn.id} is invalid") from e - # Validate that all edges match nodes and fields in the graph for edge in self.edges: source_node = self.nodes.get(edge.source.node_id, None) @@ -498,7 +448,6 @@ class Graph(BaseModel): except ( DuplicateNodeIdError, NodeIdMismatchError, - InvalidSubGraphError, NodeNotFoundError, NodeFieldNotFoundError, CyclicalGraphError, @@ -519,7 +468,7 @@ class Graph(BaseModel): def _validate_edge(self, edge: Edge): """Validates that a new edge doesn't create a cycle in the graph""" - # Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly) + # Validate that the nodes exist try: from_node = self.get_node(edge.source.node_id) to_node = self.get_node(edge.destination.node_id) @@ -586,171 +535,90 @@ class Graph(BaseModel): f"Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}" ) - def has_node(self, node_path: str) -> bool: + def has_node(self, node_id: str) -> bool: """Determines whether or not a node exists in the graph.""" try: - n = self.get_node(node_path) - if n is not None: - return True - else: - return False + _ = self.get_node(node_id) + return True except NodeNotFoundError: return False - def get_node(self, node_path: str) -> BaseInvocation: - """Gets a node from the graph using a node path.""" - # Materialized graphs may have nodes at the top level - graph, node_id = self._get_graph_and_node(node_path) - return graph.nodes[node_id] + def get_node(self, node_id: str) -> BaseInvocation: + """Gets a node from the graph.""" + try: + return self.nodes[node_id] + except KeyError as e: + raise NodeNotFoundError(f"Node {node_id} not found in graph") from e - def _get_node_path(self, node_id: str, prefix: Optional[str] = None) -> str: - return node_id if prefix is None or prefix == "" else f"{prefix}.{node_id}" - - def update_node(self, node_path: str, new_node: BaseInvocation) -> None: + def update_node(self, node_id: str, new_node: BaseInvocation) -> None: """Updates a node in the graph.""" - graph, node_id = self._get_graph_and_node(node_path) - node = graph.nodes[node_id] + node = self.nodes[node_id] # Ensure the node type matches the new node if type(node) is not type(new_node): - raise TypeError(f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}") + raise TypeError(f"Node {node_id} is type {type(node)} but new node is type {type(new_node)}") # Ensure the new id is either the same or is not in the graph - prefix = None if "." not in node_path else node_path[: node_path.rindex(".")] - new_path = self._get_node_path(new_node.id, prefix=prefix) - if new_node.id != node.id and self.has_node(new_path): - raise NodeAlreadyInGraphError("Node with id {new_node.id} already exists in graph") + if new_node.id != node.id and self.has_node(new_node.id): + raise NodeAlreadyInGraphError(f"Node with id {new_node.id} already exists in graph") # Set the new node in the graph - graph.nodes[new_node.id] = new_node + self.nodes[new_node.id] = new_node if new_node.id != node.id: - input_edges = self._get_input_edges_and_graphs(node_path) - output_edges = self._get_output_edges_and_graphs(node_path) + input_edges = self._get_input_edges(node_id) + output_edges = self._get_output_edges(node_id) # Delete node and all edges - graph.delete_node(node_path) + self.delete_node(node_id) # Create new edges for each input and output - for graph, _, edge in input_edges: - # Remove the graph prefix from the node path - new_graph_node_path = ( - new_node.id - if "." not in edge.destination.node_id - else f'{edge.destination.node_id[edge.destination.node_id.rindex("."):]}.{new_node.id}' - ) - graph.add_edge( + for edge in input_edges: + self.add_edge( Edge( source=edge.source, - destination=EdgeConnection(node_id=new_graph_node_path, field=edge.destination.field), + destination=EdgeConnection(node_id=new_node.id, field=edge.destination.field), ) ) - for graph, _, edge in output_edges: - # Remove the graph prefix from the node path - new_graph_node_path = ( - new_node.id - if "." not in edge.source.node_id - else f'{edge.source.node_id[edge.source.node_id.rindex("."):]}.{new_node.id}' - ) - graph.add_edge( + for edge in output_edges: + self.add_edge( Edge( - source=EdgeConnection(node_id=new_graph_node_path, field=edge.source.field), + source=EdgeConnection(node_id=new_node.id, field=edge.source.field), destination=edge.destination, ) ) - def _get_input_edges(self, node_path: str, field: Optional[str] = None) -> list[Edge]: - """Gets all input edges for a node""" - edges = self._get_input_edges_and_graphs(node_path) + def _get_input_edges(self, node_id: str, field: Optional[str] = None) -> list[Edge]: + """Gets all input edges for a node. If field is provided, only edges to that field are returned.""" - # Filter to edges that match the field - filtered_edges = (e for e in edges if field is None or e[2].destination.field == field) + edges = [e for e in self.edges if e.destination.node_id == node_id] - # Create full node paths for each edge - return [ - Edge( - source=EdgeConnection( - node_id=self._get_node_path(e.source.node_id, prefix=prefix), - field=e.source.field, - ), - destination=EdgeConnection( - node_id=self._get_node_path(e.destination.node_id, prefix=prefix), - field=e.destination.field, - ), - ) - for _, prefix, e in filtered_edges - ] + if field is None: + return edges - def _get_input_edges_and_graphs( - self, node_path: str, prefix: Optional[str] = None - ) -> list[tuple["Graph", Union[str, None], Edge]]: - """Gets all input edges for a node along with the graph they are in and the graph's path""" - edges = [] + filtered_edges = [e for e in edges if e.destination.field == field] - # Return any input edges that appear in this graph - edges.extend([(self, prefix, e) for e in self.edges if e.destination.node_id == node_path]) + return filtered_edges - node_id = node_path if "." not in node_path else node_path[: node_path.index(".")] - node = self.nodes[node_id] + def _get_output_edges(self, node_id: str, field: Optional[str] = None) -> list[Edge]: + """Gets all output edges for a node. If field is provided, only edges from that field are returned.""" + edges = [e for e in self.edges if e.source.node_id == node_id] - if isinstance(node, GraphInvocation): - graph = node.graph - graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix) - graph_edges = graph._get_input_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path) - edges.extend(graph_edges) + if field is None: + return edges - return edges + filtered_edges = [e for e in edges if e.source.field == field] - def _get_output_edges(self, node_path: str, field: str) -> list[Edge]: - """Gets all output edges for a node""" - edges = self._get_output_edges_and_graphs(node_path) - - # Filter to edges that match the field - filtered_edges = (e for e in edges if e[2].source.field == field) - - # Create full node paths for each edge - return [ - Edge( - source=EdgeConnection( - node_id=self._get_node_path(e.source.node_id, prefix=prefix), - field=e.source.field, - ), - destination=EdgeConnection( - node_id=self._get_node_path(e.destination.node_id, prefix=prefix), - field=e.destination.field, - ), - ) - for _, prefix, e in filtered_edges - ] - - def _get_output_edges_and_graphs( - self, node_path: str, prefix: Optional[str] = None - ) -> list[tuple["Graph", Union[str, None], Edge]]: - """Gets all output edges for a node along with the graph they are in and the graph's path""" - edges = [] - - # Return any input edges that appear in this graph - edges.extend([(self, prefix, e) for e in self.edges if e.source.node_id == node_path]) - - node_id = node_path if "." not in node_path else node_path[: node_path.index(".")] - node = self.nodes[node_id] - - if isinstance(node, GraphInvocation): - graph = node.graph - graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix) - graph_edges = graph._get_output_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path) - edges.extend(graph_edges) - - return edges + return filtered_edges def _is_iterator_connection_valid( self, - node_path: str, + node_id: str, new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None, ) -> bool: - inputs = [e.source for e in self._get_input_edges(node_path, "collection")] - outputs = [e.destination for e in self._get_output_edges(node_path, "item")] + inputs = [e.source for e in self._get_input_edges(node_id, "collection")] + outputs = [e.destination for e in self._get_output_edges(node_id, "item")] if new_input is not None: inputs.append(new_input) @@ -778,12 +646,12 @@ class Graph(BaseModel): def _is_collector_connection_valid( self, - node_path: str, + node_id: str, new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None, ) -> bool: - inputs = [e.source for e in self._get_input_edges(node_path, "item")] - outputs = [e.destination for e in self._get_output_edges(node_path, "collection")] + inputs = [e.source for e in self._get_input_edges(node_id, "item")] + outputs = [e.destination for e in self._get_output_edges(node_id, "collection")] if new_input is not None: inputs.append(new_input) @@ -839,27 +707,17 @@ class Graph(BaseModel): g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges}) return g - def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph: + def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None) -> nx.DiGraph: """Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)""" g = nx_graph or nx.DiGraph() # Add all nodes from this graph except graph/iteration nodes - g.add_nodes_from( - [ - self._get_node_path(n.id, prefix) - for n in self.nodes.values() - if not isinstance(n, GraphInvocation) and not isinstance(n, IterateInvocation) - ] - ) - - # Expand graph nodes - for sgn in (gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)): - g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix)) + g.add_nodes_from([n.id for n in self.nodes.values() if not isinstance(n, IterateInvocation)]) # TODO: figure out if iteration nodes need to be expanded unique_edges = {(e.source.node_id, e.destination.node_id) for e in self.edges} - g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges]) + g.add_edges_from([(e[0], e[1]) for e in unique_edges]) return g @@ -1017,17 +875,17 @@ class GraphExecutionState(BaseModel): """Returns true if the graph has any errors""" return len(self.errors) > 0 - def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[str, str]]) -> list[str]: + def _create_execution_node(self, node_id: str, iteration_node_map: list[tuple[str, str]]) -> list[str]: """Prepares an iteration node and connects all edges, returning the new node id""" - node = self.graph.get_node(node_path) + node = self.graph.get_node(node_id) self_iteration_count = -1 # If this is an iterator node, we must create a copy for each iteration if isinstance(node, IterateInvocation): # Get input collection edge (should error if there are no inputs) - input_collection_edge = next(iter(self.graph._get_input_edges(node_path, "collection"))) + input_collection_edge = next(iter(self.graph._get_input_edges(node_id, "collection"))) input_collection_prepared_node_id = next( n[1] for n in iteration_node_map if n[0] == input_collection_edge.source.node_id ) @@ -1041,7 +899,7 @@ class GraphExecutionState(BaseModel): return new_nodes # Get all input edges - input_edges = self.graph._get_input_edges(node_path) + input_edges = self.graph._get_input_edges(node_id) # Create new edges for this iteration # For collect nodes, this may contain multiple inputs to the same field @@ -1068,10 +926,10 @@ class GraphExecutionState(BaseModel): # Add to execution graph self.execution_graph.add_node(new_node) - self.prepared_source_mapping[new_node.id] = node_path - if node_path not in self.source_prepared_mapping: - self.source_prepared_mapping[node_path] = set() - self.source_prepared_mapping[node_path].add(new_node.id) + self.prepared_source_mapping[new_node.id] = node_id + if node_id not in self.source_prepared_mapping: + self.source_prepared_mapping[node_id] = set() + self.source_prepared_mapping[node_id].add(new_node.id) # Add new edges to execution graph for edge in new_edges: @@ -1175,13 +1033,13 @@ class GraphExecutionState(BaseModel): def _get_iteration_node( self, - source_node_path: str, + source_node_id: str, graph: nx.DiGraph, execution_graph: nx.DiGraph, prepared_iterator_nodes: list[str], ) -> Optional[str]: """Gets the prepared version of the specified source node that matches every iteration specified""" - prepared_nodes = self.source_prepared_mapping[source_node_path] + prepared_nodes = self.source_prepared_mapping[source_node_id] if len(prepared_nodes) == 1: return next(iter(prepared_nodes)) @@ -1192,7 +1050,7 @@ class GraphExecutionState(BaseModel): # Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source) iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes] - parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_path)] + parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_id)] return next( (n for n in prepared_nodes if all(nx.has_path(execution_graph, pit[0], n) for pit in parent_iterators)), @@ -1261,19 +1119,19 @@ class GraphExecutionState(BaseModel): def add_node(self, node: BaseInvocation) -> None: self.graph.add_node(node) - def update_node(self, node_path: str, new_node: BaseInvocation) -> None: - if not self._is_node_updatable(node_path): + def update_node(self, node_id: str, new_node: BaseInvocation) -> None: + if not self._is_node_updatable(node_id): raise NodeAlreadyExecutedError( - f"Node {node_path} has already been prepared or executed and cannot be updated" + f"Node {node_id} has already been prepared or executed and cannot be updated" ) - self.graph.update_node(node_path, new_node) + self.graph.update_node(node_id, new_node) - def delete_node(self, node_path: str) -> None: - if not self._is_node_updatable(node_path): + def delete_node(self, node_id: str) -> None: + if not self._is_node_updatable(node_id): raise NodeAlreadyExecutedError( - f"Node {node_path} has already been prepared or executed and cannot be deleted" + f"Node {node_id} has already been prepared or executed and cannot be deleted" ) - self.graph.delete_node(node_path) + self.graph.delete_node(node_id) def add_edge(self, edge: Edge) -> None: if not self._is_node_updatable(edge.destination.node_id): diff --git a/tests/aa_nodes/test_invoker.py b/tests/aa_nodes/test_invoker.py index f67b5a2ac5..38fcf859a5 100644 --- a/tests/aa_nodes/test_invoker.py +++ b/tests/aa_nodes/test_invoker.py @@ -23,7 +23,7 @@ from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService from invokeai.app.services.invoker import Invoker from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID -from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation +from invokeai.app.services.shared.graph import Graph, GraphExecutionState @pytest.fixture @@ -35,17 +35,6 @@ def simple_graph(): return g -@pytest.fixture -def graph_with_subgraph(): - sub_g = Graph() - sub_g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi")) - sub_g.add_node(TextToImageTestInvocation(id="2")) - sub_g.add_edge(create_edge("1", "prompt", "2", "prompt")) - g = Graph() - g.add_node(GraphInvocation(id="1", graph=sub_g)) - return g - - # This must be defined here to avoid issues with the dynamic creation of the union of all invocation types # Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate # the test invocations. diff --git a/tests/aa_nodes/test_node_graph.py b/tests/aa_nodes/test_node_graph.py index 12a181f392..94682962ad 100644 --- a/tests/aa_nodes/test_node_graph.py +++ b/tests/aa_nodes/test_node_graph.py @@ -8,8 +8,6 @@ from invokeai.app.invocations.baseinvocation import ( invocation, invocation_output, ) -from invokeai.app.invocations.image import ShowImageInvocation -from invokeai.app.invocations.math import AddInvocation, SubtractInvocation from invokeai.app.invocations.primitives import ( FloatCollectionInvocation, FloatInvocation, @@ -17,13 +15,11 @@ from invokeai.app.invocations.primitives import ( StringInvocation, ) from invokeai.app.invocations.upscale import ESRGANInvocation -from invokeai.app.services.shared.default_graphs import create_text_to_image from invokeai.app.services.shared.graph import ( CollectInvocation, Edge, EdgeConnection, Graph, - GraphInvocation, InvalidEdgeError, IterateInvocation, NodeAlreadyInGraphError, @@ -425,19 +421,19 @@ def test_graph_invalid_if_edges_reference_missing_nodes(): assert g.is_valid() is False -def test_graph_invalid_if_subgraph_invalid(): - g = Graph() - n1 = GraphInvocation(id="1") - n1.graph = Graph() +# def test_graph_invalid_if_subgraph_invalid(): +# g = Graph() +# n1 = GraphInvocation(id="1") +# n1.graph = Graph() - n1_1 = TextToImageTestInvocation(id="2", prompt="Banana sushi") - n1.graph.nodes[n1_1.id] = n1_1 - e1 = create_edge("1", "image", "2", "image") - n1.graph.edges.append(e1) +# n1_1 = TextToImageTestInvocation(id="2", prompt="Banana sushi") +# n1.graph.nodes[n1_1.id] = n1_1 +# e1 = create_edge("1", "image", "2", "image") +# n1.graph.edges.append(e1) - g.nodes[n1.id] = n1 +# g.nodes[n1.id] = n1 - assert g.is_valid() is False +# assert g.is_valid() is False def test_graph_invalid_if_has_cycle(): @@ -466,108 +462,108 @@ def test_graph_invalid_with_invalid_connection(): assert g.is_valid() is False -# TODO: Subgraph operations -def test_graph_gets_subgraph_node(): - g = Graph() - n1 = GraphInvocation(id="1") - n1.graph = Graph() +# # TODO: Subgraph operations +# def test_graph_gets_subgraph_node(): +# g = Graph() +# n1 = GraphInvocation(id="1") +# n1.graph = Graph() - n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") - n1.graph.add_node(n1_1) +# n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") +# n1.graph.add_node(n1_1) - g.add_node(n1) +# g.add_node(n1) - result = g.get_node("1.1") +# result = g.get_node("1.1") - assert result is not None - assert result.id == "1" - assert result == n1_1 +# assert result is not None +# assert result.id == "1" +# assert result == n1_1 -def test_graph_expands_subgraph(): - g = Graph() - n1 = GraphInvocation(id="1") - n1.graph = Graph() +# def test_graph_expands_subgraph(): +# g = Graph() +# n1 = GraphInvocation(id="1") +# n1.graph = Graph() - n1_1 = AddInvocation(id="1", a=1, b=2) - n1_2 = SubtractInvocation(id="2", b=3) - n1.graph.add_node(n1_1) - n1.graph.add_node(n1_2) - n1.graph.add_edge(create_edge("1", "value", "2", "a")) +# n1_1 = AddInvocation(id="1", a=1, b=2) +# n1_2 = SubtractInvocation(id="2", b=3) +# n1.graph.add_node(n1_1) +# n1.graph.add_node(n1_2) +# n1.graph.add_edge(create_edge("1", "value", "2", "a")) - g.add_node(n1) +# g.add_node(n1) - n2 = AddInvocation(id="2", b=5) - g.add_node(n2) - g.add_edge(create_edge("1.2", "value", "2", "a")) +# n2 = AddInvocation(id="2", b=5) +# g.add_node(n2) +# g.add_edge(create_edge("1.2", "value", "2", "a")) - dg = g.nx_graph_flat() - assert set(dg.nodes) == {"1.1", "1.2", "2"} - assert set(dg.edges) == {("1.1", "1.2"), ("1.2", "2")} +# dg = g.nx_graph_flat() +# assert set(dg.nodes) == {"1.1", "1.2", "2"} +# assert set(dg.edges) == {("1.1", "1.2"), ("1.2", "2")} -def test_graph_subgraph_t2i(): - g = Graph() - n1 = GraphInvocation(id="1") +# def test_graph_subgraph_t2i(): +# g = Graph() +# n1 = GraphInvocation(id="1") - # Get text to image default graph - lg = create_text_to_image() - n1.graph = lg.graph +# # Get text to image default graph +# lg = create_text_to_image() +# n1.graph = lg.graph - g.add_node(n1) +# g.add_node(n1) - n2 = IntegerInvocation(id="2", value=512) - n3 = IntegerInvocation(id="3", value=256) +# n2 = IntegerInvocation(id="2", value=512) +# n3 = IntegerInvocation(id="3", value=256) - g.add_node(n2) - g.add_node(n3) +# g.add_node(n2) +# g.add_node(n3) - g.add_edge(create_edge("2", "value", "1.width", "value")) - g.add_edge(create_edge("3", "value", "1.height", "value")) +# g.add_edge(create_edge("2", "value", "1.width", "value")) +# g.add_edge(create_edge("3", "value", "1.height", "value")) - n4 = ShowImageInvocation(id="4") - g.add_node(n4) - g.add_edge(create_edge("1.8", "image", "4", "image")) +# n4 = ShowImageInvocation(id="4") +# g.add_node(n4) +# g.add_edge(create_edge("1.8", "image", "4", "image")) - # Validate - dg = g.nx_graph_flat() - assert set(dg.nodes) == {"1.width", "1.height", "1.seed", "1.3", "1.4", "1.5", "1.6", "1.7", "1.8", "2", "3", "4"} - expected_edges = [(f"1.{e.source.node_id}", f"1.{e.destination.node_id}") for e in lg.graph.edges] - expected_edges.extend([("2", "1.width"), ("3", "1.height"), ("1.8", "4")]) - print(expected_edges) - print(list(dg.edges)) - assert set(dg.edges) == set(expected_edges) +# # Validate +# dg = g.nx_graph_flat() +# assert set(dg.nodes) == {"1.width", "1.height", "1.seed", "1.3", "1.4", "1.5", "1.6", "1.7", "1.8", "2", "3", "4"} +# expected_edges = [(f"1.{e.source.node_id}", f"1.{e.destination.node_id}") for e in lg.graph.edges] +# expected_edges.extend([("2", "1.width"), ("3", "1.height"), ("1.8", "4")]) +# print(expected_edges) +# print(list(dg.edges)) +# assert set(dg.edges) == set(expected_edges) -def test_graph_fails_to_get_missing_subgraph_node(): - g = Graph() - n1 = GraphInvocation(id="1") - n1.graph = Graph() +# def test_graph_fails_to_get_missing_subgraph_node(): +# g = Graph() +# n1 = GraphInvocation(id="1") +# n1.graph = Graph() - n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") - n1.graph.add_node(n1_1) +# n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") +# n1.graph.add_node(n1_1) - g.add_node(n1) +# g.add_node(n1) - with pytest.raises(NodeNotFoundError): - _ = g.get_node("1.2") +# with pytest.raises(NodeNotFoundError): +# _ = g.get_node("1.2") -def test_graph_fails_to_enumerate_non_subgraph_node(): - g = Graph() - n1 = GraphInvocation(id="1") - n1.graph = Graph() +# def test_graph_fails_to_enumerate_non_subgraph_node(): +# g = Graph() +# n1 = GraphInvocation(id="1") +# n1.graph = Graph() - n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") - n1.graph.add_node(n1_1) +# n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi") +# n1.graph.add_node(n1_1) - g.add_node(n1) +# g.add_node(n1) - n2 = ESRGANInvocation(id="2") - g.add_node(n2) +# n2 = ESRGANInvocation(id="2") +# g.add_node(n2) - with pytest.raises(NodeNotFoundError): - _ = g.get_node("2.1") +# with pytest.raises(NodeNotFoundError): +# _ = g.get_node("2.1") def test_graph_gets_networkx_graph(): diff --git a/tests/aa_nodes/test_session_queue.py b/tests/aa_nodes/test_session_queue.py index b15bb9df36..bfe6444de8 100644 --- a/tests/aa_nodes/test_session_queue.py +++ b/tests/aa_nodes/test_session_queue.py @@ -8,10 +8,9 @@ from invokeai.app.services.session_queue.session_queue_common import ( NodeFieldValue, calc_session_count, create_session_nfv_tuples, - populate_graph, prepare_values_to_insert, ) -from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation +from invokeai.app.services.shared.graph import Graph, GraphExecutionState from tests.aa_nodes.test_nodes import PromptTestInvocation @@ -39,28 +38,28 @@ def batch_graph() -> Graph: return g -def test_populate_graph_with_subgraph(): - g1 = Graph() - g1.add_node(PromptTestInvocation(id="1", prompt="Banana sushi")) - g1.add_node(PromptTestInvocation(id="2", prompt="Banana sushi")) - n1 = PromptTestInvocation(id="1", prompt="Banana snake") - subgraph = Graph() - subgraph.add_node(n1) - g1.add_node(GraphInvocation(id="3", graph=subgraph)) +# def test_populate_graph_with_subgraph(): +# g1 = Graph() +# g1.add_node(PromptTestInvocation(id="1", prompt="Banana sushi")) +# g1.add_node(PromptTestInvocation(id="2", prompt="Banana sushi")) +# n1 = PromptTestInvocation(id="1", prompt="Banana snake") +# subgraph = Graph() +# subgraph.add_node(n1) +# g1.add_node(GraphInvocation(id="3", graph=subgraph)) - nfvs = [ - NodeFieldValue(node_path="1", field_name="prompt", value="Strawberry sushi"), - NodeFieldValue(node_path="2", field_name="prompt", value="Strawberry sunday"), - NodeFieldValue(node_path="3.1", field_name="prompt", value="Strawberry snake"), - ] +# nfvs = [ +# NodeFieldValue(node_path="1", field_name="prompt", value="Strawberry sushi"), +# NodeFieldValue(node_path="2", field_name="prompt", value="Strawberry sunday"), +# NodeFieldValue(node_path="3.1", field_name="prompt", value="Strawberry snake"), +# ] - g2 = populate_graph(g1, nfvs) +# g2 = populate_graph(g1, nfvs) - # do not mutate g1 - assert g1 is not g2 - assert g2.get_node("1").prompt == "Strawberry sushi" - assert g2.get_node("2").prompt == "Strawberry sunday" - assert g2.get_node("3.1").prompt == "Strawberry snake" +# # do not mutate g1 +# assert g1 is not g2 +# assert g2.get_node("1").prompt == "Strawberry sushi" +# assert g2.get_node("2").prompt == "Strawberry sunday" +# assert g2.get_node("3.1").prompt == "Strawberry snake" def test_create_sessions_from_batch_with_runs(batch_data_collection, batch_graph):