From 8f6078d007ffa52607ad79f11a5d59e45ef60d58 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 5 May 2024 16:14:54 +1000 Subject: [PATCH] feat(ui): refine graph building util Simpler types and API surface. --- .../features/nodes/util/graph/Graph.test.ts | 198 +++++++----------- .../src/features/nodes/util/graph/Graph.ts | 168 +++++---------- .../nodes/util/graph/MetadataUtil.test.ts | 13 +- .../features/nodes/util/graph/MetadataUtil.ts | 53 +++-- .../frontend/web/src/services/api/types.ts | 35 ++-- 5 files changed, 189 insertions(+), 278 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/Graph.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/Graph.test.ts index 6a38dbd218..71bcd9331c 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/Graph.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/Graph.test.ts @@ -69,63 +69,20 @@ describe('Graph', () => { }); }); - describe('updateNode', () => { - it('should update the node with the provided id', () => { - const g = new Graph(); - const node: Invocation<'add'> = { - id: 'test-node', - type: 'add', - a: 1, - }; - g.addNode(node); - const updatedNode = g.updateNode('test-node', 'add', { - a: 2, - }); - expect(g.getNode('test-node', 'add').a).toBe(2); - expect(node).toBe(updatedNode); - }); - it('should throw an error if the node is not found', () => { - expect(() => new Graph().updateNode('not-found', 'add', {})).toThrowError(AssertionError); - }); - it('should throw an error if the node is found but has the wrong type', () => { - const g = new Graph(); - g.addNode({ - id: 'test-node', - type: 'add', - a: 1, - }); - expect(() => g.updateNode('test-node', 'sub', {})).toThrowError(AssertionError); - }); - it('should infer types correctly when `type` is omitted', () => { - const g = new Graph(); - g.addNode({ - id: 'test-node', - type: 'add', - a: 1, - }); - const updatedNode = g.updateNode('test-node', 'add', { - a: 2, - }); - assert(is>(updatedNode)); - }); - it('should infer types correctly when `type` is provided', () => { - const g = new Graph(); - g.addNode({ - id: 'test-node', - type: 'add', - a: 1, - }); - const updatedNode = g.updateNode('test-node', 'add', { - a: 2, - }); - assert(is>(updatedNode)); - }); - }); - describe('addEdge', () => { + const add: Invocation<'add'> = { + id: 'from-node', + type: 'add', + }; + const sub: Invocation<'sub'> = { + id: 'to-node', + type: 'sub', + }; it('should add an edge to the graph with the provided values', () => { const g = new Graph(); - g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b'); + g.addNode(add); + g.addNode(sub); + g.addEdge(add, 'value', sub, 'b'); expect(g._graph.edges.length).toBe(1); expect(g._graph.edges[0]).toEqual({ source: { node_id: 'from-node', field: 'value' }, @@ -134,19 +91,19 @@ describe('Graph', () => { }); it('should throw an error if the edge already exists', () => { const g = new Graph(); - g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b'); - expect(() => g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b')).toThrowError(AssertionError); + g.addEdge(add, 'value', sub, 'b'); + expect(() => g.addEdge(add, 'value', sub, 'b')).toThrowError(AssertionError); }); it('should infer field names', () => { const g = new Graph(); // @ts-expect-error The first field must be a valid output field of the first type arg - g.addEdge<'add', 'sub'>('from-node', 'not-a-valid-field', 'to-node', 'a'); + g.addEdge(add, 'not-a-valid-field', add, 'a'); // @ts-expect-error The second field must be a valid input field of the second type arg - g.addEdge<'add', 'sub'>('from-node-2', 'value', 'to-node-2', 'not-a-valid-field'); + g.addEdge(add, 'value', sub, 'not-a-valid-field'); // @ts-expect-error The first field must be any valid output field - g.addEdge('from-node-3', 'not-a-valid-field', 'to-node-3', 'a'); + g.addEdge(add, 'not-a-valid-field', sub, 'a'); // @ts-expect-error The second field must be any valid input field - g.addEdge('from-node-4', 'clip', 'to-node-4', 'not-a-valid-field'); + g.addEdge(add, 'clip', sub, 'not-a-valid-field'); }); }); @@ -161,38 +118,9 @@ describe('Graph', () => { const n = g.getNode('test-node'); expect(n).toBe(node); }); - it('should return the node with the provided id and type', () => { - const n = g.getNode('test-node', 'add'); - expect(n).toBe(node); - assert(is>(node)); - }); it('should throw an error if the node is not found', () => { expect(() => g.getNode('not-found')).toThrowError(AssertionError); }); - it('should throw an error if the node is found but has the wrong type', () => { - expect(() => g.getNode('test-node', 'sub')).toThrowError(AssertionError); - }); - }); - - describe('getNodeSafe', () => { - const g = new Graph(); - const node = g.addNode({ - id: 'test-node', - type: 'add', - }); - it('should return the node if it is found', () => { - expect(g.getNodeSafe('test-node')).toBe(node); - }); - it('should return the node if it is found with the provided type', () => { - expect(g.getNodeSafe('test-node')).toBe(node); - assert(is>(node)); - }); - it("should return undefined if the node isn't found", () => { - expect(g.getNodeSafe('not-found')).toBeUndefined(); - }); - it('should return undefined if the node is found but has the wrong type', () => { - expect(g.getNodeSafe('test-node', 'sub')).toBeUndefined(); - }); }); describe('hasNode', () => { @@ -212,40 +140,42 @@ describe('Graph', () => { describe('getEdge', () => { const g = new Graph(); - g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b'); + const add: Invocation<'add'> = { + id: 'from-node', + type: 'add', + }; + const sub: Invocation<'sub'> = { + id: 'to-node', + type: 'sub', + }; + g.addEdge(add, 'value', sub, 'b'); it('should return the edge with the provided values', () => { - expect(g.getEdge('from-node', 'value', 'to-node', 'b')).toEqual({ + expect(g.getEdge(add, 'value', sub, 'b')).toEqual({ source: { node_id: 'from-node', field: 'value' }, destination: { node_id: 'to-node', field: 'b' }, }); }); it('should throw an error if the edge is not found', () => { - expect(() => g.getEdge('from-node', 'value', 'to-node', 'a')).toThrowError(AssertionError); - }); - }); - - describe('getEdgeSafe', () => { - const g = new Graph(); - g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b'); - it('should return the edge if it is found', () => { - expect(g.getEdgeSafe('from-node', 'value', 'to-node', 'b')).toEqual({ - source: { node_id: 'from-node', field: 'value' }, - destination: { node_id: 'to-node', field: 'b' }, - }); - }); - it('should return undefined if the edge is not found', () => { - expect(g.getEdgeSafe('from-node', 'value', 'to-node', 'a')).toBeUndefined(); + expect(() => g.getEdge(add, 'value', sub, 'a')).toThrowError(AssertionError); }); }); describe('hasEdge', () => { const g = new Graph(); - g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b'); + const add: Invocation<'add'> = { + id: 'from-node', + type: 'add', + }; + const sub: Invocation<'sub'> = { + id: 'to-node', + type: 'sub', + }; + g.addEdge(add, 'value', sub, 'b'); it('should return true if the edge is in the graph', () => { - expect(g.hasEdge('from-node', 'value', 'to-node', 'b')).toBe(true); + expect(g.hasEdge(add, 'value', sub, 'b')).toBe(true); }); it('should return false if the edge is not in the graph', () => { - expect(g.hasEdge('from-node', 'value', 'to-node', 'a')).toBe(false); + expect(g.hasEdge(add, 'value', sub, 'a')).toBe(false); }); }); @@ -256,7 +186,15 @@ describe('Graph', () => { }); it('should raise an error if the graph is invalid', () => { const g = new Graph(); - g.addEdge('from-node', 'value', 'to-node', 'b'); + const add: Invocation<'add'> = { + id: 'from-node', + type: 'add', + }; + const sub: Invocation<'sub'> = { + id: 'to-node', + type: 'sub', + }; + g.addEdge(add, 'value', sub, 'b'); expect(() => g.getGraph()).toThrowError(AssertionError); }); }); @@ -264,7 +202,15 @@ describe('Graph', () => { describe('getGraphSafe', () => { it('should return the graph even if it is invalid', () => { const g = new Graph(); - g.addEdge('from-node', 'value', 'to-node', 'b'); + const add: Invocation<'add'> = { + id: 'from-node', + type: 'add', + }; + const sub: Invocation<'sub'> = { + id: 'to-node', + type: 'sub', + }; + g.addEdge(add, 'value', sub, 'b'); expect(g.getGraphSafe()).toBe(g._graph); }); }); @@ -276,8 +222,16 @@ describe('Graph', () => { }); it('should throw an error if the graph is invalid', () => { const g = new Graph(); + const add: Invocation<'add'> = { + id: 'from-node', + type: 'add', + }; + const sub: Invocation<'sub'> = { + id: 'to-node', + type: 'sub', + }; // edge from nowhere to nowhere - g.addEdge('from-node', 'value', 'to-node', 'b'); + g.addEdge(add, 'value', sub, 'b'); expect(() => g.validate()).toThrowError(AssertionError); }); }); @@ -304,34 +258,34 @@ describe('Graph', () => { id: 'n5', type: 'add', }); - const e1 = g.addEdge<'add', 'add'>(n1.id, 'value', n3.id, 'a'); - const e2 = g.addEdge<'alpha_mask_to_tensor', 'add'>(n2.id, 'height', n3.id, 'b'); - const e3 = g.addEdge<'add', 'add'>(n3.id, 'value', n4.id, 'a'); - const e4 = g.addEdge<'add', 'add'>(n3.id, 'value', n5.id, 'b'); + const e1 = g.addEdge(n1, 'value', n3, 'a'); + const e2 = g.addEdge(n2, 'height', n3, 'b'); + const e3 = g.addEdge(n3, 'value', n4, 'a'); + const e4 = g.addEdge(n3, 'value', n5, 'b'); describe('getEdgesFrom', () => { it('should return the edges that start at the provided node', () => { - expect(g.getEdgesFrom(n3.id)).toEqual([e3, e4]); + expect(g.getEdgesFrom(n3)).toEqual([e3, e4]); }); it('should return the edges that start at the provided node and have the provided field', () => { - expect(g.getEdgesFrom(n2.id, 'height')).toEqual([e2]); + expect(g.getEdgesFrom(n2, 'height')).toEqual([e2]); }); }); describe('getEdgesTo', () => { it('should return the edges that end at the provided node', () => { - expect(g.getEdgesTo(n3.id)).toEqual([e1, e2]); + expect(g.getEdgesTo(n3)).toEqual([e1, e2]); }); it('should return the edges that end at the provided node and have the provided field', () => { - expect(g.getEdgesTo(n3.id, 'b')).toEqual([e2]); + expect(g.getEdgesTo(n3, 'b')).toEqual([e2]); }); }); describe('getIncomers', () => { it('should return the nodes that have an edge to the provided node', () => { - expect(g.getIncomers(n3.id)).toEqual([n1, n2]); + expect(g.getIncomers(n3)).toEqual([n1, n2]); }); }); describe('getOutgoers', () => { it('should return the nodes that the provided node has an edge to', () => { - expect(g.getOutgoers(n3.id)).toEqual([n4, n5]); + expect(g.getOutgoers(n3)).toEqual([n4, n5]); }); }); }); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/Graph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/Graph.ts index ecade47b23..e25cbaa78d 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/Graph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/Graph.ts @@ -3,31 +3,26 @@ import type { AnyInvocation, AnyInvocationInputField, AnyInvocationOutputField, + InputFields, Invocation, - InvocationInputFields, - InvocationOutputFields, InvocationType, - S, + OutputFields, } from 'services/api/types'; -import type { O } from 'ts-toolbelt'; import { assert } from 'tsafe'; import { v4 as uuidv4 } from 'uuid'; -type GraphType = O.NonNullable>; -type Edge = GraphType['edges'][number]; -type Never = Record; +type Edge = { + source: { + node_id: string; + field: AnyInvocationOutputField; + }; + destination: { + node_id: string; + field: AnyInvocationInputField; + }; +}; -// The `core_metadata` node has very lax types, it accepts arbitrary field names. It must be excluded from edge utils -// to preview their types from being widened from a union of valid field names to `string | number | symbol`. -type EdgeNodeType = Exclude; - -type EdgeFromField = TFrom extends EdgeNodeType - ? InvocationOutputFields - : AnyInvocationOutputField; - -type EdgeToField = TTo extends EdgeNodeType - ? InvocationInputFields - : AnyInvocationInputField; +type GraphType = { id: string; nodes: Record; edges: Edge[] }; export class Graph { _graph: GraphType; @@ -64,45 +59,12 @@ export class Graph { /** * Gets a node from the graph. * @param id The id of the node to get. - * @param type The type of the node to get. If provided, the retrieved node is guaranteed to be of this type. * @returns The node. * @raises `AssertionError` if the node does not exist or if a `type` is provided but the node is not of the expected type. */ - getNode(id: string, type?: T): Invocation { + getNode(id: string): AnyInvocation { const node = this._graph.nodes[id]; assert(node !== undefined, Graph.getNodeNotFoundMsg(id)); - if (type) { - assert(node.type === type, Graph.getNodeNotOfTypeMsg(node, type)); - } - // We just asserted that the node type is correct, this is OK to cast - return node as Invocation; - } - - /** - * Gets a node from the graph without raising an error if the node does not exist or is not of the expected type. - * @param id The id of the node to get. - * @param type The type of the node to get. If provided, node is guaranteed to be of this type. - * @returns The node, if it exists and is of the correct type. Otherwise, `undefined`. - */ - getNodeSafe(id: string, type?: T): Invocation | undefined { - try { - return this.getNode(id, type); - } catch { - return undefined; - } - } - - /** - * Update a node in the graph. Properties are shallow-copied from `updates` to the node. - * @param id The id of the node to update. - * @param type The type of the node to update. If provided, node is guaranteed to be of this type. - * @param updates The fields to update on the node. - * @returns The updated node. - * @raises `AssertionError` if the node does not exist or its type doesn't match. - */ - updateNode(id: string, type: T, updates: Partial>): Invocation { - const node = this.getNode(id, type); - Object.assign(node, updates); return node; } @@ -125,8 +87,8 @@ export class Graph { * @returns The incoming nodes. * @raises `AssertionError` if the node does not exist. */ - getIncomers(nodeId: string): AnyInvocation[] { - return this.getEdgesTo(nodeId).map((edge) => this.getNode(edge.source.node_id)); + getIncomers(node: AnyInvocation): AnyInvocation[] { + return this.getEdgesTo(node).map((edge) => this.getNode(edge.source.node_id)); } /** @@ -135,8 +97,8 @@ export class Graph { * @returns The outgoing nodes. * @raises `AssertionError` if the node does not exist. */ - getOutgoers(nodeId: string): AnyInvocation[] { - return this.getEdgesFrom(nodeId).map((edge) => this.getNode(edge.destination.node_id)); + getOutgoers(node: AnyInvocation): AnyInvocation[] { + return this.getEdgesFrom(node).map((edge) => this.getNode(edge.destination.node_id)); } //#endregion @@ -144,28 +106,26 @@ export class Graph { /** * Add an edge to the graph. If an edge with the same source and destination already exists, an `AssertionError` is raised. - * Provide the from and to node types as generics to get type hints for from and to field names. - * @param fromNodeId The id of the source node. + * If providing node ids, provide the from and to node types as generics to get type hints for from and to field names. + * @param fromNode The source node or id of the source node. * @param fromField The field of the source node. - * @param toNodeId The id of the destination node. + * @param toNode The source node or id of the destination node. * @param toField The field of the destination node. * @returns The added edge. * @raises `AssertionError` if an edge with the same source and destination already exists. */ - addEdge( - fromNodeId: string, - fromField: EdgeFromField, - toNodeId: string, - toField: EdgeToField + addEdge( + fromNode: TFrom, + fromField: OutputFields, + toNode: TTo, + toField: InputFields ): Edge { - const edge = { - source: { node_id: fromNodeId, field: fromField }, - destination: { node_id: toNodeId, field: toField }, + const edge: Edge = { + source: { node_id: fromNode.id, field: fromField }, + destination: { node_id: toNode.id, field: toField }, }; - assert( - !this._graph.edges.some((e) => isEqual(e, edge)), - Graph.getEdgeAlreadyExistsMsg(fromNodeId, fromField, toNodeId, toField) - ); + const edgeAlreadyExists = this._graph.edges.some((e) => isEqual(e, edge)); + assert(!edgeAlreadyExists, Graph.getEdgeAlreadyExistsMsg(fromNode.id, fromField, toNode.id, toField)); this._graph.edges.push(edge); return edge; } @@ -180,45 +140,23 @@ export class Graph { * @returns The edge. * @raises `AssertionError` if the edge does not exist. */ - getEdge( - fromNode: string, - fromField: EdgeFromField, - toNode: string, - toField: EdgeToField + getEdge( + fromNode: TFrom, + fromField: OutputFields, + toNode: TTo, + toField: InputFields ): Edge { const edge = this._graph.edges.find( (e) => - e.source.node_id === fromNode && + e.source.node_id === fromNode.id && e.source.field === fromField && - e.destination.node_id === toNode && + e.destination.node_id === toNode.id && e.destination.field === toField ); - assert(edge !== undefined, Graph.getEdgeNotFoundMsg(fromNode, fromField, toNode, toField)); + assert(edge !== undefined, Graph.getEdgeNotFoundMsg(fromNode.id, fromField, toNode.id, toField)); return edge; } - /** - * Get an edge from the graph, or undefined if it doesn't exist. - * Provide the from and to node types as generics to get type hints for from and to field names. - * @param fromNodeId The id of the source node. - * @param fromField The field of the source node. - * @param toNodeId The id of the destination node. - * @param toField The field of the destination node. - * @returns The edge, or undefined if it doesn't exist. - */ - getEdgeSafe( - fromNode: string, - fromField: EdgeFromField, - toNode: string, - toField: EdgeToField - ): Edge | undefined { - try { - return this.getEdge(fromNode, fromField, toNode, toField); - } catch { - return undefined; - } - } - /** * Check if a graph has an edge. * Provide the from and to node types as generics to get type hints for from and to field names. @@ -229,11 +167,11 @@ export class Graph { * @returns Whether the graph has the edge. */ - hasEdge( - fromNode: string, - fromField: EdgeFromField, - toNode: string, - toField: EdgeToField + hasEdge( + fromNode: TFrom, + fromField: OutputFields, + toNode: TTo, + toField: InputFields ): boolean { try { this.getEdge(fromNode, fromField, toNode, toField); @@ -250,8 +188,8 @@ export class Graph { * @param fromField The field of the source node (optional). * @returns The edges. */ - getEdgesFrom(fromNodeId: string, fromField?: EdgeFromField): Edge[] { - let edges = this._graph.edges.filter((edge) => edge.source.node_id === fromNodeId); + getEdgesFrom(fromNode: T, fromField?: OutputFields): Edge[] { + let edges = this._graph.edges.filter((edge) => edge.source.node_id === fromNode.id); if (fromField) { edges = edges.filter((edge) => edge.source.field === fromField); } @@ -265,8 +203,8 @@ export class Graph { * @param toField The field of the destination node (optional). * @returns The edges. */ - getEdgesTo(toNodeId: string, toField?: EdgeToField): Edge[] { - let edges = this._graph.edges.filter((edge) => edge.destination.node_id === toNodeId); + getEdgesTo(toNode: T, toField?: InputFields): Edge[] { + let edges = this._graph.edges.filter((edge) => edge.destination.node_id === toNode.id); if (toField) { edges = edges.filter((edge) => edge.destination.field === toField); } @@ -284,11 +222,11 @@ export class Graph { /** * Delete all edges to a node. If `toField` is provided, only edges to that field are deleted. * Provide the to node type as a generic to get type hints for to field names. - * @param toNodeId The id of the destination node. + * @param toNode The destination node. * @param toField The field of the destination node (optional). */ - deleteEdgesTo(toNodeId: string, toField?: EdgeToField): void { - for (const edge of this.getEdgesTo(toNodeId, toField)) { + deleteEdgesTo(toNode: T, toField?: InputFields): void { + for (const edge of this.getEdgesTo(toNode, toField)) { this._deleteEdge(edge); } } @@ -299,8 +237,8 @@ export class Graph { * @param toNodeId The id of the source node. * @param toField The field of the source node (optional). */ - deleteEdgesFrom(fromNodeId: string, fromField?: EdgeFromField): void { - for (const edge of this.getEdgesFrom(fromNodeId, fromField)) { + deleteEdgesFrom(fromNode: T, fromField?: OutputFields): void { + for (const edge of this.getEdgesFrom(fromNode, fromField)) { this._deleteEdge(edge); } } diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/MetadataUtil.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/MetadataUtil.test.ts index ba76e43632..69e3676641 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/MetadataUtil.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/MetadataUtil.test.ts @@ -10,6 +10,7 @@ describe('MetadataUtil', () => { describe('getNode', () => { it('should return the metadata node if one exists', () => { const g = new Graph(); + // @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing const metadataNode = g.addNode({ id: MetadataUtil.metadataNodeId, type: 'core_metadata' }); expect(MetadataUtil.getNode(g)).toEqual(metadataNode); }); @@ -56,14 +57,16 @@ describe('MetadataUtil', () => { it('should add an edge from from metadata to the receiving node', () => { const n = g.addNode({ id: 'my-node', type: 'img_resize' }); MetadataUtil.add(g, { foo: 'bar' }); - MetadataUtil.setMetadataReceivingNode(g, n.id); - expect(g.hasEdge(MetadataUtil.metadataNodeId, 'metadata', n.id, 'metadata')).toBe(true); + MetadataUtil.setMetadataReceivingNode(g, n); + // @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing + expect(g.hasEdge(MetadataUtil.getNode(g), 'metadata', n, 'metadata')).toBe(true); }); it('should remove existing metadata edges', () => { const n2 = g.addNode({ id: 'my-other-node', type: 'img_resize' }); - MetadataUtil.setMetadataReceivingNode(g, n2.id); - expect(g.getIncomers(n2.id).length).toBe(1); - expect(g.hasEdge(MetadataUtil.metadataNodeId, 'metadata', n2.id, 'metadata')).toBe(true); + MetadataUtil.setMetadataReceivingNode(g, n2); + expect(g.getIncomers(n2).length).toBe(1); + // @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing + expect(g.hasEdge(MetadataUtil.getNode(g), 'metadata', n2, 'metadata')).toBe(true); }); }); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/MetadataUtil.ts b/invokeai/frontend/web/src/features/nodes/util/graph/MetadataUtil.ts index a51cebd21e..38e57a5e65 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/MetadataUtil.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/MetadataUtil.ts @@ -1,34 +1,50 @@ import type { ModelIdentifierField } from 'features/nodes/types/common'; import { METADATA } from 'features/nodes/util/graph/constants'; import { isString, unset } from 'lodash-es'; -import type { AnyModelConfig, Invocation } from 'services/api/types'; +import type { + AnyInvocation, + AnyInvocationIncMetadata, + AnyModelConfig, + CoreMetadataInvocation, + S, +} from 'services/api/types'; +import { assert } from 'tsafe'; import type { Graph } from './Graph'; +const isCoreMetadata = (node: S['Graph']['nodes'][string]): node is CoreMetadataInvocation => + node.type === 'core_metadata'; + export class MetadataUtil { static metadataNodeId = METADATA; - static getNode(graph: Graph): Invocation<'core_metadata'> { - return graph.getNode(this.metadataNodeId, 'core_metadata'); + static getNode(g: Graph): CoreMetadataInvocation { + const node = g.getNode(this.metadataNodeId) as AnyInvocationIncMetadata; + assert(isCoreMetadata(node)); + return node; } - static add(graph: Graph, metadata: Partial>): Invocation<'core_metadata'> { - const metadataNode = graph.getNodeSafe(this.metadataNodeId, 'core_metadata'); - if (!metadataNode) { - return graph.addNode({ + static add(g: Graph, metadata: Partial): CoreMetadataInvocation { + try { + const node = g.getNode(this.metadataNodeId) as AnyInvocationIncMetadata; + assert(isCoreMetadata(node)); + Object.assign(node, metadata); + return node; + } catch { + const metadataNode: CoreMetadataInvocation = { id: this.metadataNodeId, type: 'core_metadata', ...metadata, - }); - } else { - return graph.updateNode(this.metadataNodeId, 'core_metadata', metadata); + }; + // @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing + return g.addNode(metadataNode); } } - static remove(graph: Graph, key: string): Invocation<'core_metadata'>; - static remove(graph: Graph, keys: string[]): Invocation<'core_metadata'>; - static remove(graph: Graph, keyOrKeys: string | string[]): Invocation<'core_metadata'> { - const metadataNode = this.getNode(graph); + static remove(g: Graph, key: string): CoreMetadataInvocation; + static remove(g: Graph, keys: string[]): CoreMetadataInvocation; + static remove(g: Graph, keyOrKeys: string | string[]): CoreMetadataInvocation { + const metadataNode = this.getNode(g); if (isString(keyOrKeys)) { unset(metadataNode, keyOrKeys); } else { @@ -39,10 +55,11 @@ export class MetadataUtil { return metadataNode; } - static setMetadataReceivingNode(graph: Graph, nodeId: string): void { - // We need to break the rules to update metadata - `addEdge` doesn't allow `core_metadata` as a node type - graph._graph.edges = graph._graph.edges.filter((edge) => edge.source.node_id !== this.metadataNodeId); - graph.addEdge(this.metadataNodeId, 'metadata', nodeId, 'metadata'); + static setMetadataReceivingNode(g: Graph, node: AnyInvocation): void { + // @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing + g.deleteEdgesFrom(this.getNode(g)); + // @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing + g.addEdge(this.getNode(g), 'metadata', node, 'metadata'); } static getModelMetadataField({ key, hash, name, base, type }: AnyModelConfig): ModelIdentifierField { diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index 8d41fd6474..e22f73ed9e 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -131,30 +131,29 @@ export type WorkflowRecordListItemDTO = S['WorkflowRecordListItemDTO']; export type KeysOfUnion = T extends T ? keyof T : never; -export type NonInputFields = 'id' | 'type' | 'is_intermediate' | 'use_cache'; -export type NonOutputFields = 'type'; -export type AnyInvocation = Graph['nodes'][string]; -export type AnyInvocationExcludeCoreMetata = Exclude; -export type InvocationType = AnyInvocation['type']; -export type InvocationTypeExcludeCoreMetadata = Exclude; +export type AnyInvocation = Exclude< + Graph['nodes'][string], + S['CoreMetadataInvocation'] | S['MetadataInvocation'] | S['MetadataItemInvocation'] | S['MergeMetadataInvocation'] +>; +export type AnyInvocationIncMetadata = S['Graph']['nodes'][string]; +export type InvocationType = AnyInvocation['type']; export type InvocationOutputMap = S['InvocationOutputMap']; export type AnyInvocationOutput = InvocationOutputMap[InvocationType]; export type Invocation = Extract; -export type InvocationExcludeCoreMetadata = Extract< - AnyInvocation, - { type: T } ->; -export type InvocationInputFields = Exclude< - keyof Invocation, - NonInputFields ->; -export type AnyInvocationInputField = Exclude, NonInputFields>; - export type InvocationOutput = InvocationOutputMap[T]; -export type InvocationOutputFields = Exclude, NonOutputFields>; -export type AnyInvocationOutputField = Exclude, NonOutputFields>; + +export type NonInputFields = 'id' | 'type' | 'is_intermediate' | 'use_cache' | 'board' | 'metadata'; +export type AnyInvocationInputField = Exclude>, NonInputFields>; +export type InputFields = Extract; + +export type NonOutputFields = 'type'; +export type AnyInvocationOutputField = Exclude>, NonOutputFields>; +export type OutputFields = Extract< + keyof InvocationOutputMap[T['type']], + AnyInvocationOutputField +>; // General nodes export type CollectInvocation = Invocation<'collect'>;