diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 854defc945..80f56b49d3 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -2,7 +2,7 @@ import copy import itertools -from typing import Annotated, Any, Optional, Union, get_args, get_origin, get_type_hints +from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints import networkx as nx from pydantic import BaseModel, ConfigDict, field_validator, model_validator @@ -141,6 +141,16 @@ def are_connections_compatible( return are_connection_types_compatible(from_node_field, to_node_field) +T = TypeVar("T") + + +def copydeep(obj: T) -> T: + """Deep-copies an object. If it is a pydantic model, use the model's copy method.""" + if isinstance(obj, BaseModel): + return obj.model_copy(deep=True) + return copy.deepcopy(obj) + + class NodeAlreadyInGraphError(ValueError): pass @@ -1118,17 +1128,22 @@ class GraphExecutionState(BaseModel): def _prepare_inputs(self, node: BaseInvocation): input_edges = [e for e in self.execution_graph.edges if e.destination.node_id == node.id] + # Inputs must be deep-copied, else if a node mutates the object, other nodes that get the same input + # will see the mutation. if isinstance(node, CollectInvocation): output_collection = [ - getattr(self.results[edge.source.node_id], edge.source.field) + copydeep(getattr(self.results[edge.source.node_id], edge.source.field)) for edge in input_edges if edge.destination.field == "item" ] node.collection = output_collection else: for edge in input_edges: - output_value = getattr(self.results[edge.source.node_id], edge.source.field) - setattr(node, edge.destination.field, output_value) + setattr( + node, + edge.destination.field, + copydeep(getattr(self.results[edge.source.node_id], edge.source.field)), + ) # TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state def _is_edge_valid(self, edge: Edge) -> bool: