# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) import copy import itertools import traceback import uuid from types import NoneType from typing import ( Annotated, Any, Literal, Optional, Union, get_args, get_origin, get_type_hints, ) import networkx as nx from pydantic import BaseModel, validator from pydantic.fields import Field from ..invocations import * from ..invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, InvocationContext, ) from .invocation_services import InvocationServices class EdgeConnection(BaseModel): node_id: str = Field(description="The id of the node for this edge connection") field: str = Field(description="The field for this connection") def __eq__(self, other): return ( isinstance(other, self.__class__) and getattr(other, "node_id", None) == self.node_id and getattr(other, "field", None) == self.field ) def __hash__(self): return hash(f"{self.node_id}.{self.field}") class Edge(BaseModel): source: EdgeConnection = Field(description="The connection for the edge's from node and field") destination: EdgeConnection = Field(description="The connection for the edge's to node and field") def get_output_field(node: BaseInvocation, field: str) -> Any: node_type = type(node) node_outputs = get_type_hints(node_type.get_output_type()) node_output_field = node_outputs.get(field) or None return node_output_field def get_input_field(node: BaseInvocation, field: str) -> Any: node_type = type(node) node_inputs = get_type_hints(node_type) node_input_field = node_inputs.get(field) or None return node_input_field def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool: if not from_type: return False if not to_type: return False # TODO: this is pretty forgiving on generic types. Clean that up (need to handle optionals and such) if from_type and to_type: # Ports are compatible if ( from_type == to_type or from_type == Any or to_type == Any or Any in get_args(from_type) or Any in get_args(to_type) ): return True if from_type in get_args(to_type): return True if to_type in get_args(from_type): return True if not issubclass(from_type, to_type): return False else: return False return True def are_connections_compatible( from_node: BaseInvocation, from_field: str, to_node: BaseInvocation, to_field: str ) -> bool: """Determines if a connection between fields of two nodes is compatible.""" # TODO: handle iterators and collectors from_node_field = get_output_field(from_node, from_field) to_node_field = get_input_field(to_node, to_field) return are_connection_types_compatible(from_node_field, to_node_field) class NodeAlreadyInGraphError(Exception): pass class InvalidEdgeError(Exception): pass class NodeNotFoundError(Exception): pass class NodeAlreadyExecutedError(Exception): pass # TODO: Create and use an Empty output? class GraphInvocationOutput(BaseInvocationOutput): type: Literal["graph_output"] = "graph_output" class Config: schema_extra = { 'required': [ 'type', 'image', ] } # TODO: Fill this out and move to invocations class GraphInvocation(BaseInvocation): type: Literal["graph"] = "graph" # TODO: figure out how to create a default here graph: "Graph" = Field(description="The graph to run", default=None) def invoke(self, context: InvocationContext) -> GraphInvocationOutput: """Invoke with provided services and return outputs.""" return GraphInvocationOutput() class IterateInvocationOutput(BaseInvocationOutput): """Used to connect iteration outputs. Will be expanded to a specific output.""" type: Literal["iterate_output"] = "iterate_output" item: Any = Field(description="The item being iterated over") class Config: schema_extra = { 'required': [ 'type', 'item', ] } # TODO: Fill this out and move to invocations class IterateInvocation(BaseInvocation): type: Literal["iterate"] = "iterate" collection: list[Any] = Field( description="The list of items to iterate over", default_factory=list ) index: int = Field( description="The index, will be provided on executed iterators", default=0 ) def invoke(self, context: InvocationContext) -> IterateInvocationOutput: """Produces the outputs as values""" return IterateInvocationOutput(item=self.collection[self.index]) class CollectInvocationOutput(BaseInvocationOutput): type: Literal["collect_output"] = "collect_output" collection: list[Any] = Field(description="The collection of input items") class Config: schema_extra = { 'required': [ 'type', 'collection', ] } class CollectInvocation(BaseInvocation): """Collects values into a collection""" type: Literal["collect"] = "collect" item: Any = Field( description="The item to collect (all inputs must be of the same type)", default=None, ) collection: list[Any] = Field( description="The collection, will be provided on execution", default_factory=list, ) def invoke(self, context: InvocationContext) -> CollectInvocationOutput: """Invoke with provided services and return outputs.""" return CollectInvocationOutput(collection=copy.copy(self.collection)) InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore class Graph(BaseModel): id: str = Field(description="The id of this graph", default_factory=uuid.uuid4) # TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field( description="The nodes in this graph", default_factory=dict ) edges: list[Edge] = Field( description="The connections between nodes and their fields in this graph", default_factory=list, ) def add_node(self, node: BaseInvocation) -> None: """Adds a node to a graph :raises NodeAlreadyInGraphError: the node is already present in the graph. """ if node.id in self.nodes: raise NodeAlreadyInGraphError() self.nodes[node.id] = node def _get_graph_and_node(self, node_path: str) -> tuple["Graph", str]: """Returns the graph and node id for a node path.""" # Materialized graphs may have nodes at the top level if node_path in self.nodes: return (self, node_path) node_id = ( node_path if "." not in node_path else node_path[: node_path.index(".")] ) if node_id not in self.nodes: raise NodeNotFoundError(f"Node {node_path} not found in graph") node = self.nodes[node_id] if not isinstance(node, GraphInvocation): # There's more node path left but this isn't a graph - failure raise NodeNotFoundError("Node path terminated early at a non-graph node") return node.graph._get_graph_and_node(node_path[node_path.index(".") + 1 :]) def delete_node(self, node_path: str) -> None: """Deletes a node from a graph""" try: graph, node_id = self._get_graph_and_node(node_path) # Delete edges for this node input_edges = self._get_input_edges_and_graphs(node_path) output_edges = self._get_output_edges_and_graphs(node_path) for edge_graph, _, edge in input_edges: edge_graph.delete_edge(edge) for edge_graph, _, edge in output_edges: edge_graph.delete_edge(edge) del graph.nodes[node_id] except NodeNotFoundError: pass # Ignore, not doesn't exist (should this throw?) def add_edge(self, edge: Edge) -> None: """Adds an edge to a graph :raises InvalidEdgeError: the provided edge is invalid. """ if self._is_edge_valid(edge) and edge not in self.edges: self.edges.append(edge) else: raise InvalidEdgeError() def delete_edge(self, edge: Edge) -> None: """Deletes an edge from a graph""" try: self.edges.remove(edge) except KeyError: pass def is_valid(self) -> bool: """Validates the graph.""" # Validate all subgraphs for gn in (n for n in self.nodes.values() if isinstance(n, GraphInvocation)): if not gn.graph.is_valid(): return False # Validate all edges reference nodes in the graph node_ids = set( [e.source.node_id for e in self.edges] + [e.destination.node_id for e in self.edges] ) if not all((self.has_node(node_id) for node_id in node_ids)): return False # Validate there are no cycles g = self.nx_graph_flat() if not nx.is_directed_acyclic_graph(g): return False # Validate all edge connections are valid if not all( ( are_connections_compatible( self.get_node(e.source.node_id), e.source.field, self.get_node(e.destination.node_id), e.destination.field, ) for e in self.edges ) ): return False # Validate all iterators # TODO: may need to validate all iterators in subgraphs so edge connections in parent graphs will be available if not all( ( self._is_iterator_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, IterateInvocation) ) ): return False # Validate all collectors # TODO: may need to validate all collectors in subgraphs so edge connections in parent graphs will be available if not all( ( self._is_collector_connection_valid(n.id) for n in self.nodes.values() if isinstance(n, CollectInvocation) ) ): return False return True def _is_edge_valid(self, edge: Edge) -> bool: """Validates that a new edge doesn't create a cycle in the graph""" # Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly) try: from_node = self.get_node(edge.source.node_id) to_node = self.get_node(edge.destination.node_id) except NodeNotFoundError: return False # Validate that an edge to this node+field doesn't already exist input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field) if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation): return False # Validate that no cycles would be created g = self.nx_graph_flat() g.add_edge(edge.source.node_id, edge.destination.node_id) if not nx.is_directed_acyclic_graph(g): return False # Validate that the field types are compatible if not are_connections_compatible( from_node, edge.source.field, to_node, edge.destination.field ): return False # Validate if iterator output type matches iterator input type (if this edge results in both being set) if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection": if not self._is_iterator_connection_valid( edge.destination.node_id, new_input=edge.source ): return False # Validate if iterator input type matches output type (if this edge results in both being set) if isinstance(from_node, IterateInvocation) and edge.source.field == "item": if not self._is_iterator_connection_valid( edge.source.node_id, new_output=edge.destination ): return False # Validate if collector input type matches output type (if this edge results in both being set) if isinstance(to_node, CollectInvocation) and edge.destination.field == "item": if not self._is_collector_connection_valid( edge.destination.node_id, new_input=edge.source ): return False # Validate if collector output type matches input type (if this edge results in both being set) if isinstance(from_node, CollectInvocation) and edge.source.field == "collection": if not self._is_collector_connection_valid( edge.source.node_id, new_output=edge.destination ): return False return True def has_node(self, node_path: str) -> bool: """Determines whether or not a node exists in the graph.""" try: n = self.get_node(node_path) if n is not None: return True else: return False except NodeNotFoundError: return False def get_node(self, node_path: str) -> InvocationsUnion: """Gets a node from the graph using a node path.""" # Materialized graphs may have nodes at the top level graph, node_id = self._get_graph_and_node(node_path) return graph.nodes[node_id] def _get_node_path(self, node_id: str, prefix: Optional[str] = None) -> str: return node_id if prefix is None or prefix == "" else f"{prefix}.{node_id}" def update_node(self, node_path: str, new_node: BaseInvocation) -> None: """Updates a node in the graph.""" graph, node_id = self._get_graph_and_node(node_path) node = graph.nodes[node_id] # Ensure the node type matches the new node if type(node) != type(new_node): raise TypeError( f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}" ) # Ensure the new id is either the same or is not in the graph prefix = None if "." not in node_path else node_path[: node_path.rindex(".")] new_path = self._get_node_path(new_node.id, prefix=prefix) if new_node.id != node.id and self.has_node(new_path): raise NodeAlreadyInGraphError( "Node with id {new_node.id} already exists in graph" ) # Set the new node in the graph graph.nodes[new_node.id] = new_node if new_node.id != node.id: input_edges = self._get_input_edges_and_graphs(node_path) output_edges = self._get_output_edges_and_graphs(node_path) # Delete node and all edges graph.delete_node(node_path) # Create new edges for each input and output for graph, _, edge in input_edges: # Remove the graph prefix from the node path new_graph_node_path = ( new_node.id if "." not in edge.destination.node_id else f'{edge.destination.node_id[edge.destination.node_id.rindex("."):]}.{new_node.id}' ) graph.add_edge( Edge( source=edge.source, destination=EdgeConnection( node_id=new_graph_node_path, field=edge.destination.field ) ) ) for graph, _, edge in output_edges: # Remove the graph prefix from the node path new_graph_node_path = ( new_node.id if "." not in edge.source.node_id else f'{edge.source.node_id[edge.source.node_id.rindex("."):]}.{new_node.id}' ) graph.add_edge( Edge( source=EdgeConnection( node_id=new_graph_node_path, field=edge.source.field ), destination=edge.destination ) ) def _get_input_edges( self, node_path: str, field: Optional[str] = None ) -> list[Edge]: """Gets all input edges for a node""" edges = self._get_input_edges_and_graphs(node_path) # Filter to edges that match the field filtered_edges = (e for e in edges if field is None or e[2].destination.field == field) # Create full node paths for each edge return [ Edge( source=EdgeConnection( node_id=self._get_node_path(e.source.node_id, prefix=prefix), field=e.source.field, ), destination=EdgeConnection( node_id=self._get_node_path(e.destination.node_id, prefix=prefix), field=e.destination.field, ) ) for _, prefix, e in filtered_edges ] def _get_input_edges_and_graphs( self, node_path: str, prefix: Optional[str] = None ) -> list[tuple["Graph", str, Edge]]: """Gets all input edges for a node along with the graph they are in and the graph's path""" edges = list() # Return any input edges that appear in this graph edges.extend( [(self, prefix, e) for e in self.edges if e.destination.node_id == node_path] ) node_id = ( node_path if "." not in node_path else node_path[: node_path.index(".")] ) node = self.nodes[node_id] if isinstance(node, GraphInvocation): graph = node.graph graph_path = ( node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix) ) graph_edges = graph._get_input_edges_and_graphs( node_path[(len(node_id) + 1) :], prefix=graph_path ) edges.extend(graph_edges) return edges def _get_output_edges( self, node_path: str, field: str ) -> list[Edge]: """Gets all output edges for a node""" edges = self._get_output_edges_and_graphs(node_path) # Filter to edges that match the field filtered_edges = (e for e in edges if e[2].source.field == field) # Create full node paths for each edge return [ Edge( source=EdgeConnection( node_id=self._get_node_path(e.source.node_id, prefix=prefix), field=e.source.field, ), destination=EdgeConnection( node_id=self._get_node_path(e.destination.node_id, prefix=prefix), field=e.destination.field, ) ) for _, prefix, e in filtered_edges ] def _get_output_edges_and_graphs( self, node_path: str, prefix: Optional[str] = None ) -> list[tuple["Graph", str, Edge]]: """Gets all output edges for a node along with the graph they are in and the graph's path""" edges = list() # Return any input edges that appear in this graph edges.extend( [(self, prefix, e) for e in self.edges if e.source.node_id == node_path] ) node_id = ( node_path if "." not in node_path else node_path[: node_path.index(".")] ) node = self.nodes[node_id] if isinstance(node, GraphInvocation): graph = node.graph graph_path = ( node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix) ) graph_edges = graph._get_output_edges_and_graphs( node_path[(len(node_id) + 1) :], prefix=graph_path ) edges.extend(graph_edges) return edges def _is_iterator_connection_valid( self, node_path: str, new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None, ) -> bool: inputs = list([e.source for e in self._get_input_edges(node_path, "collection")]) outputs = list([e.destination for e in self._get_output_edges(node_path, "item")]) if new_input is not None: inputs.append(new_input) if new_output is not None: outputs.append(new_output) # Only one input is allowed for iterators if len(inputs) > 1: return False # Get input and output fields (the fields linked to the iterator's input/output) input_field = get_output_field( self.get_node(inputs[0].node_id), inputs[0].field ) output_fields = list( [get_input_field(self.get_node(e.node_id), e.field) for e in outputs] ) # Input type must be a list if get_origin(input_field) != list: return False # Validate that all outputs match the input type input_field_item_type = get_args(input_field)[0] if not all( ( are_connection_types_compatible(input_field_item_type, f) for f in output_fields ) ): return False return True def _is_collector_connection_valid( self, node_path: str, new_input: Optional[EdgeConnection] = None, new_output: Optional[EdgeConnection] = None, ) -> bool: inputs = list([e.source for e in self._get_input_edges(node_path, "item")]) outputs = list([e.destination for e in self._get_output_edges(node_path, "collection")]) if new_input is not None: inputs.append(new_input) if new_output is not None: outputs.append(new_output) # Get input and output fields (the fields linked to the iterator's input/output) input_fields = list( [get_output_field(self.get_node(e.node_id), e.field) for e in inputs] ) output_fields = list( [get_input_field(self.get_node(e.node_id), e.field) for e in outputs] ) # Validate that all inputs are derived from or match a single type input_field_types = set( [ t for input_field in input_fields for t in ( [input_field] if get_origin(input_field) == None else get_args(input_field) ) if t != NoneType ] ) # Get unique types type_tree = nx.DiGraph() type_tree.add_nodes_from(input_field_types) type_tree.add_edges_from( [ e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0]) ] ) type_degrees = type_tree.in_degree(type_tree.nodes) if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore return False # There is more than one root type # Get the input root type input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore # Verify that all outputs are lists if not all((get_origin(f) == list for f in output_fields)): return False # Verify that all outputs match the input type (are a base class or the same class) if not all( (issubclass(input_root_type, get_args(f)[0]) for f in output_fields) ): return False return True def nx_graph(self) -> nx.DiGraph: """Returns a NetworkX DiGraph representing the layout of this graph""" # TODO: Cache this? g = nx.DiGraph() g.add_nodes_from([n for n in self.nodes.keys()]) g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges])) return g def nx_graph_flat( self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None ) -> nx.DiGraph: """Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)""" g = nx_graph or nx.DiGraph() # Add all nodes from this graph except graph/iteration nodes g.add_nodes_from( [ self._get_node_path(n.id, prefix) for n in self.nodes.values() if not isinstance(n, GraphInvocation) and not isinstance(n, IterateInvocation) ] ) # Expand graph nodes for sgn in ( gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation) ): sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix)) # TODO: figure out if iteration nodes need to be expanded unique_edges = set([(e.source.node_id, e.destination.node_id) for e in self.edges]) g.add_edges_from( [ (self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges ] ) return g class GraphExecutionState(BaseModel): """Tracks the state of a graph execution""" id: str = Field( description="The id of the execution state", default_factory=uuid.uuid4 ) # TODO: Store a reference to the graph instead of the actual graph? graph: Graph = Field(description="The graph being executed") # The graph of materialized nodes execution_graph: Graph = Field( description="The expanded graph of activated and executed nodes", default_factory=Graph, ) # Nodes that have been executed executed: set[str] = Field( description="The set of node ids that have been executed", default_factory=set ) executed_history: list[str] = Field( description="The list of node ids that have been executed, in order of execution", default_factory=list, ) # The results of executed nodes results: dict[ str, Annotated[InvocationOutputsUnion, Field(discriminator="type")] ] = Field(description="The results of node executions", default_factory=dict) # Errors raised when executing nodes errors: dict[str, str] = Field( description="Errors raised when executing nodes", default_factory=dict ) # Map of prepared/executed nodes to their original nodes prepared_source_mapping: dict[str, str] = Field( description="The map of prepared nodes to original graph nodes", default_factory=dict, ) # Map of original nodes to prepared nodes source_prepared_mapping: dict[str, set[str]] = Field( description="The map of original graph nodes to prepared nodes", default_factory=dict, ) class Config: schema_extra = { 'required': [ 'id', 'graph', 'execution_graph', 'executed', 'executed_history', 'results', 'errors', 'prepared_source_mapping', 'source_prepared_mapping', ] } def next(self) -> BaseInvocation | None: """Gets the next node ready to execute.""" # TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes # possibly with a timeout? # If there are no prepared nodes, prepare some nodes next_node = self._get_next_node() if next_node is None: prepared_id = self._prepare() # TODO: prepare multiple nodes at once? # while prepared_id is not None and not isinstance(self.graph.nodes[prepared_id], IterateInvocation): # prepared_id = self._prepare() if prepared_id is not None: next_node = self._get_next_node() # Get values from edges if next_node is not None: self._prepare_inputs(next_node) # If next is still none, there's no next node, return None return next_node def complete(self, node_id: str, output: InvocationOutputsUnion): """Marks a node as complete""" if node_id not in self.execution_graph.nodes: return # TODO: log error? # Mark node as executed self.executed.add(node_id) self.results[node_id] = output # Check if source node is complete (all prepared nodes are complete) source_node = self.prepared_source_mapping[node_id] prepared_nodes = self.source_prepared_mapping[source_node] if all([n in self.executed for n in prepared_nodes]): self.executed.add(source_node) self.executed_history.append(source_node) def set_node_error(self, node_id: str, error: str): """Marks a node as errored""" self.errors[node_id] = error def is_complete(self) -> bool: """Returns true if the graph is complete""" return self.has_error() or all((k in self.executed for k in self.graph.nodes)) def has_error(self) -> bool: """Returns true if the graph has any errors""" return len(self.errors) > 0 def _create_execution_node( self, node_path: str, iteration_node_map: list[tuple[str, str]] ) -> list[str]: """Prepares an iteration node and connects all edges, returning the new node id""" node = self.graph.get_node(node_path) self_iteration_count = -1 # If this is an iterator node, we must create a copy for each iteration if isinstance(node, IterateInvocation): # Get input collection edge (should error if there are no inputs) input_collection_edge = next( iter(self.graph._get_input_edges(node_path, "collection")) ) input_collection_prepared_node_id = next( n[1] for n in iteration_node_map if n[0] == input_collection_edge.source.node_id ) input_collection_prepared_node_output = self.results[ input_collection_prepared_node_id ] input_collection = getattr( input_collection_prepared_node_output, input_collection_edge.source.field ) self_iteration_count = len(input_collection) new_nodes = list() if self_iteration_count == 0: # TODO: should this raise a warning? It might just happen if an empty collection is input, and should be valid. return new_nodes # Get all input edges input_edges = self.graph._get_input_edges(node_path) # Create new edges for this iteration # For collect nodes, this may contain multiple inputs to the same field new_edges = list() for edge in input_edges: for input_node_id in ( n[1] for n in iteration_node_map if n[0] == edge.source.node_id ): new_edge = Edge( source=EdgeConnection(node_id=input_node_id, field=edge.source.field), destination=EdgeConnection(node_id="", field=edge.destination.field), ) new_edges.append(new_edge) # Create a new node (or one for each iteration of this iterator) for i in range(self_iteration_count) if self_iteration_count > 0 else [-1]: # Create a new node new_node = copy.deepcopy(node) # Create the node id (use a random uuid) new_node.id = str(uuid.uuid4()) # Set the iteration index for iteration invocations if isinstance(new_node, IterateInvocation): new_node.index = i # Add to execution graph self.execution_graph.add_node(new_node) self.prepared_source_mapping[new_node.id] = node_path if node_path not in self.source_prepared_mapping: self.source_prepared_mapping[node_path] = set() self.source_prepared_mapping[node_path].add(new_node.id) # Add new edges to execution graph for edge in new_edges: new_edge = Edge( source=edge.source, destination=EdgeConnection(node_id=new_node.id, field=edge.destination.field), ) self.execution_graph.add_edge(new_edge) new_nodes.append(new_node.id) return new_nodes def _iterator_graph(self) -> nx.DiGraph: """Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node""" g = self.graph.nx_graph() collectors = ( n for n in self.graph.nodes if isinstance(self.graph.nodes[n], CollectInvocation) ) for c in collectors: g.remove_edges_from(list(g.in_edges(c))) return g def _get_node_iterators(self, node_id: str) -> list[str]: """Gets iterators for a node""" g = self._iterator_graph() iterators = [ n for n in nx.ancestors(g, node_id) if isinstance(self.graph.nodes[n], IterateInvocation) ] return iterators def _prepare(self) -> Optional[str]: # Get flattened source graph g = self.graph.nx_graph_flat() # Find next unprepared node where all source nodes are executed sorted_nodes = nx.topological_sort(g) next_node_id = next( ( n for n in sorted_nodes if n not in self.source_prepared_mapping and all((e[0] in self.executed for e in g.in_edges(n))) ), None, ) if next_node_id == None: return None # Get all parents of the next node next_node_parents = [e[0] for e in g.in_edges(next_node_id)] # Create execution nodes next_node = self.graph.get_node(next_node_id) new_node_ids = list() if isinstance(next_node, CollectInvocation): # Collapse all iterator input mappings and create a single execution node for the collect invocation all_iteration_mappings = list( itertools.chain( *( ((s, p) for p in self.source_prepared_mapping[s]) for s in next_node_parents ) ) ) # all_iteration_mappings = list(set(itertools.chain(*prepared_parent_mappings))) create_results = self._create_execution_node( next_node_id, all_iteration_mappings ) if create_results is not None: new_node_ids.extend(create_results) else: # Iterators or normal nodes # Get all iterator combinations for this node # Will produce a list of lists of prepared iterator nodes, from which results can be iterated iterator_nodes = self._get_node_iterators(next_node_id) iterator_nodes_prepared = [ list(self.source_prepared_mapping[n]) for n in iterator_nodes ] iterator_node_prepared_combinations = list( itertools.product(*iterator_nodes_prepared) ) # Select the correct prepared parents for each iteration # For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator # TODO: Handle a node mapping to none eg = self.execution_graph.nx_graph_flat() prepared_parent_mappings = [[(n, self._get_iteration_node(n, g, eg, it)) for n in next_node_parents] for it in iterator_node_prepared_combinations] # type: ignore # Create execution node for each iteration for iteration_mappings in prepared_parent_mappings: create_results = self._create_execution_node(next_node_id, iteration_mappings) # type: ignore if create_results is not None: new_node_ids.extend(create_results) return next(iter(new_node_ids), None) def _get_iteration_node( self, source_node_path: str, graph: nx.DiGraph, execution_graph: nx.DiGraph, prepared_iterator_nodes: list[str], ) -> Optional[str]: """Gets the prepared version of the specified source node that matches every iteration specified""" prepared_nodes = self.source_prepared_mapping[source_node_path] if len(prepared_nodes) == 1: return next(iter(prepared_nodes)) # Check if the requested node is an iterator prepared_iterator = next( (n for n in prepared_nodes if n in prepared_iterator_nodes), None ) if prepared_iterator is not None: return prepared_iterator # Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source) iterator_source_node_mapping = [ (n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes ] parent_iterators = [ itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_path) ] return next( ( n for n in prepared_nodes if all( nx.has_path(execution_graph, pit[0], n) for pit in parent_iterators ) ), None, ) def _get_next_node(self) -> Optional[BaseInvocation]: g = self.execution_graph.nx_graph() sorted_nodes = nx.topological_sort(g) next_node = next((n for n in sorted_nodes if n not in self.executed), None) if next_node is None: return None return self.execution_graph.nodes[next_node] def _prepare_inputs(self, node: BaseInvocation): input_edges = [e for e in self.execution_graph.edges if e.destination.node_id == node.id] if isinstance(node, CollectInvocation): output_collection = [ getattr(self.results[edge.source.node_id], edge.source.field) for edge in input_edges if edge.destination.field == "item" ] setattr(node, "collection", output_collection) else: for edge in input_edges: output_value = getattr(self.results[edge.source.node_id], edge.source.field) setattr(node, edge.destination.field, output_value) # TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state def _is_edge_valid(self, edge: Edge) -> bool: if not self._is_edge_valid(edge): return False # Invalid if destination has already been prepared or executed if edge.destination.node_id in self.source_prepared_mapping: return False # Otherwise, the edge is valid return True def _is_node_updatable(self, node_id: str) -> bool: # The node is updatable as long as it hasn't been prepared or executed return node_id not in self.source_prepared_mapping def add_node(self, node: BaseInvocation) -> None: self.graph.add_node(node) def update_node(self, node_path: str, new_node: BaseInvocation) -> None: if not self._is_node_updatable(node_path): raise NodeAlreadyExecutedError( f"Node {node_path} has already been prepared or executed and cannot be updated" ) self.graph.update_node(node_path, new_node) def delete_node(self, node_path: str) -> None: if not self._is_node_updatable(node_path): raise NodeAlreadyExecutedError( f"Node {node_path} has already been prepared or executed and cannot be deleted" ) self.graph.delete_node(node_path) def add_edge(self, edge: Edge) -> None: if not self._is_node_updatable(edge.destination.node_id): raise NodeAlreadyExecutedError( f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot be linked to" ) self.graph.add_edge(edge) def delete_edge(self, edge: Edge) -> None: if not self._is_node_updatable(edge.destination.node_id): raise NodeAlreadyExecutedError( f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot have a source edge deleted" ) self.graph.delete_edge(edge) GraphInvocation.update_forward_refs()