diff --git a/invokeai/app/services/graph.py b/invokeai/app/services/graph.py index 9dccd14026..ffe737024d 100644 --- a/invokeai/app/services/graph.py +++ b/invokeai/app/services/graph.py @@ -170,6 +170,18 @@ class NodeIdMismatchError(ValueError): pass +class InvalidSubGraphError(ValueError): + pass + + +class CyclicalGraphError(ValueError): + pass + + +class UnknownGraphValidationError(ValueError): + pass + + # TODO: Create and use an Empty output? @invocation_output("graph_output") class GraphInvocationOutput(BaseInvocationOutput): @@ -254,59 +266,6 @@ class Graph(BaseModel): default_factory=list, ) - @root_validator - def validate_nodes_and_edges(cls, values): - """Validates that all edges match nodes in the graph""" - nodes = cast(Optional[dict[str, BaseInvocation]], values.get("nodes")) - edges = cast(Optional[list[Edge]], values.get("edges")) - - if nodes is not None: - # Validate that all node ids are unique - node_ids = [n.id for n in nodes.values()] - duplicate_node_ids = set([node_id for node_id in node_ids if node_ids.count(node_id) >= 2]) - if duplicate_node_ids: - raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}") - - # Validate that all node ids match the keys in the nodes dict - for k, v in nodes.items(): - if k != v.id: - raise NodeIdMismatchError(f"Node ids must match, got {k} and {v.id}") - - if edges is not None and nodes is not None: - # Validate that all edges match nodes in the graph - node_ids = set([e.source.node_id for e in edges] + [e.destination.node_id for e in edges]) - missing_node_ids = [node_id for node_id in node_ids if node_id not in nodes] - if missing_node_ids: - raise NodeNotFoundError( - f"All edges must reference nodes in the graph, missing nodes: {missing_node_ids}" - ) - - # Validate that all edge fields match node fields in the graph - for edge in edges: - source_node = nodes.get(edge.source.node_id, None) - if source_node is None: - raise NodeFieldNotFoundError(f"Edge source node {edge.source.node_id} does not exist in the graph") - - destination_node = nodes.get(edge.destination.node_id, None) - if destination_node is None: - raise NodeFieldNotFoundError( - f"Edge destination node {edge.destination.node_id} does not exist in the graph" - ) - - # output fields are not on the node object directly, they are on the output type - if edge.source.field not in source_node.get_output_type().__fields__: - raise NodeFieldNotFoundError( - f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}" - ) - - # input fields are on the node - if edge.destination.field not in destination_node.__fields__: - raise NodeFieldNotFoundError( - f"Edge destination field {edge.destination.field} does not exist in node {edge.destination.node_id}" - ) - - return values - def add_node(self, node: BaseInvocation) -> None: """Adds a node to a graph @@ -377,53 +336,108 @@ class Graph(BaseModel): except KeyError: pass - def is_valid(self) -> bool: - """Validates the graph.""" + def validate_self(self) -> None: + """ + Validates the graph. + + Raises an exception if the graph is invalid: + - `DuplicateNodeIdError` + - `NodeIdMismatchError` + - `InvalidSubGraphError` + - `NodeNotFoundError` + - `NodeFieldNotFoundError` + - `CyclicalGraphError` + - `InvalidEdgeError` + """ + + # Validate that all node ids are unique + node_ids = [n.id for n in self.nodes.values()] + duplicate_node_ids = set([node_id for node_id in node_ids if node_ids.count(node_id) >= 2]) + if duplicate_node_ids: + raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}") + + # Validate that all node ids match the keys in the nodes dict + for k, v in self.nodes.items(): + if k != v.id: + raise NodeIdMismatchError(f"Node ids must match, got {k} and {v.id}") # Validate all subgraphs for gn in (n for n in self.nodes.values() if isinstance(n, GraphInvocation)): - if not gn.graph.is_valid(): - return False + try: + gn.graph.validate_self() + except Exception as e: + raise InvalidSubGraphError(f"Subgraph {gn.id} is invalid") from e - # Validate all edges reference nodes in the graph - node_ids = set([e.source.node_id for e in self.edges] + [e.destination.node_id for e in self.edges]) - if not all((self.has_node(node_id) for node_id in node_ids)): - return False + # Validate that all edges match nodes and fields in the graph + for edge in self.edges: + source_node = self.nodes.get(edge.source.node_id, None) + if source_node is None: + raise NodeNotFoundError(f"Edge source node {edge.source.node_id} does not exist in the graph") + + destination_node = self.nodes.get(edge.destination.node_id, None) + if destination_node is None: + raise NodeNotFoundError(f"Edge destination node {edge.destination.node_id} does not exist in the graph") + + # output fields are not on the node object directly, they are on the output type + if edge.source.field not in source_node.get_output_type().__fields__: + raise NodeFieldNotFoundError( + f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}" + ) + + # input fields are on the node + if edge.destination.field not in destination_node.__fields__: + raise NodeFieldNotFoundError( + f"Edge destination field {edge.destination.field} does not exist in node {edge.destination.node_id}" + ) # Validate there are no cycles g = self.nx_graph_flat() if not nx.is_directed_acyclic_graph(g): - return False + raise CyclicalGraphError("Graph contains cycles") # Validate all edge connections are valid - if not all( - ( - are_connections_compatible( - self.get_node(e.source.node_id), - e.source.field, - self.get_node(e.destination.node_id), - e.destination.field, + for e in self.edges: + if not are_connections_compatible( + self.get_node(e.source.node_id), + e.source.field, + self.get_node(e.destination.node_id), + e.destination.field, + ): + raise InvalidEdgeError( + f"Invalid edge from {e.source.node_id}.{e.source.field} to {e.destination.node_id}.{e.destination.field}" ) - for e in self.edges - ) + + # Validate all iterators & collectors + # TODO: may need to validate all iterators & collectors in subgraphs so edge connections in parent graphs will be available + for n in self.nodes.values(): + if isinstance(n, IterateInvocation) and not self._is_iterator_connection_valid(n.id): + raise InvalidEdgeError(f"Invalid iterator node {n.id}") + if isinstance(n, CollectInvocation) and not self._is_collector_connection_valid(n.id): + raise InvalidEdgeError(f"Invalid collector node {n.id}") + + return None + + def is_valid(self) -> bool: + """ + Checks if the graph is valid. + + Raises `UnknownGraphValidationError` if there is a problem validating the graph (not a validation error). + """ + try: + self.validate_self() + return True + except ( + DuplicateNodeIdError, + NodeIdMismatchError, + InvalidSubGraphError, + NodeNotFoundError, + NodeFieldNotFoundError, + CyclicalGraphError, + InvalidEdgeError, ): return False - - # Validate all iterators - # TODO: may need to validate all iterators in subgraphs so edge connections in parent graphs will be available - if not all( - (self._is_iterator_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, IterateInvocation)) - ): - return False - - # Validate all collectors - # TODO: may need to validate all collectors in subgraphs so edge connections in parent graphs will be available - if not all( - (self._is_collector_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, CollectInvocation)) - ): - return False - - return True + except Exception as e: + raise UnknownGraphValidationError(f"Problem validating graph {e}") from e def _validate_edge(self, edge: Edge): """Validates that a new edge doesn't create a cycle in the graph""" @@ -804,6 +818,12 @@ class GraphExecutionState(BaseModel): default_factory=dict, ) + @validator("graph") + def graph_is_valid(cls, v: Graph): + """Validates that the graph is valid""" + v.validate_self() + return v + class Config: schema_extra = { "required": [ diff --git a/invokeai/app/services/session_queue/session_queue_common.py b/invokeai/app/services/session_queue/session_queue_common.py index 905e568fa1..a1eada6523 100644 --- a/invokeai/app/services/session_queue/session_queue_common.py +++ b/invokeai/app/services/session_queue/session_queue_common.py @@ -123,6 +123,11 @@ class Batch(BaseModel): raise NodeNotFoundError(f"Field {batch_data.field_name} not found in node {batch_data.node_path}") return values + @validator("graph") + def validate_graph(cls, v: Graph): + v.validate_self() + return v + class Config: schema_extra = { "required": [