mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
[nodes] Add Edge data type (#2958)
Adds an `Edge` data type, replacing the current tuple used for edges.
This commit is contained in:
commit
9738b0ff69
@ -10,6 +10,7 @@ from pydantic.fields import Field
|
|||||||
from ...invocations import *
|
from ...invocations import *
|
||||||
from ...invocations.baseinvocation import BaseInvocation
|
from ...invocations.baseinvocation import BaseInvocation
|
||||||
from ...services.graph import (
|
from ...services.graph import (
|
||||||
|
Edge,
|
||||||
EdgeConnection,
|
EdgeConnection,
|
||||||
Graph,
|
Graph,
|
||||||
GraphExecutionState,
|
GraphExecutionState,
|
||||||
@ -92,7 +93,7 @@ async def get_session(
|
|||||||
async def add_node(
|
async def add_node(
|
||||||
session_id: str = Path(description="The id of the session"),
|
session_id: str = Path(description="The id of the session"),
|
||||||
node: Annotated[
|
node: Annotated[
|
||||||
Union[BaseInvocation.get_invocations()], Field(discriminator="type")
|
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
|
||||||
] = Body(description="The node to add"),
|
] = Body(description="The node to add"),
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Adds a node to the graph"""
|
"""Adds a node to the graph"""
|
||||||
@ -125,7 +126,7 @@ async def update_node(
|
|||||||
session_id: str = Path(description="The id of the session"),
|
session_id: str = Path(description="The id of the session"),
|
||||||
node_path: str = Path(description="The path to the node in the graph"),
|
node_path: str = Path(description="The path to the node in the graph"),
|
||||||
node: Annotated[
|
node: Annotated[
|
||||||
Union[BaseInvocation.get_invocations()], Field(discriminator="type")
|
Union[BaseInvocation.get_invocations()], Field(discriminator="type") # type: ignore
|
||||||
] = Body(description="The new node"),
|
] = Body(description="The new node"),
|
||||||
) -> GraphExecutionState:
|
) -> GraphExecutionState:
|
||||||
"""Updates a node in the graph and removes all linked edges"""
|
"""Updates a node in the graph and removes all linked edges"""
|
||||||
@ -186,7 +187,7 @@ async def delete_node(
|
|||||||
)
|
)
|
||||||
async def add_edge(
|
async def add_edge(
|
||||||
session_id: str = Path(description="The id of the session"),
|
session_id: str = Path(description="The id of the session"),
|
||||||
edge: tuple[EdgeConnection, EdgeConnection] = Body(description="The edge to add"),
|
edge: Edge = Body(description="The edge to add"),
|
||||||
) -> GraphExecutionState:
|
) -> GraphExecutionState:
|
||||||
"""Adds an edge to the graph"""
|
"""Adds an edge to the graph"""
|
||||||
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
session = ApiDependencies.invoker.services.graph_execution_manager.get(session_id)
|
||||||
@ -228,9 +229,9 @@ async def delete_edge(
|
|||||||
return Response(status_code=404)
|
return Response(status_code=404)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
edge = (
|
edge = Edge(
|
||||||
EdgeConnection(node_id=from_node_id, field=from_field),
|
source=EdgeConnection(node_id=from_node_id, field=from_field),
|
||||||
EdgeConnection(node_id=to_node_id, field=to_field),
|
destination=EdgeConnection(node_id=to_node_id, field=to_field)
|
||||||
)
|
)
|
||||||
session.delete_edge(edge)
|
session.delete_edge(edge)
|
||||||
ApiDependencies.invoker.services.graph_execution_manager.set(
|
ApiDependencies.invoker.services.graph_execution_manager.set(
|
||||||
|
@ -19,7 +19,7 @@ from .invocations.baseinvocation import BaseInvocation
|
|||||||
from .services.events import EventServiceBase
|
from .services.events import EventServiceBase
|
||||||
from .services.model_manager_initializer import get_model_manager
|
from .services.model_manager_initializer import get_model_manager
|
||||||
from .services.restoration_services import RestorationServices
|
from .services.restoration_services import RestorationServices
|
||||||
from .services.graph import EdgeConnection, GraphExecutionState
|
from .services.graph import Edge, EdgeConnection, GraphExecutionState
|
||||||
from .services.image_storage import DiskImageStorage
|
from .services.image_storage import DiskImageStorage
|
||||||
from .services.invocation_queue import MemoryInvocationQueue
|
from .services.invocation_queue import MemoryInvocationQueue
|
||||||
from .services.invocation_services import InvocationServices
|
from .services.invocation_services import InvocationServices
|
||||||
@ -77,7 +77,7 @@ def get_command_parser() -> argparse.ArgumentParser:
|
|||||||
|
|
||||||
def generate_matching_edges(
|
def generate_matching_edges(
|
||||||
a: BaseInvocation, b: BaseInvocation
|
a: BaseInvocation, b: BaseInvocation
|
||||||
) -> list[tuple[EdgeConnection, EdgeConnection]]:
|
) -> list[Edge]:
|
||||||
"""Generates all possible edges between two invocations"""
|
"""Generates all possible edges between two invocations"""
|
||||||
atype = type(a)
|
atype = type(a)
|
||||||
btype = type(b)
|
btype = type(b)
|
||||||
@ -94,9 +94,9 @@ def generate_matching_edges(
|
|||||||
matching_fields = matching_fields.difference(invalid_fields)
|
matching_fields = matching_fields.difference(invalid_fields)
|
||||||
|
|
||||||
edges = [
|
edges = [
|
||||||
(
|
Edge(
|
||||||
EdgeConnection(node_id=a.id, field=field),
|
source=EdgeConnection(node_id=a.id, field=field),
|
||||||
EdgeConnection(node_id=b.id, field=field),
|
destination=EdgeConnection(node_id=b.id, field=field)
|
||||||
)
|
)
|
||||||
for field in matching_fields
|
for field in matching_fields
|
||||||
]
|
]
|
||||||
@ -111,16 +111,15 @@ class SessionError(Exception):
|
|||||||
def invoke_all(context: CliContext):
|
def invoke_all(context: CliContext):
|
||||||
"""Runs all invocations in the specified session"""
|
"""Runs all invocations in the specified session"""
|
||||||
context.invoker.invoke(context.session, invoke_all=True)
|
context.invoker.invoke(context.session, invoke_all=True)
|
||||||
while not context.session.is_complete():
|
while not context.get_session().is_complete():
|
||||||
# Wait some time
|
# Wait some time
|
||||||
session = context.get_session()
|
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
# Print any errors
|
# Print any errors
|
||||||
if context.session.has_error():
|
if context.session.has_error():
|
||||||
for n in context.session.errors:
|
for n in context.session.errors:
|
||||||
print(
|
print(
|
||||||
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {session.errors[n]}"
|
f"Error in node {n} (source node {context.session.prepared_source_mapping[n]}): {context.session.errors[n]}"
|
||||||
)
|
)
|
||||||
|
|
||||||
raise SessionError()
|
raise SessionError()
|
||||||
@ -203,7 +202,7 @@ def invoke_cli():
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Pipe previous command output (if there was a previous command)
|
# Pipe previous command output (if there was a previous command)
|
||||||
edges = []
|
edges: list[Edge] = list()
|
||||||
if len(history) > 0 or current_id != start_id:
|
if len(history) > 0 or current_id != start_id:
|
||||||
from_id = (
|
from_id = (
|
||||||
history[0] if current_id == start_id else str(current_id - 1)
|
history[0] if current_id == start_id else str(current_id - 1)
|
||||||
@ -225,19 +224,19 @@ def invoke_cli():
|
|||||||
matching_edges = generate_matching_edges(
|
matching_edges = generate_matching_edges(
|
||||||
link_node, command.command
|
link_node, command.command
|
||||||
)
|
)
|
||||||
matching_destinations = [e[1] for e in matching_edges]
|
matching_destinations = [e.destination for e in matching_edges]
|
||||||
edges = [e for e in edges if e[1] not in matching_destinations]
|
edges = [e for e in edges if e.destination not in matching_destinations]
|
||||||
edges.extend(matching_edges)
|
edges.extend(matching_edges)
|
||||||
|
|
||||||
if "link" in args and args["link"]:
|
if "link" in args and args["link"]:
|
||||||
for link in args["link"]:
|
for link in args["link"]:
|
||||||
edges = [e for e in edges if e[1].node_id != command.command.id and e[1].field != link[2]]
|
edges = [e for e in edges if e.destination.node_id != command.command.id and e.destination.field != link[2]]
|
||||||
edges.append(
|
edges.append(
|
||||||
(
|
Edge(
|
||||||
EdgeConnection(node_id=link[1], field=link[0]),
|
source=EdgeConnection(node_id=link[1], field=link[0]),
|
||||||
EdgeConnection(
|
destination=EdgeConnection(
|
||||||
node_id=command.command.id, field=link[2]
|
node_id=command.command.id, field=link[2]
|
||||||
),
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -44,6 +44,11 @@ class EdgeConnection(BaseModel):
|
|||||||
return hash(f"{self.node_id}.{self.field}")
|
return hash(f"{self.node_id}.{self.field}")
|
||||||
|
|
||||||
|
|
||||||
|
class Edge(BaseModel):
|
||||||
|
source: EdgeConnection = Field(description="The connection for the edge's from node and field")
|
||||||
|
destination: EdgeConnection = Field(description="The connection for the edge's to node and field")
|
||||||
|
|
||||||
|
|
||||||
def get_output_field(node: BaseInvocation, field: str) -> Any:
|
def get_output_field(node: BaseInvocation, field: str) -> Any:
|
||||||
node_type = type(node)
|
node_type = type(node)
|
||||||
node_outputs = get_type_hints(node_type.get_output_type())
|
node_outputs = get_type_hints(node_type.get_output_type())
|
||||||
@ -194,7 +199,7 @@ class Graph(BaseModel):
|
|||||||
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
|
nodes: dict[str, Annotated[InvocationsUnion, Field(discriminator="type")]] = Field(
|
||||||
description="The nodes in this graph", default_factory=dict
|
description="The nodes in this graph", default_factory=dict
|
||||||
)
|
)
|
||||||
edges: list[tuple[EdgeConnection, EdgeConnection]] = Field(
|
edges: list[Edge] = Field(
|
||||||
description="The connections between nodes and their fields in this graph",
|
description="The connections between nodes and their fields in this graph",
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
)
|
)
|
||||||
@ -251,7 +256,7 @@ class Graph(BaseModel):
|
|||||||
except NodeNotFoundError:
|
except NodeNotFoundError:
|
||||||
pass # Ignore, not doesn't exist (should this throw?)
|
pass # Ignore, not doesn't exist (should this throw?)
|
||||||
|
|
||||||
def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
|
def add_edge(self, edge: Edge) -> None:
|
||||||
"""Adds an edge to a graph
|
"""Adds an edge to a graph
|
||||||
|
|
||||||
:raises InvalidEdgeError: the provided edge is invalid.
|
:raises InvalidEdgeError: the provided edge is invalid.
|
||||||
@ -262,7 +267,7 @@ class Graph(BaseModel):
|
|||||||
else:
|
else:
|
||||||
raise InvalidEdgeError()
|
raise InvalidEdgeError()
|
||||||
|
|
||||||
def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
|
def delete_edge(self, edge: Edge) -> None:
|
||||||
"""Deletes an edge from a graph"""
|
"""Deletes an edge from a graph"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -280,7 +285,7 @@ class Graph(BaseModel):
|
|||||||
|
|
||||||
# Validate all edges reference nodes in the graph
|
# Validate all edges reference nodes in the graph
|
||||||
node_ids = set(
|
node_ids = set(
|
||||||
[e[0].node_id for e in self.edges] + [e[1].node_id for e in self.edges]
|
[e.source.node_id for e in self.edges] + [e.destination.node_id for e in self.edges]
|
||||||
)
|
)
|
||||||
if not all((self.has_node(node_id) for node_id in node_ids)):
|
if not all((self.has_node(node_id) for node_id in node_ids)):
|
||||||
return False
|
return False
|
||||||
@ -294,10 +299,10 @@ class Graph(BaseModel):
|
|||||||
if not all(
|
if not all(
|
||||||
(
|
(
|
||||||
are_connections_compatible(
|
are_connections_compatible(
|
||||||
self.get_node(e[0].node_id),
|
self.get_node(e.source.node_id),
|
||||||
e[0].field,
|
e.source.field,
|
||||||
self.get_node(e[1].node_id),
|
self.get_node(e.destination.node_id),
|
||||||
e[1].field,
|
e.destination.field,
|
||||||
)
|
)
|
||||||
for e in self.edges
|
for e in self.edges
|
||||||
)
|
)
|
||||||
@ -328,58 +333,58 @@ class Graph(BaseModel):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool:
|
def _is_edge_valid(self, edge: Edge) -> bool:
|
||||||
"""Validates that a new edge doesn't create a cycle in the graph"""
|
"""Validates that a new edge doesn't create a cycle in the graph"""
|
||||||
|
|
||||||
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
|
# Validate that the nodes exist (edges may contain node paths, so we can't just check for nodes directly)
|
||||||
try:
|
try:
|
||||||
from_node = self.get_node(edge[0].node_id)
|
from_node = self.get_node(edge.source.node_id)
|
||||||
to_node = self.get_node(edge[1].node_id)
|
to_node = self.get_node(edge.destination.node_id)
|
||||||
except NodeNotFoundError:
|
except NodeNotFoundError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Validate that an edge to this node+field doesn't already exist
|
# Validate that an edge to this node+field doesn't already exist
|
||||||
input_edges = self._get_input_edges(edge[1].node_id, edge[1].field)
|
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
|
||||||
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
|
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Validate that no cycles would be created
|
# Validate that no cycles would be created
|
||||||
g = self.nx_graph_flat()
|
g = self.nx_graph_flat()
|
||||||
g.add_edge(edge[0].node_id, edge[1].node_id)
|
g.add_edge(edge.source.node_id, edge.destination.node_id)
|
||||||
if not nx.is_directed_acyclic_graph(g):
|
if not nx.is_directed_acyclic_graph(g):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Validate that the field types are compatible
|
# Validate that the field types are compatible
|
||||||
if not are_connections_compatible(
|
if not are_connections_compatible(
|
||||||
from_node, edge[0].field, to_node, edge[1].field
|
from_node, edge.source.field, to_node, edge.destination.field
|
||||||
):
|
):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
|
# Validate if iterator output type matches iterator input type (if this edge results in both being set)
|
||||||
if isinstance(to_node, IterateInvocation) and edge[1].field == "collection":
|
if isinstance(to_node, IterateInvocation) and edge.destination.field == "collection":
|
||||||
if not self._is_iterator_connection_valid(
|
if not self._is_iterator_connection_valid(
|
||||||
edge[1].node_id, new_input=edge[0]
|
edge.destination.node_id, new_input=edge.source
|
||||||
):
|
):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Validate if iterator input type matches output type (if this edge results in both being set)
|
# Validate if iterator input type matches output type (if this edge results in both being set)
|
||||||
if isinstance(from_node, IterateInvocation) and edge[0].field == "item":
|
if isinstance(from_node, IterateInvocation) and edge.source.field == "item":
|
||||||
if not self._is_iterator_connection_valid(
|
if not self._is_iterator_connection_valid(
|
||||||
edge[0].node_id, new_output=edge[1]
|
edge.source.node_id, new_output=edge.destination
|
||||||
):
|
):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Validate if collector input type matches output type (if this edge results in both being set)
|
# Validate if collector input type matches output type (if this edge results in both being set)
|
||||||
if isinstance(to_node, CollectInvocation) and edge[1].field == "item":
|
if isinstance(to_node, CollectInvocation) and edge.destination.field == "item":
|
||||||
if not self._is_collector_connection_valid(
|
if not self._is_collector_connection_valid(
|
||||||
edge[1].node_id, new_input=edge[0]
|
edge.destination.node_id, new_input=edge.source
|
||||||
):
|
):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Validate if collector output type matches input type (if this edge results in both being set)
|
# Validate if collector output type matches input type (if this edge results in both being set)
|
||||||
if isinstance(from_node, CollectInvocation) and edge[0].field == "collection":
|
if isinstance(from_node, CollectInvocation) and edge.source.field == "collection":
|
||||||
if not self._is_collector_connection_valid(
|
if not self._is_collector_connection_valid(
|
||||||
edge[0].node_id, new_output=edge[1]
|
edge.source.node_id, new_output=edge.destination
|
||||||
):
|
):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -438,15 +443,15 @@ class Graph(BaseModel):
|
|||||||
# Remove the graph prefix from the node path
|
# Remove the graph prefix from the node path
|
||||||
new_graph_node_path = (
|
new_graph_node_path = (
|
||||||
new_node.id
|
new_node.id
|
||||||
if "." not in edge[1].node_id
|
if "." not in edge.destination.node_id
|
||||||
else f'{edge[1].node_id[edge[1].node_id.rindex("."):]}.{new_node.id}'
|
else f'{edge.destination.node_id[edge.destination.node_id.rindex("."):]}.{new_node.id}'
|
||||||
)
|
)
|
||||||
graph.add_edge(
|
graph.add_edge(
|
||||||
(
|
Edge(
|
||||||
edge[0],
|
source=edge.source,
|
||||||
EdgeConnection(
|
destination=EdgeConnection(
|
||||||
node_id=new_graph_node_path, field=edge[1].field
|
node_id=new_graph_node_path, field=edge.destination.field
|
||||||
),
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -454,51 +459,51 @@ class Graph(BaseModel):
|
|||||||
# Remove the graph prefix from the node path
|
# Remove the graph prefix from the node path
|
||||||
new_graph_node_path = (
|
new_graph_node_path = (
|
||||||
new_node.id
|
new_node.id
|
||||||
if "." not in edge[0].node_id
|
if "." not in edge.source.node_id
|
||||||
else f'{edge[0].node_id[edge[0].node_id.rindex("."):]}.{new_node.id}'
|
else f'{edge.source.node_id[edge.source.node_id.rindex("."):]}.{new_node.id}'
|
||||||
)
|
)
|
||||||
graph.add_edge(
|
graph.add_edge(
|
||||||
(
|
Edge(
|
||||||
EdgeConnection(
|
source=EdgeConnection(
|
||||||
node_id=new_graph_node_path, field=edge[0].field
|
node_id=new_graph_node_path, field=edge.source.field
|
||||||
),
|
),
|
||||||
edge[1],
|
destination=edge.destination
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_input_edges(
|
def _get_input_edges(
|
||||||
self, node_path: str, field: Optional[str] = None
|
self, node_path: str, field: Optional[str] = None
|
||||||
) -> list[tuple[EdgeConnection, EdgeConnection]]:
|
) -> list[Edge]:
|
||||||
"""Gets all input edges for a node"""
|
"""Gets all input edges for a node"""
|
||||||
edges = self._get_input_edges_and_graphs(node_path)
|
edges = self._get_input_edges_and_graphs(node_path)
|
||||||
|
|
||||||
# Filter to edges that match the field
|
# Filter to edges that match the field
|
||||||
filtered_edges = (e for e in edges if field is None or e[2][1].field == field)
|
filtered_edges = (e for e in edges if field is None or e[2].destination.field == field)
|
||||||
|
|
||||||
# Create full node paths for each edge
|
# Create full node paths for each edge
|
||||||
return [
|
return [
|
||||||
(
|
Edge(
|
||||||
EdgeConnection(
|
source=EdgeConnection(
|
||||||
node_id=self._get_node_path(e[0].node_id, prefix=prefix),
|
node_id=self._get_node_path(e.source.node_id, prefix=prefix),
|
||||||
field=e[0].field,
|
field=e.source.field,
|
||||||
),
|
|
||||||
EdgeConnection(
|
|
||||||
node_id=self._get_node_path(e[1].node_id, prefix=prefix),
|
|
||||||
field=e[1].field,
|
|
||||||
),
|
),
|
||||||
|
destination=EdgeConnection(
|
||||||
|
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
|
||||||
|
field=e.destination.field,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
for _, prefix, e in filtered_edges
|
for _, prefix, e in filtered_edges
|
||||||
]
|
]
|
||||||
|
|
||||||
def _get_input_edges_and_graphs(
|
def _get_input_edges_and_graphs(
|
||||||
self, node_path: str, prefix: Optional[str] = None
|
self, node_path: str, prefix: Optional[str] = None
|
||||||
) -> list[tuple["Graph", str, tuple[EdgeConnection, EdgeConnection]]]:
|
) -> list[tuple["Graph", str, Edge]]:
|
||||||
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
|
"""Gets all input edges for a node along with the graph they are in and the graph's path"""
|
||||||
edges = list()
|
edges = list()
|
||||||
|
|
||||||
# Return any input edges that appear in this graph
|
# Return any input edges that appear in this graph
|
||||||
edges.extend(
|
edges.extend(
|
||||||
[(self, prefix, e) for e in self.edges if e[1].node_id == node_path]
|
[(self, prefix, e) for e in self.edges if e.destination.node_id == node_path]
|
||||||
)
|
)
|
||||||
|
|
||||||
node_id = (
|
node_id = (
|
||||||
@ -522,37 +527,37 @@ class Graph(BaseModel):
|
|||||||
|
|
||||||
def _get_output_edges(
|
def _get_output_edges(
|
||||||
self, node_path: str, field: str
|
self, node_path: str, field: str
|
||||||
) -> list[tuple[EdgeConnection, EdgeConnection]]:
|
) -> list[Edge]:
|
||||||
"""Gets all output edges for a node"""
|
"""Gets all output edges for a node"""
|
||||||
edges = self._get_output_edges_and_graphs(node_path)
|
edges = self._get_output_edges_and_graphs(node_path)
|
||||||
|
|
||||||
# Filter to edges that match the field
|
# Filter to edges that match the field
|
||||||
filtered_edges = (e for e in edges if e[2][0].field == field)
|
filtered_edges = (e for e in edges if e[2].source.field == field)
|
||||||
|
|
||||||
# Create full node paths for each edge
|
# Create full node paths for each edge
|
||||||
return [
|
return [
|
||||||
(
|
Edge(
|
||||||
EdgeConnection(
|
source=EdgeConnection(
|
||||||
node_id=self._get_node_path(e[0].node_id, prefix=prefix),
|
node_id=self._get_node_path(e.source.node_id, prefix=prefix),
|
||||||
field=e[0].field,
|
field=e.source.field,
|
||||||
),
|
|
||||||
EdgeConnection(
|
|
||||||
node_id=self._get_node_path(e[1].node_id, prefix=prefix),
|
|
||||||
field=e[1].field,
|
|
||||||
),
|
),
|
||||||
|
destination=EdgeConnection(
|
||||||
|
node_id=self._get_node_path(e.destination.node_id, prefix=prefix),
|
||||||
|
field=e.destination.field,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
for _, prefix, e in filtered_edges
|
for _, prefix, e in filtered_edges
|
||||||
]
|
]
|
||||||
|
|
||||||
def _get_output_edges_and_graphs(
|
def _get_output_edges_and_graphs(
|
||||||
self, node_path: str, prefix: Optional[str] = None
|
self, node_path: str, prefix: Optional[str] = None
|
||||||
) -> list[tuple["Graph", str, tuple[EdgeConnection, EdgeConnection]]]:
|
) -> list[tuple["Graph", str, Edge]]:
|
||||||
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
|
"""Gets all output edges for a node along with the graph they are in and the graph's path"""
|
||||||
edges = list()
|
edges = list()
|
||||||
|
|
||||||
# Return any input edges that appear in this graph
|
# Return any input edges that appear in this graph
|
||||||
edges.extend(
|
edges.extend(
|
||||||
[(self, prefix, e) for e in self.edges if e[0].node_id == node_path]
|
[(self, prefix, e) for e in self.edges if e.source.node_id == node_path]
|
||||||
)
|
)
|
||||||
|
|
||||||
node_id = (
|
node_id = (
|
||||||
@ -580,8 +585,8 @@ class Graph(BaseModel):
|
|||||||
new_input: Optional[EdgeConnection] = None,
|
new_input: Optional[EdgeConnection] = None,
|
||||||
new_output: Optional[EdgeConnection] = None,
|
new_output: Optional[EdgeConnection] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
inputs = list([e[0] for e in self._get_input_edges(node_path, "collection")])
|
inputs = list([e.source for e in self._get_input_edges(node_path, "collection")])
|
||||||
outputs = list([e[1] for e in self._get_output_edges(node_path, "item")])
|
outputs = list([e.destination for e in self._get_output_edges(node_path, "item")])
|
||||||
|
|
||||||
if new_input is not None:
|
if new_input is not None:
|
||||||
inputs.append(new_input)
|
inputs.append(new_input)
|
||||||
@ -622,8 +627,8 @@ class Graph(BaseModel):
|
|||||||
new_input: Optional[EdgeConnection] = None,
|
new_input: Optional[EdgeConnection] = None,
|
||||||
new_output: Optional[EdgeConnection] = None,
|
new_output: Optional[EdgeConnection] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
inputs = list([e[0] for e in self._get_input_edges(node_path, "item")])
|
inputs = list([e.source for e in self._get_input_edges(node_path, "item")])
|
||||||
outputs = list([e[1] for e in self._get_output_edges(node_path, "collection")])
|
outputs = list([e.destination for e in self._get_output_edges(node_path, "collection")])
|
||||||
|
|
||||||
if new_input is not None:
|
if new_input is not None:
|
||||||
inputs.append(new_input)
|
inputs.append(new_input)
|
||||||
@ -684,7 +689,7 @@ class Graph(BaseModel):
|
|||||||
# TODO: Cache this?
|
# TODO: Cache this?
|
||||||
g = nx.DiGraph()
|
g = nx.DiGraph()
|
||||||
g.add_nodes_from([n for n in self.nodes.keys()])
|
g.add_nodes_from([n for n in self.nodes.keys()])
|
||||||
g.add_edges_from(set([(e[0].node_id, e[1].node_id) for e in self.edges]))
|
g.add_edges_from(set([(e.source.node_id, e.destination.node_id) for e in self.edges]))
|
||||||
return g
|
return g
|
||||||
|
|
||||||
def nx_graph_flat(
|
def nx_graph_flat(
|
||||||
@ -711,7 +716,7 @@ class Graph(BaseModel):
|
|||||||
|
|
||||||
# TODO: figure out if iteration nodes need to be expanded
|
# TODO: figure out if iteration nodes need to be expanded
|
||||||
|
|
||||||
unique_edges = set([(e[0].node_id, e[1].node_id) for e in self.edges])
|
unique_edges = set([(e.source.node_id, e.destination.node_id) for e in self.edges])
|
||||||
g.add_edges_from(
|
g.add_edges_from(
|
||||||
[
|
[
|
||||||
(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix))
|
(self._get_node_path(e[0], prefix), self._get_node_path(e[1], prefix))
|
||||||
@ -841,13 +846,13 @@ class GraphExecutionState(BaseModel):
|
|||||||
input_collection_prepared_node_id = next(
|
input_collection_prepared_node_id = next(
|
||||||
n[1]
|
n[1]
|
||||||
for n in iteration_node_map
|
for n in iteration_node_map
|
||||||
if n[0] == input_collection_edge[0].node_id
|
if n[0] == input_collection_edge.source.node_id
|
||||||
)
|
)
|
||||||
input_collection_prepared_node_output = self.results[
|
input_collection_prepared_node_output = self.results[
|
||||||
input_collection_prepared_node_id
|
input_collection_prepared_node_id
|
||||||
]
|
]
|
||||||
input_collection = getattr(
|
input_collection = getattr(
|
||||||
input_collection_prepared_node_output, input_collection_edge[0].field
|
input_collection_prepared_node_output, input_collection_edge.source.field
|
||||||
)
|
)
|
||||||
self_iteration_count = len(input_collection)
|
self_iteration_count = len(input_collection)
|
||||||
|
|
||||||
@ -864,11 +869,11 @@ class GraphExecutionState(BaseModel):
|
|||||||
new_edges = list()
|
new_edges = list()
|
||||||
for edge in input_edges:
|
for edge in input_edges:
|
||||||
for input_node_id in (
|
for input_node_id in (
|
||||||
n[1] for n in iteration_node_map if n[0] == edge[0].node_id
|
n[1] for n in iteration_node_map if n[0] == edge.source.node_id
|
||||||
):
|
):
|
||||||
new_edge = (
|
new_edge = Edge(
|
||||||
EdgeConnection(node_id=input_node_id, field=edge[0].field),
|
source=EdgeConnection(node_id=input_node_id, field=edge.source.field),
|
||||||
EdgeConnection(node_id="", field=edge[1].field),
|
destination=EdgeConnection(node_id="", field=edge.destination.field),
|
||||||
)
|
)
|
||||||
new_edges.append(new_edge)
|
new_edges.append(new_edge)
|
||||||
|
|
||||||
@ -893,9 +898,9 @@ class GraphExecutionState(BaseModel):
|
|||||||
|
|
||||||
# Add new edges to execution graph
|
# Add new edges to execution graph
|
||||||
for edge in new_edges:
|
for edge in new_edges:
|
||||||
new_edge = (
|
new_edge = Edge(
|
||||||
edge[0],
|
source=edge.source,
|
||||||
EdgeConnection(node_id=new_node.id, field=edge[1].field),
|
destination=EdgeConnection(node_id=new_node.id, field=edge.destination.field),
|
||||||
)
|
)
|
||||||
self.execution_graph.add_edge(new_edge)
|
self.execution_graph.add_edge(new_edge)
|
||||||
|
|
||||||
@ -1043,26 +1048,26 @@ class GraphExecutionState(BaseModel):
|
|||||||
return self.execution_graph.nodes[next_node]
|
return self.execution_graph.nodes[next_node]
|
||||||
|
|
||||||
def _prepare_inputs(self, node: BaseInvocation):
|
def _prepare_inputs(self, node: BaseInvocation):
|
||||||
input_edges = [e for e in self.execution_graph.edges if e[1].node_id == node.id]
|
input_edges = [e for e in self.execution_graph.edges if e.destination.node_id == node.id]
|
||||||
if isinstance(node, CollectInvocation):
|
if isinstance(node, CollectInvocation):
|
||||||
output_collection = [
|
output_collection = [
|
||||||
getattr(self.results[edge[0].node_id], edge[0].field)
|
getattr(self.results[edge.source.node_id], edge.source.field)
|
||||||
for edge in input_edges
|
for edge in input_edges
|
||||||
if edge[1].field == "item"
|
if edge.destination.field == "item"
|
||||||
]
|
]
|
||||||
setattr(node, "collection", output_collection)
|
setattr(node, "collection", output_collection)
|
||||||
else:
|
else:
|
||||||
for edge in input_edges:
|
for edge in input_edges:
|
||||||
output_value = getattr(self.results[edge[0].node_id], edge[0].field)
|
output_value = getattr(self.results[edge.source.node_id], edge.source.field)
|
||||||
setattr(node, edge[1].field, output_value)
|
setattr(node, edge.destination.field, output_value)
|
||||||
|
|
||||||
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
|
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
|
||||||
def _is_edge_valid(self, edge: tuple[EdgeConnection, EdgeConnection]) -> bool:
|
def _is_edge_valid(self, edge: Edge) -> bool:
|
||||||
if not self._is_edge_valid(edge):
|
if not self._is_edge_valid(edge):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Invalid if destination has already been prepared or executed
|
# Invalid if destination has already been prepared or executed
|
||||||
if edge[1].node_id in self.source_prepared_mapping:
|
if edge.destination.node_id in self.source_prepared_mapping:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Otherwise, the edge is valid
|
# Otherwise, the edge is valid
|
||||||
@ -1089,17 +1094,17 @@ class GraphExecutionState(BaseModel):
|
|||||||
)
|
)
|
||||||
self.graph.delete_node(node_path)
|
self.graph.delete_node(node_path)
|
||||||
|
|
||||||
def add_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
|
def add_edge(self, edge: Edge) -> None:
|
||||||
if not self._is_node_updatable(edge[1].node_id):
|
if not self._is_node_updatable(edge.destination.node_id):
|
||||||
raise NodeAlreadyExecutedError(
|
raise NodeAlreadyExecutedError(
|
||||||
f"Destination node {edge[1].node_id} has already been prepared or executed and cannot be linked to"
|
f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot be linked to"
|
||||||
)
|
)
|
||||||
self.graph.add_edge(edge)
|
self.graph.add_edge(edge)
|
||||||
|
|
||||||
def delete_edge(self, edge: tuple[EdgeConnection, EdgeConnection]) -> None:
|
def delete_edge(self, edge: Edge) -> None:
|
||||||
if not self._is_node_updatable(edge[1].node_id):
|
if not self._is_node_updatable(edge.destination.node_id):
|
||||||
raise NodeAlreadyExecutedError(
|
raise NodeAlreadyExecutedError(
|
||||||
f"Destination node {edge[1].node_id} has already been prepared or executed and cannot have a source edge deleted"
|
f"Destination node {edge.destination.node_id} has already been prepared or executed and cannot have a source edge deleted"
|
||||||
)
|
)
|
||||||
self.graph.delete_edge(edge)
|
self.graph.delete_edge(edge)
|
||||||
|
|
||||||
|
@ -105,17 +105,20 @@
|
|||||||
|
|
||||||
// Start building nodes
|
// Start building nodes
|
||||||
var id = 1;
|
var id = 1;
|
||||||
var initialNode = {"id": id.toString(), "type": "txt2img", "prompt": prompt, "sampler": sampler, "steps": steps, "seed": seed};
|
var initialNode = {"id": id.toString(), "type": "txt2img", "prompt": prompt, "model": "stable-diffusion-1-5", "sampler": sampler, "steps": steps, "seed": seed};
|
||||||
|
id++;
|
||||||
|
var i2iNode = {"id": id.toString(), "type": "img2img", "prompt": prompt, "model": "stable-diffusion-1-5", "sampler": sampler, "steps": steps, "seed": Math.floor(Math.random() * 10000)};
|
||||||
id++;
|
id++;
|
||||||
var upscaleNode = {"id": id.toString(), "type": "show_image" };
|
var upscaleNode = {"id": id.toString(), "type": "show_image" };
|
||||||
id++
|
id++
|
||||||
|
|
||||||
nodes = {};
|
nodes = {};
|
||||||
nodes[initialNode.id] = initialNode;
|
nodes[initialNode.id] = initialNode;
|
||||||
|
nodes[i2iNode.id] = i2iNode;
|
||||||
nodes[upscaleNode.id] = upscaleNode;
|
nodes[upscaleNode.id] = upscaleNode;
|
||||||
links = [
|
links = [
|
||||||
[{ "node_id": initialNode.id, field: "image" },
|
{ "source": { "node_id": initialNode.id, field: "image" }, "destination": { "node_id": i2iNode.id, field: "image" }},
|
||||||
{ "node_id": upscaleNode.id, field: "image" }]
|
{ "source": { "node_id": i2iNode.id, field: "image" }, "destination": { "node_id": upscaleNode.id, field: "image" }}
|
||||||
];
|
];
|
||||||
// expandSize = 128;
|
// expandSize = 128;
|
||||||
// for (var i = 0; i < 6; ++i) {
|
// for (var i = 0; i < 6; ++i) {
|
||||||
|
@ -1,15 +1,18 @@
|
|||||||
from invokeai.app.invocations.image import *
|
from invokeai.app.invocations.image import *
|
||||||
|
|
||||||
from .test_nodes import ListPassThroughInvocation, PromptTestInvocation
|
from .test_nodes import ListPassThroughInvocation, PromptTestInvocation
|
||||||
from invokeai.app.services.graph import Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
|
from invokeai.app.services.graph import Edge, Graph, GraphInvocation, InvalidEdgeError, NodeAlreadyInGraphError, NodeNotFoundError, are_connections_compatible, EdgeConnection, CollectInvocation, IterateInvocation
|
||||||
from invokeai.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
|
from invokeai.app.invocations.generate import ImageToImageInvocation, TextToImageInvocation
|
||||||
from invokeai.app.invocations.upscale import UpscaleInvocation
|
from invokeai.app.invocations.upscale import UpscaleInvocation
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
# Helpers
|
# Helpers
|
||||||
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> tuple[EdgeConnection, EdgeConnection]:
|
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edge:
|
||||||
return (EdgeConnection(node_id = from_id, field = from_field), EdgeConnection(node_id = to_id, field = to_field))
|
return Edge(
|
||||||
|
source=EdgeConnection(node_id = from_id, field = from_field),
|
||||||
|
destination=EdgeConnection(node_id = to_id, field = to_field)
|
||||||
|
)
|
||||||
|
|
||||||
# Tests
|
# Tests
|
||||||
def test_connections_are_compatible():
|
def test_connections_are_compatible():
|
||||||
@ -108,7 +111,7 @@ def test_graph_allows_non_conflicting_id_change():
|
|||||||
assert g.get_node("3").prompt == "Banana sushi"
|
assert g.get_node("3").prompt == "Banana sushi"
|
||||||
|
|
||||||
assert len(g.edges) == 1
|
assert len(g.edges) == 1
|
||||||
assert (EdgeConnection(node_id = "3", field = "image"), EdgeConnection(node_id = "2", field = "image")) in g.edges
|
assert Edge(source=EdgeConnection(node_id = "3", field = "image"), destination=EdgeConnection(node_id = "2", field = "image")) in g.edges
|
||||||
|
|
||||||
def test_graph_fails_to_update_node_id_if_conflict():
|
def test_graph_fails_to_update_node_id_if_conflict():
|
||||||
g = Graph()
|
g = Graph()
|
||||||
@ -490,10 +493,10 @@ def test_graph_can_deserialize():
|
|||||||
assert g2.nodes['1'] is not None
|
assert g2.nodes['1'] is not None
|
||||||
assert g2.nodes['2'] is not None
|
assert g2.nodes['2'] is not None
|
||||||
assert len(g2.edges) == 1
|
assert len(g2.edges) == 1
|
||||||
assert g2.edges[0][0].node_id == '1'
|
assert g2.edges[0].source.node_id == '1'
|
||||||
assert g2.edges[0][0].field == 'image'
|
assert g2.edges[0].source.field == 'image'
|
||||||
assert g2.edges[0][1].node_id == '2'
|
assert g2.edges[0].destination.node_id == '2'
|
||||||
assert g2.edges[0][1].field == 'image'
|
assert g2.edges[0].destination.field == 'image'
|
||||||
|
|
||||||
def test_graph_can_generate_schema():
|
def test_graph_can_generate_schema():
|
||||||
# Not throwing on this line is sufficient
|
# Not throwing on this line is sufficient
|
||||||
|
@ -64,10 +64,12 @@ class PromptCollectionTestInvocation(BaseInvocation):
|
|||||||
|
|
||||||
|
|
||||||
from invokeai.app.services.events import EventServiceBase
|
from invokeai.app.services.events import EventServiceBase
|
||||||
from invokeai.app.services.graph import EdgeConnection
|
from invokeai.app.services.graph import Edge, EdgeConnection
|
||||||
|
|
||||||
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> tuple[EdgeConnection, EdgeConnection]:
|
def create_edge(from_id: str, from_field: str, to_id: str, to_field: str) -> Edge:
|
||||||
return (EdgeConnection(node_id = from_id, field = from_field), EdgeConnection(node_id = to_id, field = to_field))
|
return Edge(
|
||||||
|
source=EdgeConnection(node_id = from_id, field = from_field),
|
||||||
|
destination=EdgeConnection(node_id = to_id, field = to_field))
|
||||||
|
|
||||||
|
|
||||||
class TestEvent:
|
class TestEvent:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user