From 3021c7839028336a6a01db39a145b6582f800388 Mon Sep 17 00:00:00 2001 From: Kyle Schouviller Date: Tue, 14 Mar 2023 23:09:30 -0700 Subject: [PATCH] [nodes] Add Edge data type --- invokeai/app/api/routers/sessions.py | 13 +- invokeai/app/cli_app.py | 31 +++-- invokeai/app/services/graph.py | 179 ++++++++++++++------------- static/dream_web/test.html | 9 +- tests/nodes/test_node_graph.py | 19 +-- tests/nodes/test_nodes.py | 8 +- 6 files changed, 136 insertions(+), 123 deletions(-) diff --git a/invokeai/app/api/routers/sessions.py b/invokeai/app/api/routers/sessions.py index 713b212294..67e3c840c0 100644 --- a/invokeai/app/api/routers/sessions.py +++ b/invokeai/app/api/routers/sessions.py @@ -10,6 +10,7 @@ from pydantic.fields import Field from ...invocations import * from ...invocations.baseinvocation import BaseInvocation from ...services.graph import ( + Edge, EdgeConnection, Graph, GraphExecutionState, @@ -92,7 +93,7 @@ async def get_session( async def add_node( session_id: str = Path(description="The id of the session"), node: Annotated[ - Union[BaseInvocation.get_invocations()], Field(discriminator="type") + Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore ] = Body(description="The node to add"), ) -> str: """Adds a node to the graph""" @@ -125,7 +126,7 @@ async def update_node( session_id: str = Path(description="The id of the session"), node_path: str = Path(description="The path to the node in the graph"), node: Annotated[ - Union[BaseInvocation.get_invocations()], Field(discriminator="type") + Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore ] = Body(description="The new node"), ) -> GraphExecutionState: """Updates a node in the graph and removes all linked edges""" @@ -186,7 +187,7 @@ async def delete_node( ) async def add_edge( session_id: str = Path(description="The id of the session"), - edge: tuple[EdgeConnection, EdgeConnection] = Body(description="The edge to add"), + edge: Edge = Body(description="The edge to add"), ) -> GraphExecutionState: """Adds an edge to the graph""" session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id) @@ -228,9 +229,9 @@ async def delete_edge( return Response(status_code=404) try: - edge = ( - EdgeConnection(node_id=from_node_id, field=from_field), - EdgeConnection(node_id=to_node_id, field=to_field), + edge = Edge( + source=EdgeConnection(node_id=from_node_id, field=from_field), + destination=EdgeConnection(node_id=to_node_id, field=to_field) ) session.delete_edge(edge) ApiDependencies.invoker.services.graph_execution_manager.set( diff --git a/invokeai/app/cli_app.py b/invokeai/app/cli_app.py index 732a233cb4..6390253250 100644 --- a/invokeai/app/cli_app.py +++ b/invokeai/app/cli_app.py @@ -19,7 +19,7 @@ from .invocations.baseinvocation import BaseInvocation from .services.events import EventServiceBase from .services.model_manager_initializer import get_model_manager from .services.restoration_services import RestorationServices -from .services.graph import EdgeConnection, GraphExecutionState +from .services.graph import Edge, EdgeConnection, GraphExecutionState from .services.image_storage import DiskImageStorage from .services.invocation_queue import MemoryInvocationQueue from .services.invocation_services import InvocationServices @@ -77,7 +77,7 @@ def get_command_parser() -> argparse.ArgumentParser: def generate_matching_edges( a: BaseInvocation, b: BaseInvocation -) -> list[tuple[EdgeConnection, EdgeConnection]]: +) -> list[Edge]: """Generates all possible edges between two invocations""" atype = type(a) btype = type(b) @@ -94,9 +94,9 @@ def generate_matching_edges( matching_fields = matching_fields.difference(invalid_fields) edges = [ - ( - EdgeConnection(node_id=a.id, field=field), - EdgeConnection(node_id=b.id, field=field), + Edge( + source=EdgeConnection(node_id=a.id, field=field), + destination=EdgeConnection(node_id=b.id, field=field) ) for field in matching_fields ] @@ -111,16 +111,15 @@ class SessionError(Exception): def invoke_all(context: CliContext): """Runs all invocations in the specified session""" context.invoker.invoke(context.session, invoke_all=True) - while not context.session.is_complete(): + while not context.get_session().is_complete(): # Wait some time - session = context.get_session() time.sleep(0.1) # Print any errors if context.session.has_error(): for n in context.session.errors: print( - f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {session.errors[n]}" + f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}" ) raise SessionError() @@ -203,7 +202,7 @@ def invoke_cli(): continue # Pipe previous command output (if there was a previous command) - edges = [] + edges: list[Edge] = list() if len(history) > 0 or current_id != start_id: from_id = ( history[0] if current_id == start_id else str(current_id - 1) @@ -225,19 +224,19 @@ def invoke_cli(): matching_edges = generate_matching_edges( link_node, command.command ) - matching_destinations = [e[1] for e in matching_edges] - edges = [e for e in edges if e[1] not in matching_destinations] + matching_destinations = [e.destination for e in matching_edges] + edges = [e for e in edges if e.destination not in matching_destinations] edges.extend(matching_edges) if "link" in args and args["link"]: for link in args["link"]: - edges = [e for e in edges if e[1].node_id != command.command.id and e[1].field != link[2]] + edges = [e for e in edges if e.destination.node_id != command.command.id and e.destination.field != link[2]] edges.append( - ( - EdgeConnection(node_id=link[1], field=link[0]), - EdgeConnection( + Edge( + source=EdgeConnection(node_id=link[1], field=link[0]), + destination=EdgeConnection( node_id=command.command.id, field=link[2] - ), + ) ) ) diff --git a/invokeai/app/services/graph.py b/invokeai/app/services/graph.py index aa462ab170..8134b47167 100644 --- a/invokeai/app/services/graph.py +++ b/invokeai/app/services/graph.py @@ -44,6 +44,11 @@ class EdgeConnection(BaseModel): return hash(f"{self.node_id}.{self.field}") +class Edge(BaseModel): + source: EdgeConnection = Field(description="The connection for the edge's from node and field") + destination: EdgeConnection = Field(description="The connection for the edge's to node and field") + + def get_output_field(node: BaseInvocation, field: str) -> Any: node_type = type(node) node_outputs = get_type_hints(node_type.get_output_type()) @@ -194,7 +199,7 @@ class Graph(BaseModel): nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field( description="The nodes in this graph", default_factory=dict ) - edges: list[tuple[EdgeConnection, EdgeConnection]] = Field( + edges: list[Edge] = Field( description="The connections between nodes and their fields in this graph", default_factory=list, ) @@ -251,7 +256,7 @@ class Graph(BaseModel): except NodeNotFoundError: pass # Ignore, not doesn't exist (should this throw?) - def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None: + def add_edge(self, edge: Edge) -> None: """Adds an edge to a graph :raises InvalidEdgeError: the provided edge is invalid. @@ -262,7 +267,7 @@ class Graph(BaseModel): else: raise InvalidEdgeError() - def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None: + def delete_edge(self, edge: Edge) -> None: """Deletes an edge from a graph""" try: @@ -280,7 +285,7 @@ class Graph(BaseModel): # Validate all edges reference nodes in the graph node_ids = set( - [e[0].node_id for e in self.edges] + [e[1].node_id for e in self.edges] + [e.source.node_id for e in self.edges] + [e.destination.node_id for e in self.edges] ) if not all((self.has_node(node_id) for node_id in node_ids)): return False @@ -294,10 +299,10 @@ class Graph(BaseModel): if not all( ( are_connections_compatible( - self.get_node(e[0].node_id), - e[0].field, - self.get_node(e[1].node_id), - e[1].field, + self.get_node(e.source.node_id), + e.source.field, + self.get_node(e.destination.node_id), + e.destination.field, ) for e in self.edges ) @@ -328,58 +333,58 @@ class Graph(BaseModel): return True - def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool: + def _is_edge_valid(self, edge: Edge) -> bool: """Validates that a new edge doesn't create a cycle in the graph""" # Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly) try: - from_node = self.get_node(edge[0].node_id) - to_node = self.get_node(edge[1].node_id) + from_node = self.get_node(edge.source.node_id) + to_node = self.get_node(edge.destination.node_id) except NodeNotFoundError: return False # Validate that an edge to this node+field doesn't already exist - input_edges = self._get_input_edges(edge[1].node_id, edge[1].field) + input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field) if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation): return False # Validate that no cycles would be created g = self.nx_graph_flat() - g.add_edge(edge[0].node_id, edge[1].node_id) + g.add_edge(edge.source.node_id, edge.destination.node_id) if not nx.is_directed_acyclic_graph(g): return False # Validate that the field types are compatible if not are_connections_compatible( - from_node, edge[0].field, to_node, edge[1].field + from_node, edge.source.field, to_node, edge.destination.field ): return False # Validate if iterator output type matches iterator input type (if this edge results in both being set) - if isinstance(to_node, IterateInvocation) and edge[1].field == "collection": + if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection": if not self._is_iterator_connection_valid( - edge[1].node_id, new_input=edge[0] + edge.destination.node_id, new_input=edge.source ): return False # Validate if iterator input type matches output type (if this edge results in both being set) - if isinstance(from_node, IterateInvocation) and edge[0].field == "item": + if isinstance(from_node, IterateInvocation) and edge.source.field == "item": if not self._is_iterator_connection_valid( - edge[0].node_id, new_output=edge[1] + edge.source.node_id, new_output=edge.destination ): return False # Validate if collector input type matches output type (if this edge results in both being set) - if isinstance(to_node, CollectInvocation) and edge[1].field == "item": + if isinstance(to_node, CollectInvocation) and edge.destination.field == "item": if not self._is_collector_connection_valid( - edge[1].node_id, new_input=edge[0] + edge.destination.node_id, new_input=edge.source ): return False # Validate if collector output type matches input type (if this edge results in both being set) - if isinstance(from_node, CollectInvocation) and edge[0].field == "collection": + if isinstance(from_node, CollectInvocation) and edge.source.field == "collection": if not self._is_collector_connection_valid( - edge[0].node_id, new_output=edge[1] + edge.source.node_id, new_output=edge.destination ): return False @@ -438,15 +443,15 @@ class Graph(BaseModel): # Remove the graph prefix from the node path new_graph_node_path = ( new_node.id - if "." not in edge[1].node_id - else f'{edge[1].node_id[edge[1].node_id.rindex("."):]}.{new_node.id}' + if "." not in edge.destination.node_id + else f'{edge.destination.node_id[edge.destination.node_id.rindex("."):]}.{new_node.id}' ) graph.add_edge( - ( - edge[0], - EdgeConnection( - node_id=new_graph_node_path, field=edge[1].field - ), + Edge( + source=edge.source, + destination=EdgeConnection( + node_id=new_graph_node_path, field=edge.destination.field + ) ) ) @@ -454,51 +459,51 @@ class Graph(BaseModel): # Remove the graph prefix from the node path new_graph_node_path = ( new_node.id - if "." not in edge[0].node_id - else f'{edge[0].node_id[edge[0].node_id.rindex("."):]}.{new_node.id}' + if "." not in edge.source.node_id + else f'{edge.source.node_id[edge.source.node_id.rindex("."):]}.{new_node.id}' ) graph.add_edge( - ( - EdgeConnection( - node_id=new_graph_node_path, field=edge[0].field + Edge( + source=EdgeConnection( + node_id=new_graph_node_path, field=edge.source.field ), - edge[1], + destination=edge.destination ) ) def _get_input_edges( self, node_path: str, field: Optional[str] = None - ) -> list[tuple[EdgeConnection, EdgeConnection]]: + ) -> list[Edge]: """Gets all input edges for a node""" edges = self._get_input_edges_and_graphs(node_path) # Filter to edges that match the field - filtered_edges = (e for e in edges if field is None or e[2][1].field == field) + filtered_edges = (e for e in edges if field is None or e[2].destination.field == field) # Create full node paths for each edge return [ - ( - EdgeConnection( - node_id=self._get_node_path(e[0].node_id, prefix=prefix), - field=e[0].field, - ), - EdgeConnection( - node_id=self._get_node_path(e[1].node_id, prefix=prefix), - field=e[1].field, + Edge( + source=EdgeConnection( + node_id=self._get_node_path(e.source.node_id, prefix=prefix), + field=e.source.field, ), + destination=EdgeConnection( + node_id=self._get_node_path(e.destination.node_id, prefix=prefix), + field=e.destination.field, + ) ) for _, prefix, e in filtered_edges ] def _get_input_edges_and_graphs( self, node_path: str, prefix: Optional[str] = None - ) -> list[tuple["Graph", str, tuple[EdgeConnection, EdgeConnection]]]: + ) -> list[tuple["Graph", str, Edge]]: """Gets all input edges for a node along with the graph they are in and the graph's path""" edges = list() # Return any input edges that appear in this graph edges.extend( - [(self, prefix, e) for e in self.edges if e[1].node_id == node_path] + [(self, prefix, e) for e in self.edges if e.destination.node_id == node_path] ) node_id = ( @@ -522,37 +527,37 @@ class Graph(BaseModel): def _get_output_edges( self, node_path: str, field: str - ) -> list[tuple[EdgeConnection, EdgeConnection]]: + ) -> list[Edge]: """Gets all output edges for a node""" edges = self._get_output_edges_and_graphs(node_path) # Filter to edges that match the field - filtered_edges = (e for e in edges if e[2][0].field == field) + filtered_edges = (e for e in edges if e[2].source.field == field) # Create full node paths for each edge return [ - ( - EdgeConnection( - node_id=self._get_node_path(e[0].node_id, prefix=prefix), - field=e[0].field, - ), - EdgeConnection( - node_id=self._get_node_path(e[1].node_id, prefix=prefix), - field=e[1].field, + Edge( + source=EdgeConnection( + node_id=self._get_node_path(e.source.node_id, prefix=prefix), + field=e.source.field, ), + destination=EdgeConnection( + node_id=self._get_node_path(e.destination.node_id, prefix=prefix), + field=e.destination.field, + ) ) for _, prefix, e in filtered_edges ] def _get_output_edges_and_graphs( self, node_path: str, prefix: Optional[str] = None - ) -> list[tuple["Graph", str, tuple[EdgeConnection, EdgeConnection]]]: + ) -> list[tuple["Graph", str, Edge]]: """Gets all output edges for a node along with the graph they are in and the graph's path""" edges = list() # Return any input edges that appear in this graph edges.extend( - [(self, prefix, e) for e in self.edges if e[0].node_id == node_path] + [(self, prefix, e) for e in self.edges if e.source.node_id == node_path] ) node_id = ( @@ -580,8 +585,8 @@ class Graph(BaseModel): new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None, ) -> bool: - inputs = list([e[0] for e in self._get_input_edges(node_path, "collection")]) - outputs = list([e[1] for e in self._get_output_edges(node_path, "item")]) + inputs = list([e.source for e in self._get_input_edges(node_path, "collection")]) + outputs = list([e.destination for e in self._get_output_edges(node_path, "item")]) if new_input is not None: inputs.append(new_input) @@ -622,8 +627,8 @@ class Graph(BaseModel): new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None, ) -> bool: - inputs = list([e[0] for e in self._get_input_edges(node_path, "item")]) - outputs = list([e[1] for e in self._get_output_edges(node_path, "collection")]) + inputs = list([e.source for e in self._get_input_edges(node_path, "item")]) + outputs = list([e.destination for e in self._get_output_edges(node_path, "collection")]) if new_input is not None: inputs.append(new_input) @@ -684,7 +689,7 @@ class Graph(BaseModel): # TODO: Cache this? g = nx.DiGraph() g.add_nodes_from([n for n in self.nodes.keys()]) - g.add_edges_from(set([(e[0].node_id, e[1].node_id) for e in self.edges])) + g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges])) return g def nx_graph_flat( @@ -711,7 +716,7 @@ class Graph(BaseModel): # TODO: figure out if iteration nodes need to be expanded - unique_edges = set([(e[0].node_id, e[1].node_id) for e in self.edges]) + unique_edges = set([(e.source.node_id, e.destination.node_id) for e in self.edges]) g.add_edges_from( [ (self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) @@ -841,13 +846,13 @@ class GraphExecutionState(BaseModel): input_collection_prepared_node_id = next( n[1] for n in iteration_node_map - if n[0] == input_collection_edge[0].node_id + if n[0] == input_collection_edge.source.node_id ) input_collection_prepared_node_output = self.results[ input_collection_prepared_node_id ] input_collection = getattr( - input_collection_prepared_node_output, input_collection_edge[0].field + input_collection_prepared_node_output, input_collection_edge.source.field ) self_iteration_count = len(input_collection) @@ -864,11 +869,11 @@ class GraphExecutionState(BaseModel): new_edges = list() for edge in input_edges: for input_node_id in ( - n[1] for n in iteration_node_map if n[0] == edge[0].node_id + n[1] for n in iteration_node_map if n[0] == edge.source.node_id ): - new_edge = ( - EdgeConnection(node_id=input_node_id, field=edge[0].field), - EdgeConnection(node_id="", field=edge[1].field), + new_edge = Edge( + source=EdgeConnection(node_id=input_node_id, field=edge.source.field), + destination=EdgeConnection(node_id="", field=edge.destination.field), ) new_edges.append(new_edge) @@ -893,9 +898,9 @@ class GraphExecutionState(BaseModel): # Add new edges to execution graph for edge in new_edges: - new_edge = ( - edge[0], - EdgeConnection(node_id=new_node.id, field=edge[1].field), + new_edge = Edge( + source=edge.source, + destination=EdgeConnection(node_id=new_node.id, field=edge.destination.field), ) self.execution_graph.add_edge(new_edge) @@ -1043,26 +1048,26 @@ class GraphExecutionState(BaseModel): return self.execution_graph.nodes[next_node] def _prepare_inputs(self, node: BaseInvocation): - input_edges = [e for e in self.execution_graph.edges if e[1].node_id == node.id] + input_edges = [e for e in self.execution_graph.edges if e.destination.node_id == node.id] if isinstance(node, CollectInvocation): output_collection = [ - getattr(self.results[edge[0].node_id], edge[0].field) + getattr(self.results[edge.source.node_id], edge.source.field) for edge in input_edges - if edge[1].field == "item" + if edge.destination.field == "item" ] setattr(node, "collection", output_collection) else: for edge in input_edges: - output_value = getattr(self.results[edge[0].node_id], edge[0].field) - setattr(node, edge[1].field, output_value) + output_value = getattr(self.results[edge.source.node_id], edge.source.field) + setattr(node, edge.destination.field, output_value) # TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state - def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool: + def _is_edge_valid(self, edge: Edge) -> bool: if not self._is_edge_valid(edge): return False # Invalid if destination has already been prepared or executed - if edge[1].node_id in self.source_prepared_mapping: + if edge.destination.node_id in self.source_prepared_mapping: return False # Otherwise, the edge is valid @@ -1089,17 +1094,17 @@ class GraphExecutionState(BaseModel): ) self.graph.delete_node(node_path) - def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None: - if not self._is_node_updatable(edge[1].node_id): + def add_edge(self, edge: Edge) -> None: + if not self._is_node_updatable(edge.destination.node_id): raise NodeAlreadyExecutedError( - f"Destination node {edge[1].node_id} has already been prepared or executed and cannot be linked to" + f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot be linked to" ) self.graph.add_edge(edge) - def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None: - if not self._is_node_updatable(edge[1].node_id): + def delete_edge(self, edge: Edge) -> None: + if not self._is_node_updatable(edge.destination.node_id): raise NodeAlreadyExecutedError( - f"Destination node {edge[1].node_id} has already been prepared or executed and cannot have a source edge deleted" + f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot have a source edge deleted" ) self.graph.delete_edge(edge) diff --git a/static/dream_web/test.html b/static/dream_web/test.html index e99abb3703..5fd6918d66 100644 --- a/static/dream_web/test.html +++ b/static/dream_web/test.html @@ -105,17 +105,20 @@ // Start building nodes var id = 1; - var initialNode = {"id": id.toString(), "type": "txt2img", "prompt": prompt, "sampler": sampler, "steps": steps, "seed": seed}; + var initialNode = {"id": id.toString(), "type": "txt2img", "prompt": prompt, "model": "stable-diffusion-1-5", "sampler": sampler, "steps": steps, "seed": seed}; + id++; + var i2iNode = {"id": id.toString(), "type": "img2img", "prompt": prompt, "model": "stable-diffusion-1-5", "sampler": sampler, "steps": steps, "seed": Math.floor(Math.random() * 10000)}; id++; var upscaleNode = {"id": id.toString(), "type": "show_image" }; id++ nodes = {}; nodes[initialNode.id] = initialNode; + nodes[i2iNode.id] = i2iNode; nodes[upscaleNode.id] = upscaleNode; links = [ - [{ "node_id": initialNode.id, field: "image" }, - { "node_id": upscaleNode.id, field: "image" }] + { "source": { "node_id": initialNode.id, field: "image" }, "destination": { "node_id": i2iNode.id, field: "image" }}, + { "source": { "node_id": i2iNode.id, field: "image" }, "destination": { "node_id": upscaleNode.id, field: "image" }} ]; // expandSize = 128; // for (var i = 0; i < 6; ++i) { diff --git a/tests/nodes/test_node_graph.py b/tests/nodes/test_node_graph.py index d432234aec..b864e1e47a 100644 --- a/tests/nodes/test_node_graph.py +++ b/tests/nodes/test_node_graph.py @@ -1,15 +1,18 @@ from invokeai.app.invocations.image import * from .test_nodes import ListPassThroughInvocation, PromptTestInvocation -from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation +from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation from invokeai.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation from invokeai.app.invocations.upscale import UpscaleInvocation import pytest # Helpers -def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> tuple[EdgeConnection, EdgeConnection]: - return (EdgeConnection(node_id = from_id, field = from_field), EdgeConnection(node_id = to_id, field = to_field)) +def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edge: + return Edge( + source=EdgeConnection(node_id = from_id, field = from_field), + destination=EdgeConnection(node_id = to_id, field = to_field) + ) # Tests def test_connections_are_compatible(): @@ -108,7 +111,7 @@ def test_graph_allows_non_conflicting_id_change(): assert g.get_node("3").prompt == "Banana sushi" assert len(g.edges) == 1 - assert (EdgeConnection(node_id = "3", field = "image"), EdgeConnection(node_id = "2", field = "image")) in g.edges + assert Edge(source=EdgeConnection(node_id = "3", field = "image"), destination=EdgeConnection(node_id = "2", field = "image")) in g.edges def test_graph_fails_to_update_node_id_if_conflict(): g = Graph() @@ -490,10 +493,10 @@ def test_graph_can_deserialize(): assert g2.nodes['1'] is not None assert g2.nodes['2'] is not None assert len(g2.edges) == 1 - assert g2.edges[0][0].node_id == '1' - assert g2.edges[0][0].field == 'image' - assert g2.edges[0][1].node_id == '2' - assert g2.edges[0][1].field == 'image' + assert g2.edges[0].source.node_id == '1' + assert g2.edges[0].source.field == 'image' + assert g2.edges[0].destination.node_id == '2' + assert g2.edges[0].destination.field == 'image' def test_graph_can_generate_schema(): # Not throwing on this line is sufficient diff --git a/tests/nodes/test_nodes.py b/tests/nodes/test_nodes.py index c3427ac03b..d16d67d815 100644 --- a/tests/nodes/test_nodes.py +++ b/tests/nodes/test_nodes.py @@ -64,10 +64,12 @@ class PromptCollectionTestInvocation(BaseInvocation): from invokeai.app.services.events import EventServiceBase -from invokeai.app.services.graph import EdgeConnection +from invokeai.app.services.graph import Edge, EdgeConnection -def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> tuple[EdgeConnection, EdgeConnection]: - return (EdgeConnection(node_id = from_id, field = from_field), EdgeConnection(node_id = to_id, field = to_field)) +def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edge: + return Edge( + source=EdgeConnection(node_id = from_id, field = from_field), + destination=EdgeConnection(node_id = to_id, field = to_field)) class TestEvent: