# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) import copy import itertools from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints import networkx as nx from pydantic import ( BaseModel, GetJsonSchemaHandler, ValidationError, field_validator, ) from pydantic.fields import Field from pydantic.json_schema import JsonSchemaValue from pydantic_core import CoreSchema # Importing * is bad karma but needed here for node detection from invokeai.app.invocations import * # noqa: F401 F403 from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, invocation, invocation_output, ) from invokeai.app.invocations.fields import Input, InputField, OutputField, UIType from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import uuid_string # in 3.10 this would be "from types import NoneType" NoneType = type(None) class EdgeConnection(BaseModel): node_id: str = Field(description="The id of the node for this edge connection") field: str = Field(description="The field for this connection") def __eq__(self, other): return ( isinstance(other, self.__class__) and getattr(other, "node_id", None) == self.node_id and getattr(other, "field", None) == self.field ) def __hash__(self): return hash(f"{self.node_id}.{self.field}") class Edge(BaseModel): source: EdgeConnection = Field(description="The connection for the edge's from node and field") destination: EdgeConnection = Field(description="The connection for the edge's to node and field") def get_output_field(node: BaseInvocation, field: str) -> Any: node_type = type(node) node_outputs = get_type_hints(node_type.get_output_annotation()) node_output_field = node_outputs.get(field) or None return node_output_field def get_input_field(node: BaseInvocation, field: str) -> Any: node_type = type(node) node_inputs = get_type_hints(node_type) node_input_field = node_inputs.get(field) or None return node_input_field def is_union_subtype(t1, t2): t1_args = get_args(t1) t2_args = get_args(t2) if not t1_args: # t1 is a single type return t1 in t2_args else: # t1 is a Union, check that all of its types are in t2_args return all(arg in t2_args for arg in t1_args) def is_list_or_contains_list(t): t_args = get_args(t) # If the type is a List if get_origin(t) is list: return True # If the type is a Union elif t_args: # Check if any of the types in the Union is a List for arg in t_args: if get_origin(arg) is list: return True return False def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool: if not from_type: return False if not to_type: return False # TODO: this is pretty forgiving on generic types. Clean that up (need to handle optionals and such) if from_type and to_type: # Ports are compatible if ( from_type == to_type or from_type == Any or to_type == Any or Any in get_args(from_type) or Any in get_args(to_type) ): return True if from_type in get_args(to_type): return True if to_type in get_args(from_type): return True # allow int -> float, pydantic will cast for us if from_type is int and to_type is float: return True # allow int|float -> str, pydantic will cast for us if (from_type is int or from_type is float) and to_type is str: return True # if not issubclass(from_type, to_type): if not is_union_subtype(from_type, to_type): return False else: return False return True def are_connections_compatible( from_node: BaseInvocation, from_field: str, to_node: BaseInvocation, to_field: str ) -> bool: """Determines if a connection between fields of two nodes is compatible.""" # TODO: handle iterators and collectors from_node_field = get_output_field(from_node, from_field) to_node_field = get_input_field(to_node, to_field) return are_connection_types_compatible(from_node_field, to_node_field) T = TypeVar("T") def copydeep(obj: T) -> T: """Deep-copies an object. If it is a pydantic model, use the model's copy method.""" if isinstance(obj, BaseModel): return obj.model_copy(deep=True) return copy.deepcopy(obj) class NodeAlreadyInGraphError(ValueError): pass class InvalidEdgeError(ValueError): pass class NodeNotFoundError(ValueError): pass class NodeAlreadyExecutedError(ValueError): pass class DuplicateNodeIdError(ValueError): pass class NodeFieldNotFoundError(ValueError): pass class NodeIdMismatchError(ValueError): pass class CyclicalGraphError(ValueError): pass class UnknownGraphValidationError(ValueError): pass class NodeInputError(ValueError): """Raised when a node fails preparation. This occurs when a node's inputs are being set from its incomers, but an input fails validation. Attributes: node: The node that failed preparation. Note: only successfully set fields will be accurate. Review the error to determine which field caused the failure. """ def __init__(self, node: BaseInvocation, e: ValidationError): self.original_error = e self.node = node # When preparing a node, we set each input one-at-a-time. We may thus safely assume that the first error # represents the first input that failed. self.failed_input = loc_to_dot_sep(e.errors()[0]["loc"]) super().__init__(f"Node {node.id} has invalid incoming input for {self.failed_input}") def loc_to_dot_sep(loc: tuple[Union[str, int], ...]) -> str: """Helper to pretty-print pydantic error locations as dot-separated strings. Taken from https://docs.pydantic.dev/latest/errors/errors/#customize-error-messages """ path = "" for i, x in enumerate(loc): if isinstance(x, str): if i > 0: path += "." path += x else: path += f"[{x}]" return path @invocation_output("iterate_output") class IterateInvocationOutput(BaseInvocationOutput): """Used to connect iteration outputs. Will be expanded to a specific output.""" item: Any = OutputField( description="The item being iterated over", title="Collection Item", ui_type=UIType._CollectionItem ) index: int = OutputField(description="The index of the item", title="Index") total: int = OutputField(description="The total number of items", title="Total") # TODO: Fill this out and move to invocations @invocation("iterate", version="1.1.0") class IterateInvocation(BaseInvocation): """Iterates over a list of items""" collection: list[Any] = InputField( description="The list of items to iterate over", default=[], ui_type=UIType._Collection ) index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True) def invoke(self, context: InvocationContext) -> IterateInvocationOutput: """Produces the outputs as values""" return IterateInvocationOutput(item=self.collection[self.index], index=self.index, total=len(self.collection)) @invocation_output("collect_output") class CollectInvocationOutput(BaseInvocationOutput): collection: list[Any] = OutputField( description="The collection of input items", title="Collection", ui_type=UIType._Collection ) @invocation("collect", version="1.0.0") class CollectInvocation(BaseInvocation): """Collects values into a collection""" item: Optional[Any] = InputField( default=None, description="The item to collect (all inputs must be of the same type)", ui_type=UIType._CollectionItem, title="Collection Item", input=Input.Connection, ) collection: list[Any] = InputField( description="The collection, will be provided on execution", default=[], ui_hidden=True ) def invoke(self, context: InvocationContext) -> CollectInvocationOutput: """Invoke with provided services and return outputs.""" return CollectInvocationOutput(collection=copy.copy(self.collection)) class Graph(BaseModel): id: str = Field(description="The id of this graph", default_factory=uuid_string) # TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me nodes: dict[str, BaseInvocation] = Field(description="The nodes in this graph", default_factory=dict) edges: list[Edge] = Field( description="The connections between nodes and their fields in this graph", default_factory=list, ) @field_validator("nodes", mode="plain") @classmethod def validate_nodes(cls, v: dict[str, Any]): """Validates the nodes in the graph by retrieving a union of all node types and validating each node.""" # Invocations register themselves as their python modules are executed. The union of all invocations is # constructed at runtime. We use pydantic to validate `Graph.nodes` using that union. # # It's possible that when `graph.py` is executed, not all invocation-containing modules will have executed. If # we construct the invocation union as `graph.py` is executed, we may miss some invocations. Those missing # invocations will cause a graph to fail if they are used. # # We can get around this by validating the nodes in the graph using a "plain" validator, which overrides the # pydantic validation entirely. This allows us to validate the nodes using the union of invocations at runtime. # # This same pattern is used in `GraphExecutionState`. nodes: dict[str, BaseInvocation] = {} typeadapter = BaseInvocation.get_typeadapter() for node_id, node in v.items(): nodes[node_id] = typeadapter.validate_python(node) return nodes @classmethod def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: # We use a "plain" validator to validate the nodes in the graph. Pydantic is unable to create a JSON Schema for # fields that use "plain" validators, so we have to hack around this. Also, we need to add all invocations to # the generated schema as options for the `nodes` field. # # The workaround is to create a new BaseModel that has the same fields as `Graph` but without the validator and # with the invocation union as the type for the `nodes` field. Pydantic then generates the JSON Schema as # expected. # # You might be tempted to do something like this: # # ```py # cloned_model = create_model(cls.__name__, __base__=cls, nodes=...) # delattr(cloned_model, "validate_nodes") # cloned_model.model_rebuild(force=True) # json_schema = handler(cloned_model.__pydantic_core_schema__) # ``` # # Unfortunately, this does not work. Calling `handler` here results in infinite recursion as pydantic attempts # to build the JSON Schema for the cloned model. Instead, we have to manually clone the model. # # This same pattern is used in `GraphExecutionState`. class Graph(BaseModel): id: Optional[str] = Field(default=None, description="The id of this graph") nodes: dict[ str, Annotated[Union[tuple(BaseInvocation._invocation_classes)], Field(discriminator="type")] ] = Field(description="The nodes in this graph") edges: list[Edge] = Field(description="The connections between nodes and their fields in this graph") json_schema = handler(Graph.__pydantic_core_schema__) json_schema = handler.resolve_ref_schema(json_schema) return json_schema def add_node(self, node: BaseInvocation) -> None: """Adds a node to a graph :raises NodeAlreadyInGraphError: the node is already present in the graph. """ if node.id in self.nodes: raise NodeAlreadyInGraphError() self.nodes[node.id] = node def delete_node(self, node_id: str) -> None: """Deletes a node from a graph""" try: # Delete edges for this node input_edges = self._get_input_edges(node_id) output_edges = self._get_output_edges(node_id) for edge in input_edges: self.delete_edge(edge) for edge in output_edges: self.delete_edge(edge) del self.nodes[node_id] except NodeNotFoundError: pass # Ignore, not doesn't exist (should this throw?) def add_edge(self, edge: Edge) -> None: """Adds an edge to a graph :raises InvalidEdgeError: the provided edge is invalid. """ self._validate_edge(edge) if edge not in self.edges: self.edges.append(edge) else: raise InvalidEdgeError() def delete_edge(self, edge: Edge) -> None: """Deletes an edge from a graph""" try: self.edges.remove(edge) except KeyError: pass def validate_self(self) -> None: """ Validates the graph. Raises an exception if the graph is invalid: - `DuplicateNodeIdError` - `NodeIdMismatchError` - `InvalidSubGraphError` - `NodeNotFoundError` - `NodeFieldNotFoundError` - `CyclicalGraphError` - `InvalidEdgeError` """ # Validate that all node ids are unique node_ids = [n.id for n in self.nodes.values()] duplicate_node_ids = {node_id for node_id in node_ids if node_ids.count(node_id) >= 2} if duplicate_node_ids: raise DuplicateNodeIdError(f"Node ids must be unique, found duplicates {duplicate_node_ids}") # Validate that all node ids match the keys in the nodes dict for k, v in self.nodes.items(): if k != v.id: raise NodeIdMismatchError(f"Node ids must match, got {k} and {v.id}") # Validate that all edges match nodes and fields in the graph for edge in self.edges: source_node = self.nodes.get(edge.source.node_id, None) if source_node is None: raise NodeNotFoundError(f"Edge source node {edge.source.node_id} does not exist in the graph") destination_node = self.nodes.get(edge.destination.node_id, None) if destination_node is None: raise NodeNotFoundError(f"Edge destination node {edge.destination.node_id} does not exist in the graph") # output fields are not on the node object directly, they are on the output type if edge.source.field not in source_node.get_output_annotation().model_fields: raise NodeFieldNotFoundError( f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}" ) # input fields are on the node if edge.destination.field not in destination_node.model_fields: raise NodeFieldNotFoundError( f"Edge destination field {edge.destination.field} does not exist in node {edge.destination.node_id}" ) # Validate there are no cycles g = self.nx_graph_flat() if not nx.is_directed_acyclic_graph(g): raise CyclicalGraphError("Graph contains cycles") # Validate all edge connections are valid for edge in self.edges: if not are_connections_compatible( self.get_node(edge.source.node_id), edge.source.field, self.get_node(edge.destination.node_id), edge.destination.field, ): raise InvalidEdgeError( f"Invalid edge from {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}" ) # Validate all iterators & collectors # TODO: may need to validate all iterators & collectors in subgraphs so edge connections in parent graphs will be available for node in self.nodes.values(): if isinstance(node, IterateInvocation) and not self._is_iterator_connection_valid(node.id): raise InvalidEdgeError(f"Invalid iterator node {node.id}") if isinstance(node, CollectInvocation) and not self._is_collector_connection_valid(node.id): raise InvalidEdgeError(f"Invalid collector node {node.id}") return None def is_valid(self) -> bool: """ Checks if the graph is valid. Raises `UnknownGraphValidationError` if there is a problem validating the graph (not a validation error). """ try: self.validate_self() return True except ( DuplicateNodeIdError, NodeIdMismatchError, NodeNotFoundError, NodeFieldNotFoundError, CyclicalGraphError, InvalidEdgeError, ): return False except Exception as e: raise UnknownGraphValidationError(f"Problem validating graph {e}") from e def _is_destination_field_Any(self, edge: Edge) -> bool: """Checks if the destination field for an edge is of type typing.Any""" return get_input_field(self.get_node(edge.destination.node_id), edge.destination.field) == Any def _is_destination_field_list_of_Any(self, edge: Edge) -> bool: """Checks if the destination field for an edge is of type typing.Any""" return get_input_field(self.get_node(edge.destination.node_id), edge.destination.field) == list[Any] def _validate_edge(self, edge: Edge): """Validates that a new edge doesn't create a cycle in the graph""" # Validate that the nodes exist try: from_node = self.get_node(edge.source.node_id) to_node = self.get_node(edge.destination.node_id) except NodeNotFoundError: raise InvalidEdgeError("One or both nodes don't exist: {edge.source.node_id} -> {edge.destination.node_id}") # Validate that an edge to this node+field doesn't already exist input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field) if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation): raise InvalidEdgeError( f"Edge to node {edge.destination.node_id} field {edge.destination.field} already exists" ) # Validate that no cycles would be created g = self.nx_graph_flat() g.add_edge(edge.source.node_id, edge.destination.node_id) if not nx.is_directed_acyclic_graph(g): raise InvalidEdgeError( f"Edge creates a cycle in the graph: {edge.source.node_id} -> {edge.destination.node_id}" ) # Validate that the field types are compatible if not are_connections_compatible(from_node, edge.source.field, to_node, edge.destination.field): raise InvalidEdgeError( f"Fields are incompatible: cannot connect {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}" ) # Validate if iterator output type matches iterator input type (if this edge results in both being set) if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection": if not self._is_iterator_connection_valid(edge.destination.node_id, new_input=edge.source): raise InvalidEdgeError( f"Iterator input type does not match iterator output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}" ) # Validate if iterator input type matches output type (if this edge results in both being set) if isinstance(from_node, IterateInvocation) and edge.source.field == "item": if not self._is_iterator_connection_valid(edge.source.node_id, new_output=edge.destination): raise InvalidEdgeError( f"Iterator output type does not match iterator input type:, {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}" ) # Validate if collector input type matches output type (if this edge results in both being set) if isinstance(to_node, CollectInvocation) and edge.destination.field == "item": if not self._is_collector_connection_valid(edge.destination.node_id, new_input=edge.source): raise InvalidEdgeError( f"Collector output type does not match collector input type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}" ) # Validate that we are not connecting collector to iterator (currently unsupported) if isinstance(from_node, CollectInvocation) and isinstance(to_node, IterateInvocation): raise InvalidEdgeError( f"Cannot connect collector to iterator: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}" ) # Validate if collector output type matches input type (if this edge results in both being set) - skip if the destination field is not Any or list[Any] if ( isinstance(from_node, CollectInvocation) and edge.source.field == "collection" and not self._is_destination_field_list_of_Any(edge) and not self._is_destination_field_Any(edge) ): if not self._is_collector_connection_valid(edge.source.node_id, new_output=edge.destination): raise InvalidEdgeError( f"Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}" ) def has_node(self, node_id: str) -> bool: """Determines whether or not a node exists in the graph.""" try: _ = self.get_node(node_id) return True except NodeNotFoundError: return False def get_node(self, node_id: str) -> BaseInvocation: """Gets a node from the graph.""" try: return self.nodes[node_id] except KeyError as e: raise NodeNotFoundError(f"Node {node_id} not found in graph") from e def update_node(self, node_id: str, new_node: BaseInvocation) -> None: """Updates a node in the graph.""" node = self.nodes[node_id] # Ensure the node type matches the new node if type(node) is not type(new_node): raise TypeError(f"Node {node_id} is type {type(node)} but new node is type {type(new_node)}") # Ensure the new id is either the same or is not in the graph if new_node.id != node.id and self.has_node(new_node.id): raise NodeAlreadyInGraphError(f"Node with id {new_node.id} already exists in graph") # Set the new node in the graph self.nodes[new_node.id] = new_node if new_node.id != node.id: input_edges = self._get_input_edges(node_id) output_edges = self._get_output_edges(node_id) # Delete node and all edges self.delete_node(node_id) # Create new edges for each input and output for edge in input_edges: self.add_edge( Edge( source=edge.source, destination=EdgeConnection(node_id=new_node.id, field=edge.destination.field), ) ) for edge in output_edges: self.add_edge( Edge( source=EdgeConnection(node_id=new_node.id, field=edge.source.field), destination=edge.destination, ) ) def _get_input_edges(self, node_id: str, field: Optional[str] = None) -> list[Edge]: """Gets all input edges for a node. If field is provided, only edges to that field are returned.""" edges = [e for e in self.edges if e.destination.node_id == node_id] if field is None: return edges filtered_edges = [e for e in edges if e.destination.field == field] return filtered_edges def _get_output_edges(self, node_id: str, field: Optional[str] = None) -> list[Edge]: """Gets all output edges for a node. If field is provided, only edges from that field are returned.""" edges = [e for e in self.edges if e.source.node_id == node_id] if field is None: return edges filtered_edges = [e for e in edges if e.source.field == field] return filtered_edges def _is_iterator_connection_valid( self, node_id: str, new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None, ) -> bool: inputs = [e.source for e in self._get_input_edges(node_id, "collection")] outputs = [e.destination for e in self._get_output_edges(node_id, "item")] if new_input is not None: inputs.append(new_input) if new_output is not None: outputs.append(new_output) # Only one input is allowed for iterators if len(inputs) > 1: return False # Get input and output fields (the fields linked to the iterator's input/output) input_field = get_output_field(self.get_node(inputs[0].node_id), inputs[0].field) output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs] # Input type must be a list if get_origin(input_field) != list: return False # Validate that all outputs match the input type input_field_item_type = get_args(input_field)[0] if not all((are_connection_types_compatible(input_field_item_type, f) for f in output_fields)): return False return True def _is_collector_connection_valid( self, node_id: str, new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None, ) -> bool: inputs = [e.source for e in self._get_input_edges(node_id, "item")] outputs = [e.destination for e in self._get_output_edges(node_id, "collection")] if new_input is not None: inputs.append(new_input) if new_output is not None: outputs.append(new_output) # Get input and output fields (the fields linked to the iterator's input/output) input_fields = [get_output_field(self.get_node(e.node_id), e.field) for e in inputs] output_fields = [get_input_field(self.get_node(e.node_id), e.field) for e in outputs] # Validate that all inputs are derived from or match a single type input_field_types = { t for input_field in input_fields for t in ([input_field] if get_origin(input_field) is None else get_args(input_field)) if t != NoneType } # Get unique types type_tree = nx.DiGraph() type_tree.add_nodes_from(input_field_types) type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])]) type_degrees = type_tree.in_degree(type_tree.nodes) if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore return False # There is more than one root type # Get the input root type input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore # Verify that all outputs are lists if not all(is_list_or_contains_list(f) for f in output_fields): return False # Verify that all outputs match the input type (are a base class or the same class) if not all( is_union_subtype(input_root_type, get_args(f)[0]) or issubclass(input_root_type, get_args(f)[0]) for f in output_fields ): return False return True def nx_graph(self) -> nx.DiGraph: """Returns a NetworkX DiGraph representing the layout of this graph""" # TODO: Cache this? g = nx.DiGraph() g.add_nodes_from(list(self.nodes.keys())) g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges}) return g def nx_graph_with_data(self) -> nx.DiGraph: """Returns a NetworkX DiGraph representing the data and layout of this graph""" g = nx.DiGraph() g.add_nodes_from(list(self.nodes.items())) g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges}) return g def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None) -> nx.DiGraph: """Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)""" g = nx_graph or nx.DiGraph() # Add all nodes from this graph except graph/iteration nodes g.add_nodes_from([n.id for n in self.nodes.values() if not isinstance(n, IterateInvocation)]) # TODO: figure out if iteration nodes need to be expanded unique_edges = {(e.source.node_id, e.destination.node_id) for e in self.edges} g.add_edges_from([(e[0], e[1]) for e in unique_edges]) return g class GraphExecutionState(BaseModel): """Tracks the state of a graph execution""" id: str = Field(description="The id of the execution state", default_factory=uuid_string) # TODO: Store a reference to the graph instead of the actual graph? graph: Graph = Field(description="The graph being executed") # The graph of materialized nodes execution_graph: Graph = Field( description="The expanded graph of activated and executed nodes", default_factory=Graph, ) # Nodes that have been executed executed: set[str] = Field(description="The set of node ids that have been executed", default_factory=set) executed_history: list[str] = Field( description="The list of node ids that have been executed, in order of execution", default_factory=list, ) # The results of executed nodes results: dict[str, BaseInvocationOutput] = Field(description="The results of node executions", default_factory=dict) # Errors raised when executing nodes errors: dict[str, str] = Field(description="Errors raised when executing nodes", default_factory=dict) # Map of prepared/executed nodes to their original nodes prepared_source_mapping: dict[str, str] = Field( description="The map of prepared nodes to original graph nodes", default_factory=dict, ) # Map of original nodes to prepared nodes source_prepared_mapping: dict[str, set[str]] = Field( description="The map of original graph nodes to prepared nodes", default_factory=dict, ) @field_validator("results", mode="plain") @classmethod def validate_results(cls, v: dict[str, BaseInvocationOutput]): """Validates the results in the GES by retrieving a union of all output types and validating each result.""" # See the comment in `Graph.validate_nodes` for an explanation of this logic. results: dict[str, BaseInvocationOutput] = {} typeadapter = BaseInvocationOutput.get_typeadapter() for result_id, result in v.items(): results[result_id] = typeadapter.validate_python(result) return results @field_validator("graph") def graph_is_valid(cls, v: Graph): """Validates that the graph is valid""" v.validate_self() return v @classmethod def __get_pydantic_json_schema__(cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler) -> JsonSchemaValue: # See the comment in `Graph.__get_pydantic_json_schema__` for an explanation of this logic. class GraphExecutionState(BaseModel): """Tracks the state of a graph execution""" id: str = Field(description="The id of the execution state") graph: Graph = Field(description="The graph being executed") execution_graph: Graph = Field(description="The expanded graph of activated and executed nodes") executed: set[str] = Field(description="The set of node ids that have been executed") executed_history: list[str] = Field( description="The list of node ids that have been executed, in order of execution" ) results: dict[ str, Annotated[Union[tuple(BaseInvocationOutput._output_classes)], Field(discriminator="type")] ] = Field(description="The results of node executions") errors: dict[str, str] = Field(description="Errors raised when executing nodes") prepared_source_mapping: dict[str, str] = Field( description="The map of prepared nodes to original graph nodes" ) source_prepared_mapping: dict[str, set[str]] = Field( description="The map of original graph nodes to prepared nodes" ) json_schema = handler(GraphExecutionState.__pydantic_core_schema__) json_schema = handler.resolve_ref_schema(json_schema) return json_schema def next(self) -> Optional[BaseInvocation]: """Gets the next node ready to execute.""" # TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes # possibly with a timeout? # If there are no prepared nodes, prepare some nodes next_node = self._get_next_node() if next_node is None: prepared_id = self._prepare() # Prepare as many nodes as we can while prepared_id is not None: prepared_id = self._prepare() next_node = self._get_next_node() # Get values from edges if next_node is not None: try: self._prepare_inputs(next_node) except ValidationError as e: raise NodeInputError(next_node, e) # If next is still none, there's no next node, return None return next_node def complete(self, node_id: str, output: BaseInvocationOutput) -> None: """Marks a node as complete""" if node_id not in self.execution_graph.nodes: return # TODO: log error? # Mark node as executed self.executed.add(node_id) self.results[node_id] = output # Check if source node is complete (all prepared nodes are complete) source_node = self.prepared_source_mapping[node_id] prepared_nodes = self.source_prepared_mapping[source_node] if all(n in self.executed for n in prepared_nodes): self.executed.add(source_node) self.executed_history.append(source_node) def set_node_error(self, node_id: str, error: str): """Marks a node as errored""" self.errors[node_id] = error def is_complete(self) -> bool: """Returns true if the graph is complete""" node_ids = set(self.graph.nx_graph_flat().nodes) return self.has_error() or all((k in self.executed for k in node_ids)) def has_error(self) -> bool: """Returns true if the graph has any errors""" return len(self.errors) > 0 def _create_execution_node(self, node_id: str, iteration_node_map: list[tuple[str, str]]) -> list[str]: """Prepares an iteration node and connects all edges, returning the new node id""" node = self.graph.get_node(node_id) self_iteration_count = -1 # If this is an iterator node, we must create a copy for each iteration if isinstance(node, IterateInvocation): # Get input collection edge (should error if there are no inputs) input_collection_edge = next(iter(self.graph._get_input_edges(node_id, "collection"))) input_collection_prepared_node_id = next( n[1] for n in iteration_node_map if n[0] == input_collection_edge.source.node_id ) input_collection_prepared_node_output = self.results[input_collection_prepared_node_id] input_collection = getattr(input_collection_prepared_node_output, input_collection_edge.source.field) self_iteration_count = len(input_collection) new_nodes: list[str] = [] if self_iteration_count == 0: # TODO: should this raise a warning? It might just happen if an empty collection is input, and should be valid. return new_nodes # Get all input edges input_edges = self.graph._get_input_edges(node_id) # Create new edges for this iteration # For collect nodes, this may contain multiple inputs to the same field new_edges: list[Edge] = [] for edge in input_edges: for input_node_id in (n[1] for n in iteration_node_map if n[0] == edge.source.node_id): new_edge = Edge( source=EdgeConnection(node_id=input_node_id, field=edge.source.field), destination=EdgeConnection(node_id="", field=edge.destination.field), ) new_edges.append(new_edge) # Create a new node (or one for each iteration of this iterator) for i in range(self_iteration_count) if self_iteration_count > 0 else [-1]: # Create a new node new_node = copy.deepcopy(node) # Create the node id (use a random uuid) new_node.id = uuid_string() # Set the iteration index for iteration invocations if isinstance(new_node, IterateInvocation): new_node.index = i # Add to execution graph self.execution_graph.add_node(new_node) self.prepared_source_mapping[new_node.id] = node_id if node_id not in self.source_prepared_mapping: self.source_prepared_mapping[node_id] = set() self.source_prepared_mapping[node_id].add(new_node.id) # Add new edges to execution graph for edge in new_edges: new_edge = Edge( source=edge.source, destination=EdgeConnection(node_id=new_node.id, field=edge.destination.field), ) self.execution_graph.add_edge(new_edge) new_nodes.append(new_node.id) return new_nodes def _iterator_graph(self) -> nx.DiGraph: """Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node""" g = self.graph.nx_graph_flat() collectors = (n for n in self.graph.nodes if isinstance(self.graph.get_node(n), CollectInvocation)) for c in collectors: g.remove_edges_from(list(g.in_edges(c))) return g def _get_node_iterators(self, node_id: str) -> list[str]: """Gets iterators for a node""" g = self._iterator_graph() iterators = [n for n in nx.ancestors(g, node_id) if isinstance(self.graph.get_node(n), IterateInvocation)] return iterators def _prepare(self) -> Optional[str]: # Get flattened source graph g = self.graph.nx_graph_flat() # Find next node that: # - was not already prepared # - is not an iterate node whose inputs have not been executed # - does not have an unexecuted iterate ancestor sorted_nodes = nx.topological_sort(g) next_node_id = next( ( n for n in sorted_nodes # exclude nodes that have already been prepared if n not in self.source_prepared_mapping # exclude iterate nodes whose inputs have not been executed and not ( isinstance(self.graph.get_node(n), IterateInvocation) # `n` is an iterate node... and not all((e[0] in self.executed for e in g.in_edges(n))) # ...that has unexecuted inputs ) # exclude nodes who have unexecuted iterate ancestors and not any( ( isinstance(self.graph.get_node(a), IterateInvocation) # `a` is an iterate ancestor of `n`... and a not in self.executed # ...that is not executed for a in nx.ancestors(g, n) # for all ancestors `a` of node `n` ) ) ), None, ) if next_node_id is None: return None # Get all parents of the next node next_node_parents = [e[0] for e in g.in_edges(next_node_id)] # Create execution nodes next_node = self.graph.get_node(next_node_id) new_node_ids = [] if isinstance(next_node, CollectInvocation): # Collapse all iterator input mappings and create a single execution node for the collect invocation all_iteration_mappings = list( itertools.chain(*(((s, p) for p in self.source_prepared_mapping[s]) for s in next_node_parents)) ) # all_iteration_mappings = list(set(itertools.chain(*prepared_parent_mappings))) create_results = self._create_execution_node(next_node_id, all_iteration_mappings) if create_results is not None: new_node_ids.extend(create_results) else: # Iterators or normal nodes # Get all iterator combinations for this node # Will produce a list of lists of prepared iterator nodes, from which results can be iterated iterator_nodes = self._get_node_iterators(next_node_id) iterator_nodes_prepared = [list(self.source_prepared_mapping[n]) for n in iterator_nodes] iterator_node_prepared_combinations = list(itertools.product(*iterator_nodes_prepared)) # Select the correct prepared parents for each iteration # For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator # TODO: Handle a node mapping to none eg = self.execution_graph.nx_graph_flat() prepared_parent_mappings = [ [(n, self._get_iteration_node(n, g, eg, it)) for n in next_node_parents] for it in iterator_node_prepared_combinations ] # type: ignore # Create execution node for each iteration for iteration_mappings in prepared_parent_mappings: create_results = self._create_execution_node(next_node_id, iteration_mappings) # type: ignore if create_results is not None: new_node_ids.extend(create_results) return next(iter(new_node_ids), None) def _get_iteration_node( self, source_node_id: str, graph: nx.DiGraph, execution_graph: nx.DiGraph, prepared_iterator_nodes: list[str], ) -> Optional[str]: """Gets the prepared version of the specified source node that matches every iteration specified""" prepared_nodes = self.source_prepared_mapping[source_node_id] if len(prepared_nodes) == 1: return next(iter(prepared_nodes)) # Check if the requested node is an iterator prepared_iterator = next((n for n in prepared_nodes if n in prepared_iterator_nodes), None) if prepared_iterator is not None: return prepared_iterator # Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source) iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes] parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_id)] return next( (n for n in prepared_nodes if all(nx.has_path(execution_graph, pit[0], n) for pit in parent_iterators)), None, ) def _get_next_node(self) -> Optional[BaseInvocation]: """Gets the deepest node that is ready to be executed""" g = self.execution_graph.nx_graph() # Perform a topological sort using depth-first search topo_order = list(nx.dfs_postorder_nodes(g)) # Get all IterateInvocation nodes iterate_nodes = [n for n in topo_order if isinstance(self.execution_graph.nodes[n], IterateInvocation)] # Sort the IterateInvocation nodes based on their index attribute iterate_nodes.sort(key=lambda x: self.execution_graph.nodes[x].index) # Prioritize IterateInvocation nodes and their children for iterate_node in iterate_nodes: if iterate_node not in self.executed and all((e[0] in self.executed for e in g.in_edges(iterate_node))): return self.execution_graph.nodes[iterate_node] # Check the children of the IterateInvocation node for child_node in nx.dfs_postorder_nodes(g, iterate_node): if child_node not in self.executed and all((e[0] in self.executed for e in g.in_edges(child_node))): return self.execution_graph.nodes[child_node] # If no IterateInvocation node or its children are ready, return the first ready node in the topological order for node in topo_order: if node not in self.executed and all((e[0] in self.executed for e in g.in_edges(node))): return self.execution_graph.nodes[node] # If no node is found, return None return None def _prepare_inputs(self, node: BaseInvocation): input_edges = [e for e in self.execution_graph.edges if e.destination.node_id == node.id] # Inputs must be deep-copied, else if a node mutates the object, other nodes that get the same input # will see the mutation. if isinstance(node, CollectInvocation): output_collection = [ copydeep(getattr(self.results[edge.source.node_id], edge.source.field)) for edge in input_edges if edge.destination.field == "item" ] node.collection = output_collection else: for edge in input_edges: setattr( node, edge.destination.field, copydeep(getattr(self.results[edge.source.node_id], edge.source.field)), ) # TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state def _is_edge_valid(self, edge: Edge) -> bool: try: self.graph._validate_edge(edge) except InvalidEdgeError: return False # Invalid if destination has already been prepared or executed if edge.destination.node_id in self.source_prepared_mapping: return False # Otherwise, the edge is valid return True def _is_node_updatable(self, node_id: str) -> bool: # The node is updatable as long as it hasn't been prepared or executed return node_id not in self.source_prepared_mapping def add_node(self, node: BaseInvocation) -> None: self.graph.add_node(node) def update_node(self, node_id: str, new_node: BaseInvocation) -> None: if not self._is_node_updatable(node_id): raise NodeAlreadyExecutedError( f"Node {node_id} has already been prepared or executed and cannot be updated" ) self.graph.update_node(node_id, new_node) def delete_node(self, node_id: str) -> None: if not self._is_node_updatable(node_id): raise NodeAlreadyExecutedError( f"Node {node_id} has already been prepared or executed and cannot be deleted" ) self.graph.delete_node(node_id) def add_edge(self, edge: Edge) -> None: if not self._is_node_updatable(edge.destination.node_id): raise NodeAlreadyExecutedError( f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot be linked to" ) self.graph.add_edge(edge) def delete_edge(self, edge: Edge) -> None: if not self._is_node_updatable(edge.destination.node_id): raise NodeAlreadyExecutedError( f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot have a source edge deleted" ) self.graph.delete_edge(edge)