From 4427960acb2383d97f7e562aef4873e326512c14 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 25 Jun 2024 19:21:15 +1000 Subject: [PATCH] feat(ui): add updateNode to Graph --- .../nodes/util/graph/generation/Graph.test.ts | 68 +++++++++++++++++++ .../nodes/util/graph/generation/Graph.ts | 24 +++++++ 2 files changed, 92 insertions(+) diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.test.ts index 67fabbc158..c3c3ca2348 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.test.ts @@ -1,3 +1,4 @@ +import { deepClone } from 'common/util/deepClone'; import { Graph } from 'features/nodes/util/graph/generation/Graph'; import type { AnyInvocation, Invocation } from 'services/api/types'; import { assert, AssertionError, is } from 'tsafe'; @@ -70,6 +71,73 @@ describe('Graph', () => { }); }); + describe('updateNode', () => { + const initialNode: Invocation<'add'> = { + id: 'old-id', + type: 'add', + a: 1, + }; + + it('should update node properties correctly', () => { + const g = new Graph(); + const n = g.addNode(deepClone(initialNode)); + const updates = { is_intermediate: true, use_cache: true }; + const updatedNode = g.updateNode(n, updates); + expect(updatedNode.is_intermediate).toBe(true); + expect(updatedNode.use_cache).toBe(true); + }); + + it('should allow updating the node id and update related edges', () => { + const g = new Graph(); + const n = g.addNode(deepClone(initialNode)); + const n2 = g.addNode({ id: 'node-2', type: 'add' }); + const n3 = g.addNode({ id: 'node-4', type: 'add' }); + const oldId = n.id; + const newId = 'new-id'; + const e1 = g.addEdge(n, 'value', n2, 'a'); + const e2 = g.addEdge(n3, 'value', n, 'a'); + g.updateNode(n, { id: newId }); + expect(g.hasNode(newId)).toBe(true); + expect(g.hasNode(oldId)).toBe(false); + expect(e1.source.node_id).toBe(newId); + expect(e2.destination.node_id).toBe(newId); + }); + + it('should throw an error if updated id already exists', () => { + const g = new Graph(); + const n = g.addNode(deepClone(initialNode)); + const n2 = g.addNode({ + id: 'other-id', + type: 'add', + }); + expect(() => g.updateNode(n, { id: n2.id })).toThrowError(AssertionError); + }); + + it('should preserve other fields not specified in updates', () => { + const g = new Graph(); + const n = g.addNode(deepClone(initialNode)); + const updatedNode = g.updateNode(n, { b: 3 }); + expect(updatedNode.b).toBe(3); + expect(updatedNode.a).toBe(initialNode.a); + }); + + it('should allow changing multiple properties at once', () => { + const g = new Graph(); + const n = g.addNode(deepClone(initialNode)); + const updatedNode = g.updateNode(n, { a: 2, b: 3 }); + expect(updatedNode.a).toBe(2); + expect(updatedNode.b).toBe(3); + }); + + it('should handle updates with no changes gracefully', () => { + const g = new Graph(); + const n = g.addNode(deepClone(initialNode)); + const updates = {}; + const updatedNode = g.updateNode(n, updates); + expect(updatedNode).toEqual(n); + }); + }); + describe('addEdge', () => { const add: Invocation<'add'> = { id: 'from-node', diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.ts index 213adac4b8..db96a5f7d0 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/Graph.ts @@ -100,6 +100,30 @@ export class Graph { } } + updateNode(node: Invocation, changes: Partial>): Invocation { + if (changes.id) { + assert(!this.hasNode(changes.id), `Node with id ${changes.id} already exists`); + const oldId = node.id; + const newId = changes.id; + this._graph.nodes[newId] = node; + delete this._graph.nodes[node.id]; + node.id = newId; + + this._graph.edges.forEach((edge) => { + if (edge.source.node_id === oldId) { + edge.source.node_id = newId; + } + if (edge.destination.node_id === oldId) { + edge.destination.node_id = newId; + } + }); + } + + Object.assign(node, changes); + + return node; + } + /** * Get the immediate incomers of a node. * @param node The node to get the incomers of.