mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tidy(nodes): remove GraphInvocation
`GraphInvocation` is a node that can contain a whole graph. It is removed for a number of reasons: 1. This feature was unused (the UI doesn't support it) and there is no plan for it to be used. The use-case it served is known in other node execution engines as "node groups" or "blocks" - a self-contained group of nodes, which has group inputs and outputs. This is a planned feature that will be handled client-side. 2. It adds substantial complexity to the graph processing logic. It's probably not enough to have a measurable performance impact but it does make it harder to work in the graph logic. 3. It allows for graphs to be recursive, and the improved invocations union handling does not play well with it. Actually, it works fine within `graph.py` but not in the tests for some reason. I do not understand why. There's probably a workaround, but I took this as encouragement to remove `GraphInvocation` from the app since we don't use it.
This commit is contained in:
parent
d7adab89bd
commit
ab83fb2cea
@ -184,10 +184,6 @@ class NodeIdMismatchError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidSubGraphError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class CyclicalGraphError(ValueError):
|
||||
pass
|
||||
|
||||
@ -196,25 +192,6 @@ class UnknownGraphValidationError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
# TODO: Create and use an Empty output?
|
||||
@invocation_output("graph_output")
|
||||
class GraphInvocationOutput(BaseInvocationOutput):
|
||||
pass
|
||||
|
||||
|
||||
# TODO: Fill this out and move to invocations
|
||||
@invocation("graph", version="1.0.0")
|
||||
class GraphInvocation(BaseInvocation):
|
||||
"""Execute a graph"""
|
||||
|
||||
# TODO: figure out how to create a default here
|
||||
graph: "Graph" = InputField(description="The graph to run", default=None)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> GraphInvocationOutput:
|
||||
"""Invoke with provided services and return outputs."""
|
||||
return GraphInvocationOutput()
|
||||
|
||||
|
||||
@invocation_output("iterate_output")
|
||||
class IterateInvocationOutput(BaseInvocationOutput):
|
||||
"""Used to connect iteration outputs. Will be expanded to a specific output."""
|
||||
@ -346,41 +323,21 @@ class Graph(BaseModel):
|
||||
|
||||
self.nodes[node.id] = node
|
||||
|
||||
def _get_graph_and_node(self, node_path: str) -> tuple["Graph", str]:
|
||||
"""Returns the graph and node id for a node path."""
|
||||
# Materialized graphs may have nodes at the top level
|
||||
if node_path in self.nodes:
|
||||
return (self, node_path)
|
||||
|
||||
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
|
||||
if node_id not in self.nodes:
|
||||
raise NodeNotFoundError(f"Node {node_path} not found in graph")
|
||||
|
||||
node = self.nodes[node_id]
|
||||
|
||||
if not isinstance(node, GraphInvocation):
|
||||
# There's more node path left but this isn't a graph - failure
|
||||
raise NodeNotFoundError("Node path terminated early at a non-graph node")
|
||||
|
||||
return node.graph._get_graph_and_node(node_path[node_path.index(".") + 1 :])
|
||||
|
||||
def delete_node(self, node_path: str) -> None:
|
||||
def delete_node(self, node_id: str) -> None:
|
||||
"""Deletes a node from a graph"""
|
||||
|
||||
try:
|
||||
graph, node_id = self._get_graph_and_node(node_path)
|
||||
|
||||
# Delete edges for this node
|
||||
input_edges = self._get_input_edges_and_graphs(node_path)
|
||||
output_edges = self._get_output_edges_and_graphs(node_path)
|
||||
input_edges = self._get_input_edges(node_id)
|
||||
output_edges = self._get_output_edges(node_id)
|
||||
|
||||
for edge_graph, _, edge in input_edges:
|
||||
edge_graph.delete_edge(edge)
|
||||
for edge in input_edges:
|
||||
self.delete_edge(edge)
|
||||
|
||||
for edge_graph, _, edge in output_edges:
|
||||
edge_graph.delete_edge(edge)
|
||||
for edge in output_edges:
|
||||
self.delete_edge(edge)
|
||||
|
||||
del graph.nodes[node_id]
|
||||
del self.nodes[node_id]
|
||||
|
||||
except NodeNotFoundError:
|
||||
pass # Ignore, not doesn't exist (should this throw?)
|
||||
@ -430,13 +387,6 @@ class Graph(BaseModel):
|
||||
if k != v.id:
|
||||
raise NodeIdMismatchError(f"Node ids must match, got {k} and {v.id}")
|
||||
|
||||
# Validate all subgraphs
|
||||
for gn in (n for n in self.nodes.values() if isinstance(n, GraphInvocation)):
|
||||
try:
|
||||
gn.graph.validate_self()
|
||||
except Exception as e:
|
||||
raise InvalidSubGraphError(f"Subgraph {gn.id} is invalid") from e
|
||||
|
||||
# Validate that all edges match nodes and fields in the graph
|
||||
for edge in self.edges:
|
||||
source_node = self.nodes.get(edge.source.node_id, None)
|
||||
@ -498,7 +448,6 @@ class Graph(BaseModel):
|
||||
except (
|
||||
DuplicateNodeIdError,
|
||||
NodeIdMismatchError,
|
||||
InvalidSubGraphError,
|
||||
NodeNotFoundError,
|
||||
NodeFieldNotFoundError,
|
||||
CyclicalGraphError,
|
||||
@ -519,7 +468,7 @@ class Graph(BaseModel):
|
||||
def _validate_edge(self, edge: Edge):
|
||||
"""Validates that a new edge doesn't create a cycle in the graph"""
|
||||
|
||||
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
|
||||
# Validate that the nodes exist
|
||||
try:
|
||||
from_node = self.get_node(edge.source.node_id)
|
||||
to_node = self.get_node(edge.destination.node_id)
|
||||
@ -586,171 +535,90 @@ class Graph(BaseModel):
|
||||
f"Collector input type does not match collector output type: {edge.source.node_id}.{edge.source.field} to {edge.destination.node_id}.{edge.destination.field}"
|
||||
)
|
||||
|
||||
def has_node(self, node_path: str) -> bool:
|
||||
def has_node(self, node_id: str) -> bool:
|
||||
"""Determines whether or not a node exists in the graph."""
|
||||
try:
|
||||
n = self.get_node(node_path)
|
||||
if n is not None:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
_ = self.get_node(node_id)
|
||||
return True
|
||||
except NodeNotFoundError:
|
||||
return False
|
||||
|
||||
def get_node(self, node_path: str) -> BaseInvocation:
|
||||
"""Gets a node from the graph using a node path."""
|
||||
# Materialized graphs may have nodes at the top level
|
||||
graph, node_id = self._get_graph_and_node(node_path)
|
||||
return graph.nodes[node_id]
|
||||
def get_node(self, node_id: str) -> BaseInvocation:
|
||||
"""Gets a node from the graph."""
|
||||
try:
|
||||
return self.nodes[node_id]
|
||||
except KeyError as e:
|
||||
raise NodeNotFoundError(f"Node {node_id} not found in graph") from e
|
||||
|
||||
def _get_node_path(self, node_id: str, prefix: Optional[str] = None) -> str:
|
||||
return node_id if prefix is None or prefix == "" else f"{prefix}.{node_id}"
|
||||
|
||||
def update_node(self, node_path: str, new_node: BaseInvocation) -> None:
|
||||
def update_node(self, node_id: str, new_node: BaseInvocation) -> None:
|
||||
"""Updates a node in the graph."""
|
||||
graph, node_id = self._get_graph_and_node(node_path)
|
||||
node = graph.nodes[node_id]
|
||||
node = self.nodes[node_id]
|
||||
|
||||
# Ensure the node type matches the new node
|
||||
if type(node) is not type(new_node):
|
||||
raise TypeError(f"Node {node_path} is type {type(node)} but new node is type {type(new_node)}")
|
||||
raise TypeError(f"Node {node_id} is type {type(node)} but new node is type {type(new_node)}")
|
||||
|
||||
# Ensure the new id is either the same or is not in the graph
|
||||
prefix = None if "." not in node_path else node_path[: node_path.rindex(".")]
|
||||
new_path = self._get_node_path(new_node.id, prefix=prefix)
|
||||
if new_node.id != node.id and self.has_node(new_path):
|
||||
raise NodeAlreadyInGraphError("Node with id {new_node.id} already exists in graph")
|
||||
if new_node.id != node.id and self.has_node(new_node.id):
|
||||
raise NodeAlreadyInGraphError(f"Node with id {new_node.id} already exists in graph")
|
||||
|
||||
# Set the new node in the graph
|
||||
graph.nodes[new_node.id] = new_node
|
||||
self.nodes[new_node.id] = new_node
|
||||
if new_node.id != node.id:
|
||||
input_edges = self._get_input_edges_and_graphs(node_path)
|
||||
output_edges = self._get_output_edges_and_graphs(node_path)
|
||||
input_edges = self._get_input_edges(node_id)
|
||||
output_edges = self._get_output_edges(node_id)
|
||||
|
||||
# Delete node and all edges
|
||||
graph.delete_node(node_path)
|
||||
self.delete_node(node_id)
|
||||
|
||||
# Create new edges for each input and output
|
||||
for graph, _, edge in input_edges:
|
||||
# Remove the graph prefix from the node path
|
||||
new_graph_node_path = (
|
||||
new_node.id
|
||||
if "." not in edge.destination.node_id
|
||||
else f'{edge.destination.node_id[edge.destination.node_id.rindex("."):]}.{new_node.id}'
|
||||
)
|
||||
graph.add_edge(
|
||||
for edge in input_edges:
|
||||
self.add_edge(
|
||||
Edge(
|
||||
source=edge.source,
|
||||
destination=EdgeConnection(node_id=new_graph_node_path, field=edge.destination.field),
|
||||
destination=EdgeConnection(node_id=new_node.id, field=edge.destination.field),
|
||||
)
|
||||
)
|
||||
|
||||
for graph, _, edge in output_edges:
|
||||
# Remove the graph prefix from the node path
|
||||
new_graph_node_path = (
|
||||
new_node.id
|
||||
if "." not in edge.source.node_id
|
||||
else f'{edge.source.node_id[edge.source.node_id.rindex("."):]}.{new_node.id}'
|
||||
)
|
||||
graph.add_edge(
|
||||
for edge in output_edges:
|
||||
self.add_edge(
|
||||
Edge(
|
||||
source=EdgeConnection(node_id=new_graph_node_path, field=edge.source.field),
|
||||
source=EdgeConnection(node_id=new_node.id, field=edge.source.field),
|
||||
destination=edge.destination,
|
||||
)
|
||||
)
|
||||
|
||||
def _get_input_edges(self, node_path: str, field: Optional[str] = None) -> list[Edge]:
|
||||
"""Gets all input edges for a node"""
|
||||
edges = self._get_input_edges_and_graphs(node_path)
|
||||
def _get_input_edges(self, node_id: str, field: Optional[str] = None) -> list[Edge]:
|
||||
"""Gets all input edges for a node. If field is provided, only edges to that field are returned."""
|
||||
|
||||
# Filter to edges that match the field
|
||||
filtered_edges = (e for e in edges if field is None or e[2].destination.field == field)
|
||||
edges = [e for e in self.edges if e.destination.node_id == node_id]
|
||||
|
||||
# Create full node paths for each edge
|
||||
return [
|
||||
Edge(
|
||||
source=EdgeConnection(
|
||||
node_id=self._get_node_path(e.source.node_id, prefix=prefix),
|
||||
field=e.source.field,
|
||||
),
|
||||
destination=EdgeConnection(
|
||||
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
|
||||
field=e.destination.field,
|
||||
),
|
||||
)
|
||||
for _, prefix, e in filtered_edges
|
||||
]
|
||||
if field is None:
|
||||
return edges
|
||||
|
||||
def _get_input_edges_and_graphs(
|
||||
self, node_path: str, prefix: Optional[str] = None
|
||||
) -> list[tuple["Graph", Union[str, None], Edge]]:
|
||||
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
|
||||
edges = []
|
||||
filtered_edges = [e for e in edges if e.destination.field == field]
|
||||
|
||||
# Return any input edges that appear in this graph
|
||||
edges.extend([(self, prefix, e) for e in self.edges if e.destination.node_id == node_path])
|
||||
return filtered_edges
|
||||
|
||||
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
|
||||
node = self.nodes[node_id]
|
||||
def _get_output_edges(self, node_id: str, field: Optional[str] = None) -> list[Edge]:
|
||||
"""Gets all output edges for a node. If field is provided, only edges from that field are returned."""
|
||||
edges = [e for e in self.edges if e.source.node_id == node_id]
|
||||
|
||||
if isinstance(node, GraphInvocation):
|
||||
graph = node.graph
|
||||
graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix)
|
||||
graph_edges = graph._get_input_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path)
|
||||
edges.extend(graph_edges)
|
||||
if field is None:
|
||||
return edges
|
||||
|
||||
return edges
|
||||
filtered_edges = [e for e in edges if e.source.field == field]
|
||||
|
||||
def _get_output_edges(self, node_path: str, field: str) -> list[Edge]:
|
||||
"""Gets all output edges for a node"""
|
||||
edges = self._get_output_edges_and_graphs(node_path)
|
||||
|
||||
# Filter to edges that match the field
|
||||
filtered_edges = (e for e in edges if e[2].source.field == field)
|
||||
|
||||
# Create full node paths for each edge
|
||||
return [
|
||||
Edge(
|
||||
source=EdgeConnection(
|
||||
node_id=self._get_node_path(e.source.node_id, prefix=prefix),
|
||||
field=e.source.field,
|
||||
),
|
||||
destination=EdgeConnection(
|
||||
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
|
||||
field=e.destination.field,
|
||||
),
|
||||
)
|
||||
for _, prefix, e in filtered_edges
|
||||
]
|
||||
|
||||
def _get_output_edges_and_graphs(
|
||||
self, node_path: str, prefix: Optional[str] = None
|
||||
) -> list[tuple["Graph", Union[str, None], Edge]]:
|
||||
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
|
||||
edges = []
|
||||
|
||||
# Return any input edges that appear in this graph
|
||||
edges.extend([(self, prefix, e) for e in self.edges if e.source.node_id == node_path])
|
||||
|
||||
node_id = node_path if "." not in node_path else node_path[: node_path.index(".")]
|
||||
node = self.nodes[node_id]
|
||||
|
||||
if isinstance(node, GraphInvocation):
|
||||
graph = node.graph
|
||||
graph_path = node.id if prefix is None or prefix == "" else self._get_node_path(node.id, prefix=prefix)
|
||||
graph_edges = graph._get_output_edges_and_graphs(node_path[(len(node_id) + 1) :], prefix=graph_path)
|
||||
edges.extend(graph_edges)
|
||||
|
||||
return edges
|
||||
return filtered_edges
|
||||
|
||||
def _is_iterator_connection_valid(
|
||||
self,
|
||||
node_path: str,
|
||||
node_id: str,
|
||||
new_input: Optional[EdgeConnection] = None,
|
||||
new_output: Optional[EdgeConnection] = None,
|
||||
) -> bool:
|
||||
inputs = [e.source for e in self._get_input_edges(node_path, "collection")]
|
||||
outputs = [e.destination for e in self._get_output_edges(node_path, "item")]
|
||||
inputs = [e.source for e in self._get_input_edges(node_id, "collection")]
|
||||
outputs = [e.destination for e in self._get_output_edges(node_id, "item")]
|
||||
|
||||
if new_input is not None:
|
||||
inputs.append(new_input)
|
||||
@ -778,12 +646,12 @@ class Graph(BaseModel):
|
||||
|
||||
def _is_collector_connection_valid(
|
||||
self,
|
||||
node_path: str,
|
||||
node_id: str,
|
||||
new_input: Optional[EdgeConnection] = None,
|
||||
new_output: Optional[EdgeConnection] = None,
|
||||
) -> bool:
|
||||
inputs = [e.source for e in self._get_input_edges(node_path, "item")]
|
||||
outputs = [e.destination for e in self._get_output_edges(node_path, "collection")]
|
||||
inputs = [e.source for e in self._get_input_edges(node_id, "item")]
|
||||
outputs = [e.destination for e in self._get_output_edges(node_id, "collection")]
|
||||
|
||||
if new_input is not None:
|
||||
inputs.append(new_input)
|
||||
@ -839,27 +707,17 @@ class Graph(BaseModel):
|
||||
g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges})
|
||||
return g
|
||||
|
||||
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None, prefix: Optional[str] = None) -> nx.DiGraph:
|
||||
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None) -> nx.DiGraph:
|
||||
"""Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)"""
|
||||
g = nx_graph or nx.DiGraph()
|
||||
|
||||
# Add all nodes from this graph except graph/iteration nodes
|
||||
g.add_nodes_from(
|
||||
[
|
||||
self._get_node_path(n.id, prefix)
|
||||
for n in self.nodes.values()
|
||||
if not isinstance(n, GraphInvocation) and not isinstance(n, IterateInvocation)
|
||||
]
|
||||
)
|
||||
|
||||
# Expand graph nodes
|
||||
for sgn in (gn for gn in self.nodes.values() if isinstance(gn, GraphInvocation)):
|
||||
g = sgn.graph.nx_graph_flat(g, self._get_node_path(sgn.id, prefix))
|
||||
g.add_nodes_from([n.id for n in self.nodes.values() if not isinstance(n, IterateInvocation)])
|
||||
|
||||
# TODO: figure out if iteration nodes need to be expanded
|
||||
|
||||
unique_edges = {(e.source.node_id, e.destination.node_id) for e in self.edges}
|
||||
g.add_edges_from([(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix)) for e in unique_edges])
|
||||
g.add_edges_from([(e[0], e[1]) for e in unique_edges])
|
||||
return g
|
||||
|
||||
|
||||
@ -1017,17 +875,17 @@ class GraphExecutionState(BaseModel):
|
||||
"""Returns true if the graph has any errors"""
|
||||
return len(self.errors) > 0
|
||||
|
||||
def _create_execution_node(self, node_path: str, iteration_node_map: list[tuple[str, str]]) -> list[str]:
|
||||
def _create_execution_node(self, node_id: str, iteration_node_map: list[tuple[str, str]]) -> list[str]:
|
||||
"""Prepares an iteration node and connects all edges, returning the new node id"""
|
||||
|
||||
node = self.graph.get_node(node_path)
|
||||
node = self.graph.get_node(node_id)
|
||||
|
||||
self_iteration_count = -1
|
||||
|
||||
# If this is an iterator node, we must create a copy for each iteration
|
||||
if isinstance(node, IterateInvocation):
|
||||
# Get input collection edge (should error if there are no inputs)
|
||||
input_collection_edge = next(iter(self.graph._get_input_edges(node_path, "collection")))
|
||||
input_collection_edge = next(iter(self.graph._get_input_edges(node_id, "collection")))
|
||||
input_collection_prepared_node_id = next(
|
||||
n[1] for n in iteration_node_map if n[0] == input_collection_edge.source.node_id
|
||||
)
|
||||
@ -1041,7 +899,7 @@ class GraphExecutionState(BaseModel):
|
||||
return new_nodes
|
||||
|
||||
# Get all input edges
|
||||
input_edges = self.graph._get_input_edges(node_path)
|
||||
input_edges = self.graph._get_input_edges(node_id)
|
||||
|
||||
# Create new edges for this iteration
|
||||
# For collect nodes, this may contain multiple inputs to the same field
|
||||
@ -1068,10 +926,10 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
# Add to execution graph
|
||||
self.execution_graph.add_node(new_node)
|
||||
self.prepared_source_mapping[new_node.id] = node_path
|
||||
if node_path not in self.source_prepared_mapping:
|
||||
self.source_prepared_mapping[node_path] = set()
|
||||
self.source_prepared_mapping[node_path].add(new_node.id)
|
||||
self.prepared_source_mapping[new_node.id] = node_id
|
||||
if node_id not in self.source_prepared_mapping:
|
||||
self.source_prepared_mapping[node_id] = set()
|
||||
self.source_prepared_mapping[node_id].add(new_node.id)
|
||||
|
||||
# Add new edges to execution graph
|
||||
for edge in new_edges:
|
||||
@ -1175,13 +1033,13 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
def _get_iteration_node(
|
||||
self,
|
||||
source_node_path: str,
|
||||
source_node_id: str,
|
||||
graph: nx.DiGraph,
|
||||
execution_graph: nx.DiGraph,
|
||||
prepared_iterator_nodes: list[str],
|
||||
) -> Optional[str]:
|
||||
"""Gets the prepared version of the specified source node that matches every iteration specified"""
|
||||
prepared_nodes = self.source_prepared_mapping[source_node_path]
|
||||
prepared_nodes = self.source_prepared_mapping[source_node_id]
|
||||
if len(prepared_nodes) == 1:
|
||||
return next(iter(prepared_nodes))
|
||||
|
||||
@ -1192,7 +1050,7 @@ class GraphExecutionState(BaseModel):
|
||||
|
||||
# Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source)
|
||||
iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes]
|
||||
parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_path)]
|
||||
parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_id)]
|
||||
|
||||
return next(
|
||||
(n for n in prepared_nodes if all(nx.has_path(execution_graph, pit[0], n) for pit in parent_iterators)),
|
||||
@ -1261,19 +1119,19 @@ class GraphExecutionState(BaseModel):
|
||||
def add_node(self, node: BaseInvocation) -> None:
|
||||
self.graph.add_node(node)
|
||||
|
||||
def update_node(self, node_path: str, new_node: BaseInvocation) -> None:
|
||||
if not self._is_node_updatable(node_path):
|
||||
def update_node(self, node_id: str, new_node: BaseInvocation) -> None:
|
||||
if not self._is_node_updatable(node_id):
|
||||
raise NodeAlreadyExecutedError(
|
||||
f"Node {node_path} has already been prepared or executed and cannot be updated"
|
||||
f"Node {node_id} has already been prepared or executed and cannot be updated"
|
||||
)
|
||||
self.graph.update_node(node_path, new_node)
|
||||
self.graph.update_node(node_id, new_node)
|
||||
|
||||
def delete_node(self, node_path: str) -> None:
|
||||
if not self._is_node_updatable(node_path):
|
||||
def delete_node(self, node_id: str) -> None:
|
||||
if not self._is_node_updatable(node_id):
|
||||
raise NodeAlreadyExecutedError(
|
||||
f"Node {node_path} has already been prepared or executed and cannot be deleted"
|
||||
f"Node {node_id} has already been prepared or executed and cannot be deleted"
|
||||
)
|
||||
self.graph.delete_node(node_path)
|
||||
self.graph.delete_node(node_id)
|
||||
|
||||
def add_edge(self, edge: Edge) -> None:
|
||||
if not self._is_node_updatable(edge.destination.node_id):
|
||||
|
@ -23,7 +23,7 @@ from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.invocation_stats.invocation_stats_default import InvocationStatsService
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.session_queue.session_queue_common import DEFAULT_QUEUE_ID
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -35,17 +35,6 @@ def simple_graph():
|
||||
return g
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def graph_with_subgraph():
|
||||
sub_g = Graph()
|
||||
sub_g.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
|
||||
sub_g.add_node(TextToImageTestInvocation(id="2"))
|
||||
sub_g.add_edge(create_edge("1", "prompt", "2", "prompt"))
|
||||
g = Graph()
|
||||
g.add_node(GraphInvocation(id="1", graph=sub_g))
|
||||
return g
|
||||
|
||||
|
||||
# This must be defined here to avoid issues with the dynamic creation of the union of all invocation types
|
||||
# Defining it in a separate module will cause the union to be incomplete, and pydantic will not validate
|
||||
# the test invocations.
|
||||
|
@ -8,8 +8,6 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.image import ShowImageInvocation
|
||||
from invokeai.app.invocations.math import AddInvocation, SubtractInvocation
|
||||
from invokeai.app.invocations.primitives import (
|
||||
FloatCollectionInvocation,
|
||||
FloatInvocation,
|
||||
@ -17,13 +15,11 @@ from invokeai.app.invocations.primitives import (
|
||||
StringInvocation,
|
||||
)
|
||||
from invokeai.app.invocations.upscale import ESRGANInvocation
|
||||
from invokeai.app.services.shared.default_graphs import create_text_to_image
|
||||
from invokeai.app.services.shared.graph import (
|
||||
CollectInvocation,
|
||||
Edge,
|
||||
EdgeConnection,
|
||||
Graph,
|
||||
GraphInvocation,
|
||||
InvalidEdgeError,
|
||||
IterateInvocation,
|
||||
NodeAlreadyInGraphError,
|
||||
@ -425,19 +421,19 @@ def test_graph_invalid_if_edges_reference_missing_nodes():
|
||||
assert g.is_valid() is False
|
||||
|
||||
|
||||
def test_graph_invalid_if_subgraph_invalid():
|
||||
g = Graph()
|
||||
n1 = GraphInvocation(id="1")
|
||||
n1.graph = Graph()
|
||||
# def test_graph_invalid_if_subgraph_invalid():
|
||||
# g = Graph()
|
||||
# n1 = GraphInvocation(id="1")
|
||||
# n1.graph = Graph()
|
||||
|
||||
n1_1 = TextToImageTestInvocation(id="2", prompt="Banana sushi")
|
||||
n1.graph.nodes[n1_1.id] = n1_1
|
||||
e1 = create_edge("1", "image", "2", "image")
|
||||
n1.graph.edges.append(e1)
|
||||
# n1_1 = TextToImageTestInvocation(id="2", prompt="Banana sushi")
|
||||
# n1.graph.nodes[n1_1.id] = n1_1
|
||||
# e1 = create_edge("1", "image", "2", "image")
|
||||
# n1.graph.edges.append(e1)
|
||||
|
||||
g.nodes[n1.id] = n1
|
||||
# g.nodes[n1.id] = n1
|
||||
|
||||
assert g.is_valid() is False
|
||||
# assert g.is_valid() is False
|
||||
|
||||
|
||||
def test_graph_invalid_if_has_cycle():
|
||||
@ -466,108 +462,108 @@ def test_graph_invalid_with_invalid_connection():
|
||||
assert g.is_valid() is False
|
||||
|
||||
|
||||
# TODO: Subgraph operations
|
||||
def test_graph_gets_subgraph_node():
|
||||
g = Graph()
|
||||
n1 = GraphInvocation(id="1")
|
||||
n1.graph = Graph()
|
||||
# # TODO: Subgraph operations
|
||||
# def test_graph_gets_subgraph_node():
|
||||
# g = Graph()
|
||||
# n1 = GraphInvocation(id="1")
|
||||
# n1.graph = Graph()
|
||||
|
||||
n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
|
||||
n1.graph.add_node(n1_1)
|
||||
# n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
|
||||
# n1.graph.add_node(n1_1)
|
||||
|
||||
g.add_node(n1)
|
||||
# g.add_node(n1)
|
||||
|
||||
result = g.get_node("1.1")
|
||||
# result = g.get_node("1.1")
|
||||
|
||||
assert result is not None
|
||||
assert result.id == "1"
|
||||
assert result == n1_1
|
||||
# assert result is not None
|
||||
# assert result.id == "1"
|
||||
# assert result == n1_1
|
||||
|
||||
|
||||
def test_graph_expands_subgraph():
|
||||
g = Graph()
|
||||
n1 = GraphInvocation(id="1")
|
||||
n1.graph = Graph()
|
||||
# def test_graph_expands_subgraph():
|
||||
# g = Graph()
|
||||
# n1 = GraphInvocation(id="1")
|
||||
# n1.graph = Graph()
|
||||
|
||||
n1_1 = AddInvocation(id="1", a=1, b=2)
|
||||
n1_2 = SubtractInvocation(id="2", b=3)
|
||||
n1.graph.add_node(n1_1)
|
||||
n1.graph.add_node(n1_2)
|
||||
n1.graph.add_edge(create_edge("1", "value", "2", "a"))
|
||||
# n1_1 = AddInvocation(id="1", a=1, b=2)
|
||||
# n1_2 = SubtractInvocation(id="2", b=3)
|
||||
# n1.graph.add_node(n1_1)
|
||||
# n1.graph.add_node(n1_2)
|
||||
# n1.graph.add_edge(create_edge("1", "value", "2", "a"))
|
||||
|
||||
g.add_node(n1)
|
||||
# g.add_node(n1)
|
||||
|
||||
n2 = AddInvocation(id="2", b=5)
|
||||
g.add_node(n2)
|
||||
g.add_edge(create_edge("1.2", "value", "2", "a"))
|
||||
# n2 = AddInvocation(id="2", b=5)
|
||||
# g.add_node(n2)
|
||||
# g.add_edge(create_edge("1.2", "value", "2", "a"))
|
||||
|
||||
dg = g.nx_graph_flat()
|
||||
assert set(dg.nodes) == {"1.1", "1.2", "2"}
|
||||
assert set(dg.edges) == {("1.1", "1.2"), ("1.2", "2")}
|
||||
# dg = g.nx_graph_flat()
|
||||
# assert set(dg.nodes) == {"1.1", "1.2", "2"}
|
||||
# assert set(dg.edges) == {("1.1", "1.2"), ("1.2", "2")}
|
||||
|
||||
|
||||
def test_graph_subgraph_t2i():
|
||||
g = Graph()
|
||||
n1 = GraphInvocation(id="1")
|
||||
# def test_graph_subgraph_t2i():
|
||||
# g = Graph()
|
||||
# n1 = GraphInvocation(id="1")
|
||||
|
||||
# Get text to image default graph
|
||||
lg = create_text_to_image()
|
||||
n1.graph = lg.graph
|
||||
# # Get text to image default graph
|
||||
# lg = create_text_to_image()
|
||||
# n1.graph = lg.graph
|
||||
|
||||
g.add_node(n1)
|
||||
# g.add_node(n1)
|
||||
|
||||
n2 = IntegerInvocation(id="2", value=512)
|
||||
n3 = IntegerInvocation(id="3", value=256)
|
||||
# n2 = IntegerInvocation(id="2", value=512)
|
||||
# n3 = IntegerInvocation(id="3", value=256)
|
||||
|
||||
g.add_node(n2)
|
||||
g.add_node(n3)
|
||||
# g.add_node(n2)
|
||||
# g.add_node(n3)
|
||||
|
||||
g.add_edge(create_edge("2", "value", "1.width", "value"))
|
||||
g.add_edge(create_edge("3", "value", "1.height", "value"))
|
||||
# g.add_edge(create_edge("2", "value", "1.width", "value"))
|
||||
# g.add_edge(create_edge("3", "value", "1.height", "value"))
|
||||
|
||||
n4 = ShowImageInvocation(id="4")
|
||||
g.add_node(n4)
|
||||
g.add_edge(create_edge("1.8", "image", "4", "image"))
|
||||
# n4 = ShowImageInvocation(id="4")
|
||||
# g.add_node(n4)
|
||||
# g.add_edge(create_edge("1.8", "image", "4", "image"))
|
||||
|
||||
# Validate
|
||||
dg = g.nx_graph_flat()
|
||||
assert set(dg.nodes) == {"1.width", "1.height", "1.seed", "1.3", "1.4", "1.5", "1.6", "1.7", "1.8", "2", "3", "4"}
|
||||
expected_edges = [(f"1.{e.source.node_id}", f"1.{e.destination.node_id}") for e in lg.graph.edges]
|
||||
expected_edges.extend([("2", "1.width"), ("3", "1.height"), ("1.8", "4")])
|
||||
print(expected_edges)
|
||||
print(list(dg.edges))
|
||||
assert set(dg.edges) == set(expected_edges)
|
||||
# # Validate
|
||||
# dg = g.nx_graph_flat()
|
||||
# assert set(dg.nodes) == {"1.width", "1.height", "1.seed", "1.3", "1.4", "1.5", "1.6", "1.7", "1.8", "2", "3", "4"}
|
||||
# expected_edges = [(f"1.{e.source.node_id}", f"1.{e.destination.node_id}") for e in lg.graph.edges]
|
||||
# expected_edges.extend([("2", "1.width"), ("3", "1.height"), ("1.8", "4")])
|
||||
# print(expected_edges)
|
||||
# print(list(dg.edges))
|
||||
# assert set(dg.edges) == set(expected_edges)
|
||||
|
||||
|
||||
def test_graph_fails_to_get_missing_subgraph_node():
|
||||
g = Graph()
|
||||
n1 = GraphInvocation(id="1")
|
||||
n1.graph = Graph()
|
||||
# def test_graph_fails_to_get_missing_subgraph_node():
|
||||
# g = Graph()
|
||||
# n1 = GraphInvocation(id="1")
|
||||
# n1.graph = Graph()
|
||||
|
||||
n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
|
||||
n1.graph.add_node(n1_1)
|
||||
# n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
|
||||
# n1.graph.add_node(n1_1)
|
||||
|
||||
g.add_node(n1)
|
||||
# g.add_node(n1)
|
||||
|
||||
with pytest.raises(NodeNotFoundError):
|
||||
_ = g.get_node("1.2")
|
||||
# with pytest.raises(NodeNotFoundError):
|
||||
# _ = g.get_node("1.2")
|
||||
|
||||
|
||||
def test_graph_fails_to_enumerate_non_subgraph_node():
|
||||
g = Graph()
|
||||
n1 = GraphInvocation(id="1")
|
||||
n1.graph = Graph()
|
||||
# def test_graph_fails_to_enumerate_non_subgraph_node():
|
||||
# g = Graph()
|
||||
# n1 = GraphInvocation(id="1")
|
||||
# n1.graph = Graph()
|
||||
|
||||
n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
|
||||
n1.graph.add_node(n1_1)
|
||||
# n1_1 = TextToImageTestInvocation(id="1", prompt="Banana sushi")
|
||||
# n1.graph.add_node(n1_1)
|
||||
|
||||
g.add_node(n1)
|
||||
# g.add_node(n1)
|
||||
|
||||
n2 = ESRGANInvocation(id="2")
|
||||
g.add_node(n2)
|
||||
# n2 = ESRGANInvocation(id="2")
|
||||
# g.add_node(n2)
|
||||
|
||||
with pytest.raises(NodeNotFoundError):
|
||||
_ = g.get_node("2.1")
|
||||
# with pytest.raises(NodeNotFoundError):
|
||||
# _ = g.get_node("2.1")
|
||||
|
||||
|
||||
def test_graph_gets_networkx_graph():
|
||||
|
@ -8,10 +8,9 @@ from invokeai.app.services.session_queue.session_queue_common import (
|
||||
NodeFieldValue,
|
||||
calc_session_count,
|
||||
create_session_nfv_tuples,
|
||||
populate_graph,
|
||||
prepare_values_to_insert,
|
||||
)
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, GraphInvocation
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState
|
||||
from tests.aa_nodes.test_nodes import PromptTestInvocation
|
||||
|
||||
|
||||
@ -39,28 +38,28 @@ def batch_graph() -> Graph:
|
||||
return g
|
||||
|
||||
|
||||
def test_populate_graph_with_subgraph():
|
||||
g1 = Graph()
|
||||
g1.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
|
||||
g1.add_node(PromptTestInvocation(id="2", prompt="Banana sushi"))
|
||||
n1 = PromptTestInvocation(id="1", prompt="Banana snake")
|
||||
subgraph = Graph()
|
||||
subgraph.add_node(n1)
|
||||
g1.add_node(GraphInvocation(id="3", graph=subgraph))
|
||||
# def test_populate_graph_with_subgraph():
|
||||
# g1 = Graph()
|
||||
# g1.add_node(PromptTestInvocation(id="1", prompt="Banana sushi"))
|
||||
# g1.add_node(PromptTestInvocation(id="2", prompt="Banana sushi"))
|
||||
# n1 = PromptTestInvocation(id="1", prompt="Banana snake")
|
||||
# subgraph = Graph()
|
||||
# subgraph.add_node(n1)
|
||||
# g1.add_node(GraphInvocation(id="3", graph=subgraph))
|
||||
|
||||
nfvs = [
|
||||
NodeFieldValue(node_path="1", field_name="prompt", value="Strawberry sushi"),
|
||||
NodeFieldValue(node_path="2", field_name="prompt", value="Strawberry sunday"),
|
||||
NodeFieldValue(node_path="3.1", field_name="prompt", value="Strawberry snake"),
|
||||
]
|
||||
# nfvs = [
|
||||
# NodeFieldValue(node_path="1", field_name="prompt", value="Strawberry sushi"),
|
||||
# NodeFieldValue(node_path="2", field_name="prompt", value="Strawberry sunday"),
|
||||
# NodeFieldValue(node_path="3.1", field_name="prompt", value="Strawberry snake"),
|
||||
# ]
|
||||
|
||||
g2 = populate_graph(g1, nfvs)
|
||||
# g2 = populate_graph(g1, nfvs)
|
||||
|
||||
# do not mutate g1
|
||||
assert g1 is not g2
|
||||
assert g2.get_node("1").prompt == "Strawberry sushi"
|
||||
assert g2.get_node("2").prompt == "Strawberry sunday"
|
||||
assert g2.get_node("3.1").prompt == "Strawberry snake"
|
||||
# # do not mutate g1
|
||||
# assert g1 is not g2
|
||||
# assert g2.get_node("1").prompt == "Strawberry sushi"
|
||||
# assert g2.get_node("2").prompt == "Strawberry sunday"
|
||||
# assert g2.get_node("3.1").prompt == "Strawberry snake"
|
||||
|
||||
|
||||
def test_create_sessions_from_batch_with_runs(batch_data_collection, batch_graph):
|
||||
|
Loading…
Reference in New Issue
Block a user