InvokeAI/invokeai/app/services/graph.py

1265 lines
46 KiB
Python
Raw Normal View History

# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
import copy
import itertools
import uuid
2023-03-03 06:02:00 +00:00
from types import NoneType
from typing import (
Annotated,
Any,
Literal,
Optional,
Union,
get_args,
get_origin,
get_type_hints,
)
import networkx as nx
from pydantic import BaseModel, root_validator, validator
from pydantic.fields import Field
from ..invocations import *
2023-03-03 06:02:00 +00:00
from ..invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
InvocationContext,
)
class EdgeConnection(BaseModel):
2023-03-03 06:02:00 +00:00
node_id: str = Field(description="The id of the node for this edge connection")
field: str = Field(description="The field for this connection")
def __eq__(self, other):
2023-03-03 06:02:00 +00:00
return (
isinstance(other, self.__class__)
and getattr(other, "node_id", None) == self.node_id
and getattr(other, "field", None) == self.field
)
def __hash__(self):
2023-03-03 06:02:00 +00:00
return hash(f"{self.node_id}.{self.field}")
2023-03-15 06:09:30 +00:00
class Edge(BaseModel):
source: EdgeConnection = Field(description="The connection for the edge's from node and field")
destination: EdgeConnection = Field(description="The connection for the edge's to node and field")
def get_output_field(node: BaseInvocation, field: str) -> Any:
node_type = type(node)
node_outputs = get_type_hints(node_type.get_output_type())
node_output_field = node_outputs.get(field) or None
return node_output_field
2023-03-03 06:02:00 +00:00
def get_input_field(node: BaseInvocation, field: str) -> Any:
node_type = type(node)
node_inputs = get_type_hints(node_type)
node_input_field = node_inputs.get(field) or None
return node_input_field
from typing import Optional, Union, List, get_args
def is_union_subtype(t1, t2):
t1_args = get_args(t1)
t2_args = get_args(t2)
if not t1_args:
# t1 is a single type
return t1 in t2_args
else:
# t1 is a Union, check that all of its types are in t2_args
return all(arg in t2_args for arg in t1_args)
def is_list_or_contains_list(t):
t_args = get_args(t)
# If the type is a List
if get_origin(t) is list:
return True
# If the type is a Union
elif t_args:
# Check if any of the types in the Union is a List
for arg in t_args:
if get_origin(arg) is list:
return True
return False
def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
if not from_type:
return False
if not to_type:
return False
2023-03-03 06:02:00 +00:00
# TODO: this is pretty forgiving on generic types. Clean that up (need to handle optionals and such)
if from_type and to_type:
# Ports are compatible
2023-03-03 06:02:00 +00:00
if (
from_type == to_type
or from_type == Any
or to_type == Any
or Any in get_args(from_type)
or Any in get_args(to_type)
):
return True
if from_type in get_args(to_type):
return True
if to_type in get_args(from_type):
return True
# if not issubclass(from_type, to_type):
if not is_union_subtype(from_type, to_type):
return False
else:
return False
2023-03-03 06:02:00 +00:00
return True
def are_connections_compatible(
2023-03-03 06:02:00 +00:00
from_node: BaseInvocation, from_field: str, to_node: BaseInvocation, to_field: str
) -> bool:
"""Determines if a connection between fields of two nodes is compatible."""
# TODO: handle iterators and collectors
from_node_field = get_output_field(from_node, from_field)
to_node_field = get_input_field(to_node, to_field)
return are_connection_types_compatible(from_node_field, to_node_field)
class NodeAlreadyInGraphError(Exception):
pass
class InvalidEdgeError(Exception):
pass
2023-03-03 06:02:00 +00:00
class NodeNotFoundError(Exception):
pass
2023-03-03 06:02:00 +00:00
class NodeAlreadyExecutedError(Exception):
pass
# TODO: Create and use an Empty output?
class GraphInvocationOutput(BaseInvocationOutput):
2023-03-03 06:02:00 +00:00
type: Literal["graph_output"] = "graph_output"
class Config:
schema_extra = {
'required': [
'type',
'image',
]
}
# TODO: Fill this out and move to invocations
class GraphInvocation(BaseInvocation):
"""Execute a graph"""
2023-03-03 06:02:00 +00:00
type: Literal["graph"] = "graph"
# TODO: figure out how to create a default here
2023-03-03 06:02:00 +00:00
graph: "Graph" = Field(description="The graph to run", default=None)
def invoke(self, context: InvocationContext) -> GraphInvocationOutput:
"""Invoke with provided services and return outputs."""
return GraphInvocationOutput()
class IterateInvocationOutput(BaseInvocationOutput):
"""Used to connect iteration outputs. Will be expanded to a specific output."""
2023-03-03 06:02:00 +00:00
type: Literal["iterate_output"] = "iterate_output"
item: Any = Field(description="The item being iterated over")
class Config:
schema_extra = {
'required': [
'type',
'item',
]
}
# TODO: Fill this out and move to invocations
class IterateInvocation(BaseInvocation):
"""Iterates over a list of items"""
2023-03-03 06:02:00 +00:00
type: Literal["iterate"] = "iterate"
2023-03-03 06:02:00 +00:00
collection: list[Any] = Field(
description="The list of items to iterate over", default_factory=list
)
index: int = Field(
description="The index, will be provided on executed iterators", default=0
)
def invoke(self, context: InvocationContext) -> IterateInvocationOutput:
"""Produces the outputs as values"""
2023-03-03 06:02:00 +00:00
return IterateInvocationOutput(item=self.collection[self.index])
class CollectInvocationOutput(BaseInvocationOutput):
2023-03-03 06:02:00 +00:00
type: Literal["collect_output"] = "collect_output"
collection: list[Any] = Field(description="The collection of input items")
class Config:
schema_extra = {
'required': [
'type',
'collection',
]
}
class CollectInvocation(BaseInvocation):
"""Collects values into a collection"""
2023-03-03 06:02:00 +00:00
type: Literal["collect"] = "collect"
item: Any = Field(
description="The item to collect (all inputs must be of the same type)",
default=None,
)
collection: list[Any] = Field(
description="The collection, will be provided on execution",
default_factory=list,
)
def invoke(self, context: InvocationContext) -> CollectInvocationOutput:
"""Invoke with provided services and return outputs."""
2023-03-03 06:02:00 +00:00
return CollectInvocationOutput(collection=copy.copy(self.collection))
2023-03-03 06:02:00 +00:00
InvocationsUnion = Union[BaseInvocation.get_invocations()] # type: ignore
InvocationOutputsUnion = Union[BaseInvocationOutput.get_all_subclasses_tuple()] # type: ignore
class Graph(BaseModel):
id: str = Field(description="The id of this graph", default_factory=lambda: uuid.uuid4().__str__())
# TODO: use a list (and never use dict in a BaseModel) because pydantic/fastapi hates me
2023-03-03 06:02:00 +00:00
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
description="The nodes in this graph", default_factory=dict
)
2023-03-15 06:09:30 +00:00
edges: list[Edge] = Field(
2023-03-03 06:02:00 +00:00
description="The connections between nodes and their fields in this graph",
default_factory=list,
)
def add_node(self, node: BaseInvocation) -> None:
"""Adds a node to a graph
2023-03-03 06:02:00 +00:00
:raises NodeAlreadyInGraphError: the node is already present in the graph.
"""
if node.id in self.nodes:
raise NodeAlreadyInGraphError()
self.nodes[node.id] = node
2023-03-03 06:02:00 +00:00
def _get_graph_and_node(self, node_path: str) -> tuple["Graph", str]:
"""Returns the graph and node id for a node path."""
# Materialized graphs may have nodes at the top level
if node_path in self.nodes:
return (self, node_path)
2023-03-03 06:02:00 +00:00
node_id = (
node_path if "." not in node_path else node_path[: node_path.index(".")]
)
if node_id not in self.nodes:
2023-03-03 06:02:00 +00:00
raise NodeNotFoundError(f"Node {node_path} not found in graph")
node = self.nodes[node_id]
if not isinstance(node, GraphInvocation):
# There's more node path left but this isn't a graph - failure
2023-03-03 06:02:00 +00:00
raise NodeNotFoundError("Node path terminated early at a non-graph node")
2023-03-03 06:02:00 +00:00
return node.graph._get_graph_and_node(node_path[node_path.index(".") + 1 :])
def delete_node(self, node_path: str) -> None:
"""Deletes a node from a graph"""
try:
graph, node_id = self._get_graph_and_node(node_path)
# Delete edges for this node
input_edges = self._get_input_edges_and_graphs(node_path)
output_edges = self._get_output_edges_and_graphs(node_path)
2023-03-03 06:02:00 +00:00
for edge_graph, _, edge in input_edges:
edge_graph.delete_edge(edge)
2023-03-03 06:02:00 +00:00
for edge_graph, _, edge in output_edges:
edge_graph.delete_edge(edge)
del graph.nodes[node_id]
except NodeNotFoundError:
2023-03-03 06:02:00 +00:00
pass # Ignore, not doesn't exist (should this throw?)
2023-03-15 06:09:30 +00:00
def add_edge(self, edge: Edge) -> None:
"""Adds an edge to a graph
2023-03-03 06:02:00 +00:00
:raises InvalidEdgeError: the provided edge is invalid.
"""
self._validate_edge(edge)
if edge not in self.edges:
self.edges.append(edge)
else:
raise InvalidEdgeError()
2023-03-03 06:02:00 +00:00
2023-03-15 06:09:30 +00:00
def delete_edge(self, edge: Edge) -> None:
"""Deletes an edge from a graph"""
try:
self.edges.remove(edge)
except KeyError:
pass
def is_valid(self) -> bool:
"""Validates the graph."""
# Validate all subgraphs
for gn in (n for n in self.nodes.values() if isinstance(n, GraphInvocation)):
if not gn.graph.is_valid():
return False
# Validate all edges reference nodes in the graph
2023-03-03 06:02:00 +00:00
node_ids = set(
2023-03-15 06:09:30 +00:00
[e.source.node_id for e in self.edges] + [e.destination.node_id for e in self.edges]
2023-03-03 06:02:00 +00:00
)
if not all((self.has_node(node_id) for node_id in node_ids)):
return False
# Validate there are no cycles
g = self.nx_graph_flat()
if not nx.is_directed_acyclic_graph(g):
return False
2023-03-03 06:02:00 +00:00
# Validate all edge connections are valid
2023-03-03 06:02:00 +00:00
if not all(
(
are_connections_compatible(
2023-03-15 06:09:30 +00:00
self.get_node(e.source.node_id),
e.source.field,
self.get_node(e.destination.node_id),
e.destination.field,
2023-03-03 06:02:00 +00:00
)
for e in self.edges
)
):
return False
2023-03-03 06:02:00 +00:00
# Validate all iterators
# TODO: may need to validate all iterators in subgraphs so edge connections in parent graphs will be available
2023-03-03 06:02:00 +00:00
if not all(
(
self._is_iterator_connection_valid(n.id)
for n in self.nodes.values()
if isinstance(n, IterateInvocation)
)
):
return False
# Validate all collectors
# TODO: may need to validate all collectors in subgraphs so edge connections in parent graphs will be available
2023-03-03 06:02:00 +00:00
if not all(
(
self._is_collector_connection_valid(n.id)
for n in self.nodes.values()
if isinstance(n, CollectInvocation)
)
):
return False
2023-03-03 06:02:00 +00:00
return True
2023-03-03 06:02:00 +00:00
def _validate_edge(self, edge: Edge):
"""Validates that a new edge doesn't create a cycle in the graph"""
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
try:
2023-03-15 06:09:30 +00:00
from_node = self.get_node(edge.source.node_id)
to_node = self.get_node(edge.destination.node_id)
except NodeNotFoundError:
raise InvalidEdgeError("One or both nodes don't exist")
# Validate that an edge to this node+field doesn't already exist
2023-03-15 06:09:30 +00:00
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
raise InvalidEdgeError(f'Edge to node {edge.destination.node_id} field {edge.destination.field} already exists')
# Validate that no cycles would be created
g = self.nx_graph_flat()
2023-03-15 06:09:30 +00:00
g.add_edge(edge.source.node_id, edge.destination.node_id)
if not nx.is_directed_acyclic_graph(g):
raise InvalidEdgeError(f'Edge creates a cycle in the graph')
2023-03-03 06:02:00 +00:00
# Validate that the field types are compatible
2023-03-03 06:02:00 +00:00
if not are_connections_compatible(
2023-03-15 06:09:30 +00:00
from_node, edge.source.field, to_node, edge.destination.field
2023-03-03 06:02:00 +00:00
):
raise InvalidEdgeError(f'Fields are incompatible')
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
2023-03-15 06:09:30 +00:00
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
2023-03-03 06:02:00 +00:00
if not self._is_iterator_connection_valid(
2023-03-15 06:09:30 +00:00
edge.destination.node_id, new_input=edge.source
2023-03-03 06:02:00 +00:00
):
raise InvalidEdgeError(f'Iterator input type does not match iterator output type')
# Validate if iterator input type matches output type (if this edge results in both being set)
2023-03-15 06:09:30 +00:00
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
2023-03-03 06:02:00 +00:00
if not self._is_iterator_connection_valid(
2023-03-15 06:09:30 +00:00
edge.source.node_id, new_output=edge.destination
2023-03-03 06:02:00 +00:00
):
raise InvalidEdgeError(f'Iterator output type does not match iterator input type')
# Validate if collector input type matches output type (if this edge results in both being set)
2023-03-15 06:09:30 +00:00
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
2023-03-03 06:02:00 +00:00
if not self._is_collector_connection_valid(
2023-03-15 06:09:30 +00:00
edge.destination.node_id, new_input=edge.source
2023-03-03 06:02:00 +00:00
):
raise InvalidEdgeError(f'Collector output type does not match collector input type')
# Validate if collector output type matches input type (if this edge results in both being set)
2023-03-15 06:09:30 +00:00
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
2023-03-03 06:02:00 +00:00
if not self._is_collector_connection_valid(
2023-03-15 06:09:30 +00:00
edge.source.node_id, new_output=edge.destination
2023-03-03 06:02:00 +00:00
):
raise InvalidEdgeError(f'Collector input type does not match collector output type')
def has_node(self, node_path: str) -> bool:
"""Determines whether or not a node exists in the graph."""
try:
n = self.get_node(node_path)
if n is not None:
return True
else:
return False
except NodeNotFoundError:
return False
def get_node(self, node_path: str) -> InvocationsUnion:
"""Gets a node from the graph using a node path."""
# Materialized graphs may have nodes at the top level
graph, node_id = self._get_graph_and_node(node_path)
return graph.nodes[node_id]
def _get_node_path(self, node_id: str, prefix: Optional[str] = None) -> str:
2023-03-03 06:02:00 +00:00
return node_id if prefix is None or prefix == "" else f"{prefix}.{node_id}"
def update_node(self, node_path: str, new_node: BaseInvocation) -> None:
"""Updates a node in the graph."""
graph, node_id = self._get_graph_and_node(node_path)
node = graph.nodes[node_id]
2023-03-03 06:02:00 +00:00
# Ensure the node type matches the new node
if type(node) != type(new_node):
2023-03-03 06:02:00 +00:00
raise TypeError(
f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}"
)
# Ensure the new id is either the same or is not in the graph
2023-03-03 06:02:00 +00:00
prefix = None if "." not in node_path else node_path[: node_path.rindex(".")]
new_path = self._get_node_path(new_node.id, prefix=prefix)
if new_node.id != node.id and self.has_node(new_path):
2023-03-03 06:02:00 +00:00
raise NodeAlreadyInGraphError(
"Node with id {new_node.id} already exists in graph"
)
# Set the new node in the graph
graph.nodes[new_node.id] = new_node
if new_node.id != node.id:
input_edges = self._get_input_edges_and_graphs(node_path)
output_edges = self._get_output_edges_and_graphs(node_path)
# Delete node and all edges
graph.delete_node(node_path)
# Create new edges for each input and output
2023-03-03 06:02:00 +00:00
for graph, _, edge in input_edges:
# Remove the graph prefix from the node path
2023-03-03 06:02:00 +00:00
new_graph_node_path = (
new_node.id
2023-03-15 06:09:30 +00:00
if "." not in edge.destination.node_id
else f'{edge.destination.node_id[edge.destination.node_id.rindex("."):]}.{new_node.id}'
2023-03-03 06:02:00 +00:00
)
graph.add_edge(
2023-03-15 06:09:30 +00:00
Edge(
source=edge.source,
destination=EdgeConnection(
node_id=new_graph_node_path, field=edge.destination.field
)
2023-03-03 06:02:00 +00:00
)
)
for graph, _, edge in output_edges:
# Remove the graph prefix from the node path
2023-03-03 06:02:00 +00:00
new_graph_node_path = (
new_node.id
2023-03-15 06:09:30 +00:00
if "." not in edge.source.node_id
else f'{edge.source.node_id[edge.source.node_id.rindex("."):]}.{new_node.id}'
2023-03-03 06:02:00 +00:00
)
graph.add_edge(
2023-03-15 06:09:30 +00:00
Edge(
source=EdgeConnection(
node_id=new_graph_node_path, field=edge.source.field
2023-03-03 06:02:00 +00:00
),
2023-03-15 06:09:30 +00:00
destination=edge.destination
2023-03-03 06:02:00 +00:00
)
)
def _get_input_edges(
self, node_path: str, field: Optional[str] = None
2023-03-15 06:09:30 +00:00
) -> list[Edge]:
"""Gets all input edges for a node"""
edges = self._get_input_edges_and_graphs(node_path)
2023-03-03 06:02:00 +00:00
# Filter to edges that match the field
2023-03-15 06:09:30 +00:00
filtered_edges = (e for e in edges if field is None or e[2].destination.field == field)
# Create full node paths for each edge
2023-03-03 06:02:00 +00:00
return [
2023-03-15 06:09:30 +00:00
Edge(
source=EdgeConnection(
node_id=self._get_node_path(e.source.node_id, prefix=prefix),
field=e.source.field,
2023-03-03 06:02:00 +00:00
),
2023-03-15 06:09:30 +00:00
destination=EdgeConnection(
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
field=e.destination.field,
)
2023-03-03 06:02:00 +00:00
)
for _, prefix, e in filtered_edges
]
def _get_input_edges_and_graphs(
self, node_path: str, prefix: Optional[str] = None
2023-03-15 06:09:30 +00:00
) -> list[tuple["Graph", str, Edge]]:
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
edges = list()
# Return any input edges that appear in this graph
2023-03-03 06:02:00 +00:00
edges.extend(
2023-03-15 06:09:30 +00:00
[(self, prefix, e) for e in self.edges if e.destination.node_id == node_path]
2023-03-03 06:02:00 +00:00
)
2023-03-03 06:02:00 +00:00
node_id = (
node_path if "." not in node_path else node_path[: node_path.index(".")]
)
node = self.nodes[node_id]
if isinstance(node, GraphInvocation):
graph = node.graph
2023-03-03 06:02:00 +00:00
graph_path = (
node.id
if prefix is None or prefix == ""
else self._get_node_path(node.id, prefix=prefix)
)
graph_edges = graph._get_input_edges_and_graphs(
node_path[(len(node_id) + 1) :], prefix=graph_path
)
edges.extend(graph_edges)
2023-03-03 06:02:00 +00:00
return edges
2023-03-03 06:02:00 +00:00
def _get_output_edges(
self, node_path: str, field: str
2023-03-15 06:09:30 +00:00
) -> list[Edge]:
"""Gets all output edges for a node"""
edges = self._get_output_edges_and_graphs(node_path)
2023-03-03 06:02:00 +00:00
# Filter to edges that match the field
2023-03-15 06:09:30 +00:00
filtered_edges = (e for e in edges if e[2].source.field == field)
# Create full node paths for each edge
2023-03-03 06:02:00 +00:00
return [
2023-03-15 06:09:30 +00:00
Edge(
source=EdgeConnection(
node_id=self._get_node_path(e.source.node_id, prefix=prefix),
field=e.source.field,
2023-03-03 06:02:00 +00:00
),
2023-03-15 06:09:30 +00:00
destination=EdgeConnection(
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
field=e.destination.field,
)
2023-03-03 06:02:00 +00:00
)
for _, prefix, e in filtered_edges
]
def _get_output_edges_and_graphs(
self, node_path: str, prefix: Optional[str] = None
2023-03-15 06:09:30 +00:00
) -> list[tuple["Graph", str, Edge]]:
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
edges = list()
# Return any input edges that appear in this graph
2023-03-03 06:02:00 +00:00
edges.extend(
2023-03-15 06:09:30 +00:00
[(self, prefix, e) for e in self.edges if e.source.node_id == node_path]
2023-03-03 06:02:00 +00:00
)
2023-03-03 06:02:00 +00:00
node_id = (
node_path if "." not in node_path else node_path[: node_path.index(".")]
)
node = self.nodes[node_id]
if isinstance(node, GraphInvocation):
graph = node.graph
2023-03-03 06:02:00 +00:00
graph_path = (
node.id
if prefix is None or prefix == ""
else self._get_node_path(node.id, prefix=prefix)
)
graph_edges = graph._get_output_edges_and_graphs(
node_path[(len(node_id) + 1) :], prefix=graph_path
)
edges.extend(graph_edges)
2023-03-03 06:02:00 +00:00
return edges
2023-03-03 06:02:00 +00:00
def _is_iterator_connection_valid(
self,
node_path: str,
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
) -> bool:
2023-03-15 06:09:30 +00:00
inputs = list([e.source for e in self._get_input_edges(node_path, "collection")])
outputs = list([e.destination for e in self._get_output_edges(node_path, "item")])
if new_input is not None:
inputs.append(new_input)
if new_output is not None:
outputs.append(new_output)
2023-03-03 06:02:00 +00:00
# Only one input is allowed for iterators
if len(inputs) > 1:
return False
# Get input and output fields (the fields linked to the iterator's input/output)
2023-03-03 06:02:00 +00:00
input_field = get_output_field(
self.get_node(inputs[0].node_id), inputs[0].field
)
output_fields = list(
[get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
)
# Input type must be a list
if get_origin(input_field) != list:
return False
2023-03-03 06:02:00 +00:00
# Validate that all outputs match the input type
input_field_item_type = get_args(input_field)[0]
2023-03-03 06:02:00 +00:00
if not all(
(
are_connection_types_compatible(input_field_item_type, f)
for f in output_fields
)
):
return False
return True
2023-03-03 06:02:00 +00:00
def _is_collector_connection_valid(
self,
node_path: str,
new_input: Optional[EdgeConnection] = None,
new_output: Optional[EdgeConnection] = None,
) -> bool:
2023-03-15 06:09:30 +00:00
inputs = list([e.source for e in self._get_input_edges(node_path, "item")])
outputs = list([e.destination for e in self._get_output_edges(node_path, "collection")])
if new_input is not None:
inputs.append(new_input)
if new_output is not None:
outputs.append(new_output)
2023-03-03 06:02:00 +00:00
# Get input and output fields (the fields linked to the iterator's input/output)
2023-03-03 06:02:00 +00:00
input_fields = list(
[get_output_field(self.get_node(e.node_id), e.field) for e in inputs]
)
output_fields = list(
[get_input_field(self.get_node(e.node_id), e.field) for e in outputs]
)
# Validate that all inputs are derived from or match a single type
2023-03-03 06:02:00 +00:00
input_field_types = set(
[
t
for input_field in input_fields
for t in (
[input_field]
if get_origin(input_field) == None
else get_args(input_field)
)
if t != NoneType
]
) # Get unique types
type_tree = nx.DiGraph()
type_tree.add_nodes_from(input_field_types)
2023-03-03 06:02:00 +00:00
type_tree.add_edges_from(
[
e
for e in itertools.permutations(input_field_types, 2)
if issubclass(e[1], e[0])
]
)
type_degrees = type_tree.in_degree(type_tree.nodes)
2023-03-03 06:02:00 +00:00
if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore
return False # There is more than one root type
# Get the input root type
2023-03-03 06:02:00 +00:00
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
# Verify that all outputs are lists
# if not all((get_origin(f) == list for f in output_fields)):
# return False
# Verify that all outputs are lists
if not all(is_list_or_contains_list(f) for f in output_fields):
return False
# Verify that all outputs match the input type (are a base class or the same class)
2023-03-03 06:02:00 +00:00
if not all(
(issubclass(input_root_type, get_args(f)[0]) for f in output_fields)
):
return False
return True
def nx_graph(self) -> nx.DiGraph:
"""Returns a NetworkX DiGraph representing the layout of this graph"""
# TODO: Cache this?
g = nx.DiGraph()
g.add_nodes_from([n for n in self.nodes.keys()])
2023-03-15 06:09:30 +00:00
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
2023-05-22 05:48:12 +00:00
return g
def nx_graph_with_data(self) -> nx.DiGraph:
"""Returns a NetworkX DiGraph representing the data and layout of this graph"""
g = nx.DiGraph()
g.add_nodes_from([n for n in self.nodes.items()])
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
return g
2023-03-03 06:02:00 +00:00
def nx_graph_flat(
self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None
) -> nx.DiGraph:
"""Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)"""
g = nx_graph or nx.DiGraph()
# Add all nodes from this graph except graph/iteration nodes
2023-03-03 06:02:00 +00:00
g.add_nodes_from(
[
self._get_node_path(n.id, prefix)
for n in self.nodes.values()
if not isinstance(n, GraphInvocation)
and not isinstance(n, IterateInvocation)
]
)
# Expand graph nodes
2023-03-03 06:02:00 +00:00
for sgn in (
gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)
):
g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
# TODO: figure out if iteration nodes need to be expanded
2023-03-15 06:09:30 +00:00
unique_edges = set([(e.source.node_id, e.destination.node_id) for e in self.edges])
2023-03-03 06:02:00 +00:00
g.add_edges_from(
[
(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix))
for e in unique_edges
]
)
return g
class GraphExecutionState(BaseModel):
"""Tracks the state of a graph execution"""
2023-03-03 06:02:00 +00:00
id: str = Field(description="The id of the execution state", default_factory=lambda: uuid.uuid4().__str__())
# TODO: Store a reference to the graph instead of the actual graph?
graph: Graph = Field(description="The graph being executed")
# The graph of materialized nodes
2023-03-03 06:02:00 +00:00
execution_graph: Graph = Field(
description="The expanded graph of activated and executed nodes",
default_factory=Graph,
)
# Nodes that have been executed
2023-03-03 06:02:00 +00:00
executed: set[str] = Field(
description="The set of node ids that have been executed", default_factory=set
)
executed_history: list[str] = Field(
description="The list of node ids that have been executed, in order of execution",
default_factory=list,
)
# The results of executed nodes
2023-03-03 06:02:00 +00:00
results: dict[
str, Annotated[InvocationOutputsUnion, Field(discriminator="type")]
] = Field(description="The results of node executions", default_factory=dict)
# Errors raised when executing nodes
2023-03-03 06:02:00 +00:00
errors: dict[str, str] = Field(
description="Errors raised when executing nodes", default_factory=dict
)
# Map of prepared/executed nodes to their original nodes
2023-03-03 06:02:00 +00:00
prepared_source_mapping: dict[str, str] = Field(
description="The map of prepared nodes to original graph nodes",
default_factory=dict,
)
# Map of original nodes to prepared nodes
2023-03-03 06:02:00 +00:00
source_prepared_mapping: dict[str, set[str]] = Field(
description="The map of original graph nodes to prepared nodes",
default_factory=dict,
)
class Config:
schema_extra = {
'required': [
'id',
'graph',
'execution_graph',
'executed',
'executed_history',
'results',
'errors',
'prepared_source_mapping',
'source_prepared_mapping',
]
}
def next(self) -> BaseInvocation | None:
"""Gets the next node ready to execute."""
# TODO: enable multiple nodes to execute simultaneously by tracking currently executing nodes
# possibly with a timeout?
# If there are no prepared nodes, prepare some nodes
next_node = self._get_next_node()
if next_node is None:
prepared_id = self._prepare()
feat(nodes): depth-first execution There was an issue where for graphs w/ iterations, your images were output all at once, at the very end of processing. So if you canceled halfway through an execution of 10 nodes, you wouldn't get any images - even though you'd completed 5 images' worth of inference. ## Cause Because graphs executed breadth-first (i.e. depth-by-depth), leaf nodes were necessarily processed last. For image generation graphs, your `LatentsToImage` will be leaf nodes, and be the last depth to be executed. For example, a `TextToLatents` graph w/ 3 iterations would execute all 3 `TextToLatents` nodes fully before moving to the next depth, where the `LatentsToImage` nodes produce output images, resulting in a node execution order like this: 1. TextToLatents 2. TextToLatents 3. TextToLatents 4. LatentsToImage 5. LatentsToImage 6. LatentsToImage ## Solution This PR makes a two changes to graph execution to execute as deeply as it can along each branch of the graph. ### Eager node preparation We now prepare as many nodes as possible, instead of just a single node at a time. We also need to change the conditions in which nodes are prepared. Previously, nodes were prepared only when all of their direct ancestors were executed. The updated logic prepares nodes that: - are *not* `Iterate` nodes whose inputs have *not* been executed - do *not* have any unexecuted `Iterate` ancestor nodes This results in graphs always being maximally prepared. ### Always execute the deepest prepared node We now choose the next node to execute by traversing from the bottom of the graph instead of the top, choosing the first node whose inputs are all executed. This means we always execute the deepest node possible. ## Result Graphs now execute depth-first, so instead of an execution order like this: 1. TextToLatents 2. TextToLatents 3. TextToLatents 4. LatentsToImage 5. LatentsToImage 6. LatentsToImage ... we get an execution order like this: 1. TextToLatents 2. LatentsToImage 3. TextToLatents 4. LatentsToImage 5. TextToLatents 6. LatentsToImage Immediately after inference, the image is decoded and sent to the gallery. fixes #3400
2023-06-08 09:51:38 +00:00
# Prepare as many nodes as we can
while prepared_id is not None:
prepared_id = self._prepare()
next_node = self._get_next_node()
2023-03-03 06:02:00 +00:00
# Get values from edges
if next_node is not None:
self._prepare_inputs(next_node)
# If next is still none, there's no next node, return None
return next_node
def complete(self, node_id: str, output: InvocationOutputsUnion):
"""Marks a node as complete"""
if node_id not in self.execution_graph.nodes:
2023-03-03 06:02:00 +00:00
return # TODO: log error?
# Mark node as executed
self.executed.add(node_id)
self.results[node_id] = output
# Check if source node is complete (all prepared nodes are complete)
source_node = self.prepared_source_mapping[node_id]
prepared_nodes = self.source_prepared_mapping[source_node]
if all([n in self.executed for n in prepared_nodes]):
self.executed.add(source_node)
self.executed_history.append(source_node)
2023-03-03 06:02:00 +00:00
def set_node_error(self, node_id: str, error: str):
"""Marks a node as errored"""
self.errors[node_id] = error
def is_complete(self) -> bool:
"""Returns true if the graph is complete"""
node_ids = set(self.graph.nx_graph_flat().nodes)
return self.has_error() or all((k in self.executed for k in node_ids))
2023-03-03 06:02:00 +00:00
def has_error(self) -> bool:
"""Returns true if the graph has any errors"""
return len(self.errors) > 0
2023-03-03 06:02:00 +00:00
def _create_execution_node(
self, node_path: str, iteration_node_map: list[tuple[str, str]]
) -> list[str]:
"""Prepares an iteration node and connects all edges, returning the new node id"""
node = self.graph.get_node(node_path)
self_iteration_count = -1
# If this is an iterator node, we must create a copy for each iteration
if isinstance(node, IterateInvocation):
# Get input collection edge (should error if there are no inputs)
2023-03-03 06:02:00 +00:00
input_collection_edge = next(
iter(self.graph._get_input_edges(node_path, "collection"))
)
input_collection_prepared_node_id = next(
n[1]
for n in iteration_node_map
2023-03-15 06:09:30 +00:00
if n[0] == input_collection_edge.source.node_id
2023-03-03 06:02:00 +00:00
)
input_collection_prepared_node_output = self.results[
input_collection_prepared_node_id
]
input_collection = getattr(
2023-03-15 06:09:30 +00:00
input_collection_prepared_node_output, input_collection_edge.source.field
2023-03-03 06:02:00 +00:00
)
self_iteration_count = len(input_collection)
new_nodes = list()
if self_iteration_count == 0:
# TODO: should this raise a warning? It might just happen if an empty collection is input, and should be valid.
return new_nodes
# Get all input edges
input_edges = self.graph._get_input_edges(node_path)
# Create new edges for this iteration
# For collect nodes, this may contain multiple inputs to the same field
new_edges = list()
for edge in input_edges:
2023-03-03 06:02:00 +00:00
for input_node_id in (
2023-03-15 06:09:30 +00:00
n[1] for n in iteration_node_map if n[0] == edge.source.node_id
2023-03-03 06:02:00 +00:00
):
2023-03-15 06:09:30 +00:00
new_edge = Edge(
source=EdgeConnection(node_id=input_node_id, field=edge.source.field),
destination=EdgeConnection(node_id="", field=edge.destination.field),
2023-03-03 06:02:00 +00:00
)
new_edges.append(new_edge)
2023-03-03 06:02:00 +00:00
# Create a new node (or one for each iteration of this iterator)
2023-03-03 06:02:00 +00:00
for i in range(self_iteration_count) if self_iteration_count > 0 else [-1]:
# Create a new node
new_node = copy.deepcopy(node)
# Create the node id (use a random uuid)
new_node.id = str(uuid.uuid4())
# Set the iteration index for iteration invocations
if isinstance(new_node, IterateInvocation):
new_node.index = i
# Add to execution graph
self.execution_graph.add_node(new_node)
self.prepared_source_mapping[new_node.id] = node_path
if node_path not in self.source_prepared_mapping:
self.source_prepared_mapping[node_path] = set()
self.source_prepared_mapping[node_path].add(new_node.id)
# Add new edges to execution graph
for edge in new_edges:
2023-03-15 06:09:30 +00:00
new_edge = Edge(
source=edge.source,
destination=EdgeConnection(node_id=new_node.id, field=edge.destination.field),
2023-03-03 06:02:00 +00:00
)
self.execution_graph.add_edge(new_edge)
new_nodes.append(new_node.id)
return new_nodes
def _iterator_graph(self) -> nx.DiGraph:
"""Gets a DiGraph with edges to collectors removed so an ancestor search produces all active iterators for any node"""
g = self.graph.nx_graph_flat()
2023-03-03 06:02:00 +00:00
collectors = (
n
for n in self.graph.nodes
if isinstance(self.graph.get_node(n), CollectInvocation)
2023-03-03 06:02:00 +00:00
)
for c in collectors:
g.remove_edges_from(list(g.in_edges(c)))
return g
def _get_node_iterators(self, node_id: str) -> list[str]:
"""Gets iterators for a node"""
g = self._iterator_graph()
2023-03-03 06:02:00 +00:00
iterators = [
n
for n in nx.ancestors(g, node_id)
if isinstance(self.graph.get_node(n), IterateInvocation)
2023-03-03 06:02:00 +00:00
]
return iterators
def _prepare(self) -> Optional[str]:
# Get flattened source graph
g = self.graph.nx_graph_flat()
feat(nodes): depth-first execution There was an issue where for graphs w/ iterations, your images were output all at once, at the very end of processing. So if you canceled halfway through an execution of 10 nodes, you wouldn't get any images - even though you'd completed 5 images' worth of inference. ## Cause Because graphs executed breadth-first (i.e. depth-by-depth), leaf nodes were necessarily processed last. For image generation graphs, your `LatentsToImage` will be leaf nodes, and be the last depth to be executed. For example, a `TextToLatents` graph w/ 3 iterations would execute all 3 `TextToLatents` nodes fully before moving to the next depth, where the `LatentsToImage` nodes produce output images, resulting in a node execution order like this: 1. TextToLatents 2. TextToLatents 3. TextToLatents 4. LatentsToImage 5. LatentsToImage 6. LatentsToImage ## Solution This PR makes a two changes to graph execution to execute as deeply as it can along each branch of the graph. ### Eager node preparation We now prepare as many nodes as possible, instead of just a single node at a time. We also need to change the conditions in which nodes are prepared. Previously, nodes were prepared only when all of their direct ancestors were executed. The updated logic prepares nodes that: - are *not* `Iterate` nodes whose inputs have *not* been executed - do *not* have any unexecuted `Iterate` ancestor nodes This results in graphs always being maximally prepared. ### Always execute the deepest prepared node We now choose the next node to execute by traversing from the bottom of the graph instead of the top, choosing the first node whose inputs are all executed. This means we always execute the deepest node possible. ## Result Graphs now execute depth-first, so instead of an execution order like this: 1. TextToLatents 2. TextToLatents 3. TextToLatents 4. LatentsToImage 5. LatentsToImage 6. LatentsToImage ... we get an execution order like this: 1. TextToLatents 2. LatentsToImage 3. TextToLatents 4. LatentsToImage 5. TextToLatents 6. LatentsToImage Immediately after inference, the image is decoded and sent to the gallery. fixes #3400
2023-06-08 09:51:38 +00:00
# Find next node that:
# - was not already prepared
# - is not an iterate node whose inputs have not been executed
# - does not have an unexecuted iterate ancestor
sorted_nodes = nx.topological_sort(g)
2023-03-03 06:02:00 +00:00
next_node_id = next(
(
n
for n in sorted_nodes
feat(nodes): depth-first execution There was an issue where for graphs w/ iterations, your images were output all at once, at the very end of processing. So if you canceled halfway through an execution of 10 nodes, you wouldn't get any images - even though you'd completed 5 images' worth of inference. ## Cause Because graphs executed breadth-first (i.e. depth-by-depth), leaf nodes were necessarily processed last. For image generation graphs, your `LatentsToImage` will be leaf nodes, and be the last depth to be executed. For example, a `TextToLatents` graph w/ 3 iterations would execute all 3 `TextToLatents` nodes fully before moving to the next depth, where the `LatentsToImage` nodes produce output images, resulting in a node execution order like this: 1. TextToLatents 2. TextToLatents 3. TextToLatents 4. LatentsToImage 5. LatentsToImage 6. LatentsToImage ## Solution This PR makes a two changes to graph execution to execute as deeply as it can along each branch of the graph. ### Eager node preparation We now prepare as many nodes as possible, instead of just a single node at a time. We also need to change the conditions in which nodes are prepared. Previously, nodes were prepared only when all of their direct ancestors were executed. The updated logic prepares nodes that: - are *not* `Iterate` nodes whose inputs have *not* been executed - do *not* have any unexecuted `Iterate` ancestor nodes This results in graphs always being maximally prepared. ### Always execute the deepest prepared node We now choose the next node to execute by traversing from the bottom of the graph instead of the top, choosing the first node whose inputs are all executed. This means we always execute the deepest node possible. ## Result Graphs now execute depth-first, so instead of an execution order like this: 1. TextToLatents 2. TextToLatents 3. TextToLatents 4. LatentsToImage 5. LatentsToImage 6. LatentsToImage ... we get an execution order like this: 1. TextToLatents 2. LatentsToImage 3. TextToLatents 4. LatentsToImage 5. TextToLatents 6. LatentsToImage Immediately after inference, the image is decoded and sent to the gallery. fixes #3400
2023-06-08 09:51:38 +00:00
# exclude nodes that have already been prepared
2023-03-03 06:02:00 +00:00
if n not in self.source_prepared_mapping
feat(nodes): depth-first execution There was an issue where for graphs w/ iterations, your images were output all at once, at the very end of processing. So if you canceled halfway through an execution of 10 nodes, you wouldn't get any images - even though you'd completed 5 images' worth of inference. ## Cause Because graphs executed breadth-first (i.e. depth-by-depth), leaf nodes were necessarily processed last. For image generation graphs, your `LatentsToImage` will be leaf nodes, and be the last depth to be executed. For example, a `TextToLatents` graph w/ 3 iterations would execute all 3 `TextToLatents` nodes fully before moving to the next depth, where the `LatentsToImage` nodes produce output images, resulting in a node execution order like this: 1. TextToLatents 2. TextToLatents 3. TextToLatents 4. LatentsToImage 5. LatentsToImage 6. LatentsToImage ## Solution This PR makes a two changes to graph execution to execute as deeply as it can along each branch of the graph. ### Eager node preparation We now prepare as many nodes as possible, instead of just a single node at a time. We also need to change the conditions in which nodes are prepared. Previously, nodes were prepared only when all of their direct ancestors were executed. The updated logic prepares nodes that: - are *not* `Iterate` nodes whose inputs have *not* been executed - do *not* have any unexecuted `Iterate` ancestor nodes This results in graphs always being maximally prepared. ### Always execute the deepest prepared node We now choose the next node to execute by traversing from the bottom of the graph instead of the top, choosing the first node whose inputs are all executed. This means we always execute the deepest node possible. ## Result Graphs now execute depth-first, so instead of an execution order like this: 1. TextToLatents 2. TextToLatents 3. TextToLatents 4. LatentsToImage 5. LatentsToImage 6. LatentsToImage ... we get an execution order like this: 1. TextToLatents 2. LatentsToImage 3. TextToLatents 4. LatentsToImage 5. TextToLatents 6. LatentsToImage Immediately after inference, the image is decoded and sent to the gallery. fixes #3400
2023-06-08 09:51:38 +00:00
# exclude iterate nodes whose inputs have not been executed
and not (
isinstance(self.graph.get_node(n), IterateInvocation) # `n` is an iterate node...
and not all((e[0] in self.executed for e in g.in_edges(n))) # ...that has unexecuted inputs
)
# exclude nodes who have unexecuted iterate ancestors
and not any(
(
isinstance(self.graph.get_node(a), IterateInvocation) # `a` is an iterate ancestor of `n`...
and a not in self.executed # ...that is not executed
for a in nx.ancestors(g, n) # for all ancestors `a` of node `n`
)
)
2023-03-03 06:02:00 +00:00
),
None,
)
if next_node_id == None:
return None
# Get all parents of the next node
next_node_parents = [e[0] for e in g.in_edges(next_node_id)]
# Create execution nodes
next_node = self.graph.get_node(next_node_id)
new_node_ids = list()
if isinstance(next_node, CollectInvocation):
# Collapse all iterator input mappings and create a single execution node for the collect invocation
2023-03-03 06:02:00 +00:00
all_iteration_mappings = list(
itertools.chain(
*(
((s, p) for p in self.source_prepared_mapping[s])
for s in next_node_parents
)
)
)
# all_iteration_mappings = list(set(itertools.chain(*prepared_parent_mappings)))
create_results = self._create_execution_node(
next_node_id, all_iteration_mappings
)
if create_results is not None:
new_node_ids.extend(create_results)
2023-03-03 06:02:00 +00:00
else: # Iterators or normal nodes
# Get all iterator combinations for this node
# Will produce a list of lists of prepared iterator nodes, from which results can be iterated
iterator_nodes = self._get_node_iterators(next_node_id)
2023-03-03 06:02:00 +00:00
iterator_nodes_prepared = [
list(self.source_prepared_mapping[n]) for n in iterator_nodes
]
iterator_node_prepared_combinations = list(
itertools.product(*iterator_nodes_prepared)
)
# Select the correct prepared parents for each iteration
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
# TODO: Handle a node mapping to none
eg = self.execution_graph.nx_graph_flat()
2023-03-03 06:02:00 +00:00
prepared_parent_mappings = [[(n, self._get_iteration_node(n, g, eg, it)) for n in next_node_parents] for it in iterator_node_prepared_combinations] # type: ignore
# Create execution node for each iteration
for iteration_mappings in prepared_parent_mappings:
2023-03-03 06:02:00 +00:00
create_results = self._create_execution_node(next_node_id, iteration_mappings) # type: ignore
if create_results is not None:
new_node_ids.extend(create_results)
return next(iter(new_node_ids), None)
2023-03-03 06:02:00 +00:00
def _get_iteration_node(
self,
source_node_path: str,
graph: nx.DiGraph,
execution_graph: nx.DiGraph,
prepared_iterator_nodes: list[str],
) -> Optional[str]:
"""Gets the prepared version of the specified source node that matches every iteration specified"""
prepared_nodes = self.source_prepared_mapping[source_node_path]
if len(prepared_nodes) == 1:
return next(iter(prepared_nodes))
2023-03-03 06:02:00 +00:00
# Check if the requested node is an iterator
2023-03-03 06:02:00 +00:00
prepared_iterator = next(
(n for n in prepared_nodes if n in prepared_iterator_nodes), None
)
if prepared_iterator is not None:
return prepared_iterator
# Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source)
2023-03-03 06:02:00 +00:00
iterator_source_node_mapping = [
(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes
]
parent_iterators = [
itn
for itn in iterator_source_node_mapping
if nx.has_path(graph, itn[1], source_node_path)
]
return next(
(
n
for n in prepared_nodes
if all(
nx.has_path(execution_graph, pit[0], n)
2023-03-03 06:02:00 +00:00
for pit in parent_iterators
)
),
None,
)
def _get_next_node(self) -> Optional[BaseInvocation]:
feat(nodes): depth-first execution There was an issue where for graphs w/ iterations, your images were output all at once, at the very end of processing. So if you canceled halfway through an execution of 10 nodes, you wouldn't get any images - even though you'd completed 5 images' worth of inference. ## Cause Because graphs executed breadth-first (i.e. depth-by-depth), leaf nodes were necessarily processed last. For image generation graphs, your `LatentsToImage` will be leaf nodes, and be the last depth to be executed. For example, a `TextToLatents` graph w/ 3 iterations would execute all 3 `TextToLatents` nodes fully before moving to the next depth, where the `LatentsToImage` nodes produce output images, resulting in a node execution order like this: 1. TextToLatents 2. TextToLatents 3. TextToLatents 4. LatentsToImage 5. LatentsToImage 6. LatentsToImage ## Solution This PR makes a two changes to graph execution to execute as deeply as it can along each branch of the graph. ### Eager node preparation We now prepare as many nodes as possible, instead of just a single node at a time. We also need to change the conditions in which nodes are prepared. Previously, nodes were prepared only when all of their direct ancestors were executed. The updated logic prepares nodes that: - are *not* `Iterate` nodes whose inputs have *not* been executed - do *not* have any unexecuted `Iterate` ancestor nodes This results in graphs always being maximally prepared. ### Always execute the deepest prepared node We now choose the next node to execute by traversing from the bottom of the graph instead of the top, choosing the first node whose inputs are all executed. This means we always execute the deepest node possible. ## Result Graphs now execute depth-first, so instead of an execution order like this: 1. TextToLatents 2. TextToLatents 3. TextToLatents 4. LatentsToImage 5. LatentsToImage 6. LatentsToImage ... we get an execution order like this: 1. TextToLatents 2. LatentsToImage 3. TextToLatents 4. LatentsToImage 5. TextToLatents 6. LatentsToImage Immediately after inference, the image is decoded and sent to the gallery. fixes #3400
2023-06-08 09:51:38 +00:00
"""Gets the deepest node that is ready to be executed"""
g = self.execution_graph.nx_graph()
feat(nodes): depth-first execution There was an issue where for graphs w/ iterations, your images were output all at once, at the very end of processing. So if you canceled halfway through an execution of 10 nodes, you wouldn't get any images - even though you'd completed 5 images' worth of inference. ## Cause Because graphs executed breadth-first (i.e. depth-by-depth), leaf nodes were necessarily processed last. For image generation graphs, your `LatentsToImage` will be leaf nodes, and be the last depth to be executed. For example, a `TextToLatents` graph w/ 3 iterations would execute all 3 `TextToLatents` nodes fully before moving to the next depth, where the `LatentsToImage` nodes produce output images, resulting in a node execution order like this: 1. TextToLatents 2. TextToLatents 3. TextToLatents 4. LatentsToImage 5. LatentsToImage 6. LatentsToImage ## Solution This PR makes a two changes to graph execution to execute as deeply as it can along each branch of the graph. ### Eager node preparation We now prepare as many nodes as possible, instead of just a single node at a time. We also need to change the conditions in which nodes are prepared. Previously, nodes were prepared only when all of their direct ancestors were executed. The updated logic prepares nodes that: - are *not* `Iterate` nodes whose inputs have *not* been executed - do *not* have any unexecuted `Iterate` ancestor nodes This results in graphs always being maximally prepared. ### Always execute the deepest prepared node We now choose the next node to execute by traversing from the bottom of the graph instead of the top, choosing the first node whose inputs are all executed. This means we always execute the deepest node possible. ## Result Graphs now execute depth-first, so instead of an execution order like this: 1. TextToLatents 2. TextToLatents 3. TextToLatents 4. LatentsToImage 5. LatentsToImage 6. LatentsToImage ... we get an execution order like this: 1. TextToLatents 2. LatentsToImage 3. TextToLatents 4. LatentsToImage 5. TextToLatents 6. LatentsToImage Immediately after inference, the image is decoded and sent to the gallery. fixes #3400
2023-06-08 09:51:38 +00:00
# Depth-first search with pre-order traversal is a depth-first topological sort
sorted_nodes = nx.dfs_preorder_nodes(g)
feat(nodes): depth-first execution There was an issue where for graphs w/ iterations, your images were output all at once, at the very end of processing. So if you canceled halfway through an execution of 10 nodes, you wouldn't get any images - even though you'd completed 5 images' worth of inference. ## Cause Because graphs executed breadth-first (i.e. depth-by-depth), leaf nodes were necessarily processed last. For image generation graphs, your `LatentsToImage` will be leaf nodes, and be the last depth to be executed. For example, a `TextToLatents` graph w/ 3 iterations would execute all 3 `TextToLatents` nodes fully before moving to the next depth, where the `LatentsToImage` nodes produce output images, resulting in a node execution order like this: 1. TextToLatents 2. TextToLatents 3. TextToLatents 4. LatentsToImage 5. LatentsToImage 6. LatentsToImage ## Solution This PR makes a two changes to graph execution to execute as deeply as it can along each branch of the graph. ### Eager node preparation We now prepare as many nodes as possible, instead of just a single node at a time. We also need to change the conditions in which nodes are prepared. Previously, nodes were prepared only when all of their direct ancestors were executed. The updated logic prepares nodes that: - are *not* `Iterate` nodes whose inputs have *not* been executed - do *not* have any unexecuted `Iterate` ancestor nodes This results in graphs always being maximally prepared. ### Always execute the deepest prepared node We now choose the next node to execute by traversing from the bottom of the graph instead of the top, choosing the first node whose inputs are all executed. This means we always execute the deepest node possible. ## Result Graphs now execute depth-first, so instead of an execution order like this: 1. TextToLatents 2. TextToLatents 3. TextToLatents 4. LatentsToImage 5. LatentsToImage 6. LatentsToImage ... we get an execution order like this: 1. TextToLatents 2. LatentsToImage 3. TextToLatents 4. LatentsToImage 5. TextToLatents 6. LatentsToImage Immediately after inference, the image is decoded and sent to the gallery. fixes #3400
2023-06-08 09:51:38 +00:00
next_node = next(
(
n
for n in sorted_nodes
feat(nodes): depth-first execution There was an issue where for graphs w/ iterations, your images were output all at once, at the very end of processing. So if you canceled halfway through an execution of 10 nodes, you wouldn't get any images - even though you'd completed 5 images' worth of inference. ## Cause Because graphs executed breadth-first (i.e. depth-by-depth), leaf nodes were necessarily processed last. For image generation graphs, your `LatentsToImage` will be leaf nodes, and be the last depth to be executed. For example, a `TextToLatents` graph w/ 3 iterations would execute all 3 `TextToLatents` nodes fully before moving to the next depth, where the `LatentsToImage` nodes produce output images, resulting in a node execution order like this: 1. TextToLatents 2. TextToLatents 3. TextToLatents 4. LatentsToImage 5. LatentsToImage 6. LatentsToImage ## Solution This PR makes a two changes to graph execution to execute as deeply as it can along each branch of the graph. ### Eager node preparation We now prepare as many nodes as possible, instead of just a single node at a time. We also need to change the conditions in which nodes are prepared. Previously, nodes were prepared only when all of their direct ancestors were executed. The updated logic prepares nodes that: - are *not* `Iterate` nodes whose inputs have *not* been executed - do *not* have any unexecuted `Iterate` ancestor nodes This results in graphs always being maximally prepared. ### Always execute the deepest prepared node We now choose the next node to execute by traversing from the bottom of the graph instead of the top, choosing the first node whose inputs are all executed. This means we always execute the deepest node possible. ## Result Graphs now execute depth-first, so instead of an execution order like this: 1. TextToLatents 2. TextToLatents 3. TextToLatents 4. LatentsToImage 5. LatentsToImage 6. LatentsToImage ... we get an execution order like this: 1. TextToLatents 2. LatentsToImage 3. TextToLatents 4. LatentsToImage 5. TextToLatents 6. LatentsToImage Immediately after inference, the image is decoded and sent to the gallery. fixes #3400
2023-06-08 09:51:38 +00:00
if n not in self.executed # the node must not already be executed...
and all((e[0] in self.executed for e in g.in_edges(n))) # ...and all its inputs must be executed
),
None,
)
if next_node is None:
return None
return self.execution_graph.nodes[next_node]
def _prepare_inputs(self, node: BaseInvocation):
2023-03-15 06:09:30 +00:00
input_edges = [e for e in self.execution_graph.edges if e.destination.node_id == node.id]
if isinstance(node, CollectInvocation):
2023-03-03 06:02:00 +00:00
output_collection = [
2023-03-15 06:09:30 +00:00
getattr(self.results[edge.source.node_id], edge.source.field)
2023-03-03 06:02:00 +00:00
for edge in input_edges
2023-03-15 06:09:30 +00:00
if edge.destination.field == "item"
2023-03-03 06:02:00 +00:00
]
setattr(node, "collection", output_collection)
else:
for edge in input_edges:
2023-03-15 06:09:30 +00:00
output_value = getattr(self.results[edge.source.node_id], edge.source.field)
setattr(node, edge.destination.field, output_value)
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
2023-03-15 06:09:30 +00:00
def _is_edge_valid(self, edge: Edge) -> bool:
try:
self.graph._validate_edge(edge)
except InvalidEdgeError:
return False
2023-03-03 06:02:00 +00:00
# Invalid if destination has already been prepared or executed
2023-03-15 06:09:30 +00:00
if edge.destination.node_id in self.source_prepared_mapping:
return False
# Otherwise, the edge is valid
return True
def _is_node_updatable(self, node_id: str) -> bool:
# The node is updatable as long as it hasn't been prepared or executed
return node_id not in self.source_prepared_mapping
def add_node(self, node: BaseInvocation) -> None:
self.graph.add_node(node)
2023-03-03 06:02:00 +00:00
def update_node(self, node_path: str, new_node: BaseInvocation) -> None:
if not self._is_node_updatable(node_path):
2023-03-03 06:02:00 +00:00
raise NodeAlreadyExecutedError(
f"Node {node_path} has already been prepared or executed and cannot be updated"
)
self.graph.update_node(node_path, new_node)
def delete_node(self, node_path: str) -> None:
if not self._is_node_updatable(node_path):
2023-03-03 06:02:00 +00:00
raise NodeAlreadyExecutedError(
f"Node {node_path} has already been prepared or executed and cannot be deleted"
)
self.graph.delete_node(node_path)
2023-03-15 06:09:30 +00:00
def add_edge(self, edge: Edge) -> None:
if not self._is_node_updatable(edge.destination.node_id):
2023-03-03 06:02:00 +00:00
raise NodeAlreadyExecutedError(
2023-03-15 06:09:30 +00:00
f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot be linked to"
2023-03-03 06:02:00 +00:00
)
self.graph.add_edge(edge)
2023-03-03 06:02:00 +00:00
2023-03-15 06:09:30 +00:00
def delete_edge(self, edge: Edge) -> None:
if not self._is_node_updatable(edge.destination.node_id):
2023-03-03 06:02:00 +00:00
raise NodeAlreadyExecutedError(
2023-03-15 06:09:30 +00:00
f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot have a source edge deleted"
2023-03-03 06:02:00 +00:00
)
self.graph.delete_edge(edge)
2023-03-03 06:02:00 +00:00
class ExposedNodeInput(BaseModel):
node_path: str = Field(description="The node path to the node with the input")
field: str = Field(description="The field name of the input")
alias: str = Field(description="The alias of the input")
class ExposedNodeOutput(BaseModel):
node_path: str = Field(description="The node path to the node with the output")
field: str = Field(description="The field name of the output")
alias: str = Field(description="The alias of the output")
class LibraryGraph(BaseModel):
id: str = Field(description="The unique identifier for this library graph", default_factory=uuid.uuid4)
graph: Graph = Field(description="The graph")
name: str = Field(description="The name of the graph")
description: str = Field(description="The description of the graph")
exposed_inputs: list[ExposedNodeInput] = Field(description="The inputs exposed by this graph", default_factory=list)
exposed_outputs: list[ExposedNodeOutput] = Field(description="The outputs exposed by this graph", default_factory=list)
@validator('exposed_inputs', 'exposed_outputs')
def validate_exposed_aliases(cls, v):
if len(v) != len(set(i.alias for i in v)):
raise ValueError("Duplicate exposed alias")
return v
@root_validator
def validate_exposed_nodes(cls, values):
graph = values['graph']
# Validate exposed inputs
for exposed_input in values['exposed_inputs']:
if not graph.has_node(exposed_input.node_path):
raise ValueError(f"Exposed input node {exposed_input.node_path} does not exist")
node = graph.get_node(exposed_input.node_path)
if get_input_field(node, exposed_input.field) is None:
raise ValueError(f"Exposed input field {exposed_input.field} does not exist on node {exposed_input.node_path}")
# Validate exposed outputs
for exposed_output in values['exposed_outputs']:
if not graph.has_node(exposed_output.node_path):
raise ValueError(f"Exposed output node {exposed_output.node_path} does not exist")
node = graph.get_node(exposed_output.node_path)
if get_output_field(node, exposed_output.field) is None:
raise ValueError(f"Exposed output field {exposed_output.field} does not exist on node {exposed_output.node_path}")
return values
GraphInvocation.update_forward_refs()