fix(nodes): deep copy graph inputs (#5686)

## What type of PR is this? (check all applicable)

- [ ] Refactor
- [ ] Feature
- [x] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission

## Description

The change to memory session storage brings a subtle behaviour change.

Previously, we serialized and deserialized everything (e.g. field state,
invocation outputs, etc) constantly. The meant we were effectively
working with deep-copied objects at all time. We could mutate objects
freely without worrying about other references to the object.

With memory storage, objects are now passed around by reference, and we
cannot handle them in the same way.

This is problematic for nodes that mutate their own inputs. There are
two ways this causes a problem:

- An output is used as input for multiple nodes. If the first node
mutates the output object while `invoke`ing, the next node will get the
mutated object.
- The invocation cache stores live python objects. When a node mutates
an output pulled from the cache, the next node that uses the cached
object will get the mutated object.

The solution is to deep-copy a node's inputs as they are set,
effectively reproducing the same behaviour as we had with the SQLite
session storage. Nodes can safely mutate their inputs and those changes
never leave the node's scope.

## Related Tickets & Documents

<!--
For pull requests that relate or close an issue, please include them
below. 

For example having the text: "closes #1234" would connect the current
pull
request to issue 1234.  And when we merge the pull request, Github will
automatically close the issue.
-->

- Closes  #5665

The root issue affects CLIP Skip because that node mutates its input
`ClipField`. Specifically, it increments `self.clip.skipped_layers` and
passes `self.clip` as its output. I don't know if there are any other
nodes that do this.

## QA Instructions, Screenshots, Recordings

Two issues to reproduce. 

First is the caching issue:


![image](https://github.com/invoke-ai/InvokeAI/assets/4822129/7a251e48-bc70-4b8e-8816-84aac41ce4d3)

Note the cache is enabled. Run this simple graph a couple times, and
check the outputs of the CLIP Skip node. You'll see the `skipped_layers`
value increasing each time.

Second is the nodes-sharing-inputs issue:


![image](https://github.com/invoke-ai/InvokeAI/assets/4822129/ecdaefab-2beb-4950-b4bf-2a5738ce6832)

Note the cache is _disabled_. Run the graph a couple times and check the
outputs of the two CLIP Skip nodes. You'll see that one has the expected
value for `skipped_layers` and the other has double that.

Now update to the PR and try again. You should see `skipped_layers` is
the right value in all cases.

<!-- 
Please provide steps on how to test changes, any hardware or 
software specifications as well as any other pertinent information. 
-->

## Merge Plan

This PR can be merged when approved. It needs a real review with
braintime.

<!--
A merge plan describes how this PR should be handled after it is
approved.

Example merge plans:
- "This PR can be merged when approved"
- "This must be squash-merged when approved"
- "DO NOT MERGE - I will rebase and tidy commits before merging"
- "#dev-chat on discord needs to be advised of this change when it is
merged"

A merge plan is particularly important for large PRs or PRs that touch
the
database in any way.
-->
This commit is contained in:
Brandon 2024-02-09 13:24:10 -05:00 committed by GitHub
commit cd169ee082
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,7 +2,7 @@
import copy
import itertools
from typing import Annotated, Any, Optional, Union, get_args, get_origin, get_type_hints
from typing import Annotated, Any, Optional, TypeVar, Union, get_args, get_origin, get_type_hints
import networkx as nx
from pydantic import BaseModel, ConfigDict, field_validator, model_validator
@ -141,6 +141,16 @@ def are_connections_compatible(
return are_connection_types_compatible(from_node_field, to_node_field)
T = TypeVar("T")
def copydeep(obj: T) -> T:
"""Deep-copies an object. If it is a pydantic model, use the model's copy method."""
if isinstance(obj, BaseModel):
return obj.model_copy(deep=True)
return copy.deepcopy(obj)
class NodeAlreadyInGraphError(ValueError):
pass
@ -1118,17 +1128,22 @@ class GraphExecutionState(BaseModel):
def _prepare_inputs(self, node: BaseInvocation):
input_edges = [e for e in self.execution_graph.edges if e.destination.node_id == node.id]
# Inputs must be deep-copied, else if a node mutates the object, other nodes that get the same input
# will see the mutation.
if isinstance(node, CollectInvocation):
output_collection = [
getattr(self.results[edge.source.node_id], edge.source.field)
copydeep(getattr(self.results[edge.source.node_id], edge.source.field))
for edge in input_edges
if edge.destination.field == "item"
]
node.collection = output_collection
else:
for edge in input_edges:
output_value = getattr(self.results[edge.source.node_id], edge.source.field)
setattr(node, edge.destination.field, output_value)
setattr(
node,
edge.destination.field,
copydeep(getattr(self.results[edge.source.node_id], edge.source.field)),
)
# TODO: Add API for modifying underlying graph that checks if the change will be valid given the current execution state
def _is_edge_valid(self, edge: Edge) -> bool: