From d20f98fb4fd3c2bbcdb9a6f6d2837b296c0d2da8 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 9 Feb 2024 20:28:19 +1100 Subject: [PATCH] fix(nodes): deep copy graph inputs The change to memory session storage brings a subtle behaviour change. Previously, we serialized and deserialized everything (e.g. field state, invocation outputs, etc) constantly. The meant we were effectively working with deep-copied objects at all time. We could mutate objects freely without worrying about other references to the object. With memory storage, objects are now passed around by reference, and we cannot handle them in the same way. This is problematic for nodes that mutate their own inputs. There are two ways this causes a problem: - An output is used as input for multiple nodes. If the first node mutates the output object while `invoke`ing, the next node will get the mutated object. - The invocation cache stores live python objects. When a node mutates an output pulled from the cache, the next node that uses the cached object will get the mutated object. The solution is to deep-copy a node's inputs as they are set, effectively reproducing the same behaviour as we had with the SQLite session storage. Nodes can safely mutate their inputs and those changes never leave the node's scope. Closes #5665 --- invokeai/app/services/shared/graph.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 854defc945..80f56b49d3 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -2,7 +2,7 @@ import copy import itertools -from typing import Annotated, Any, Optional, Union, get_args, get_origin, get_type_hints +from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints import networkx as nx from pydantic import BaseModel, ConfigDict, field_validator, model_validator @@ -141,6 +141,16 @@ def are_connections_compatible( return are_connection_types_compatible(from_node_field, to_node_field) +T = TypeVar("T") + + +def copydeep(obj: T) -> T: + """Deep-copies an object. If it is a pydantic model, use the model's copy method.""" + if isinstance(obj, BaseModel): + return obj.model_copy(deep=True) + return copy.deepcopy(obj) + + class NodeAlreadyInGraphError(ValueError): pass @@ -1118,17 +1128,22 @@ class GraphExecutionState(BaseModel): def _prepare_inputs(self, node: BaseInvocation): input_edges = [e for e in self.execution_graph.edges if e.destination.node_id == node.id] + # Inputs must be deep-copied, else if a node mutates the object, other nodes that get the same input + # will see the mutation. if isinstance(node, CollectInvocation): output_collection = [ - getattr(self.results[edge.source.node_id], edge.source.field) + copydeep(getattr(self.results[edge.source.node_id], edge.source.field)) for edge in input_edges if edge.destination.field == "item" ] node.collection = output_collection else: for edge in input_edges: - output_value = getattr(self.results[edge.source.node_id], edge.source.field) - setattr(node, edge.destination.field, output_value) + setattr( + node, + edge.destination.field, + copydeep(getattr(self.results[edge.source.node_id], edge.source.field)), + ) # TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state def _is_edge_valid(self, edge: Edge) -> bool: