diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/Graph.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/Graph.test.ts new file mode 100644 index 0000000000..6a38dbd218 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graph/Graph.test.ts @@ -0,0 +1,338 @@ +import { Graph } from 'features/nodes/util/graph/Graph'; +import type { Invocation } from 'services/api/types'; +import { assert, AssertionError, is } from 'tsafe'; +import { validate } from 'uuid'; +import { describe, expect, it } from 'vitest'; + +describe('Graph', () => { + describe('constructor', () => { + it('should create a new graph with the correct id', () => { + const g = new Graph('test-id'); + expect(g._graph.id).toBe('test-id'); + }); + it('should create a new graph with a uuid id if none is provided', () => { + const g = new Graph(); + expect(g._graph.id).not.toBeUndefined(); + expect(validate(g._graph.id)).toBeTruthy(); + }); + }); + + describe('addNode', () => { + const testNode = { + id: 'test-node', + type: 'add', + } as const; + it('should add a node to the graph', () => { + const g = new Graph(); + g.addNode(testNode); + expect(g._graph.nodes['test-node']).not.toBeUndefined(); + expect(g._graph.nodes['test-node']?.type).toBe('add'); + }); + it('should set is_intermediate to true if not provided', () => { + const g = new Graph(); + g.addNode(testNode); + expect(g._graph.nodes['test-node']?.is_intermediate).toBe(true); + }); + it('should not overwrite is_intermediate if provided', () => { + const g = new Graph(); + g.addNode({ + ...testNode, + is_intermediate: false, + }); + expect(g._graph.nodes['test-node']?.is_intermediate).toBe(false); + }); + it('should set use_cache to true if not provided', () => { + const g = new Graph(); + g.addNode(testNode); + expect(g._graph.nodes['test-node']?.use_cache).toBe(true); + }); + it('should not overwrite use_cache if provided', () => { + const g = new Graph(); + g.addNode({ + ...testNode, + use_cache: false, + }); + expect(g._graph.nodes['test-node']?.use_cache).toBe(false); + }); + it('should error if the node id is already in the graph', () => { + const g = new Graph(); + g.addNode(testNode); + expect(() => g.addNode(testNode)).toThrowError(AssertionError); + }); + it('should infer the types if provided', () => { + const g = new Graph(); + const node = g.addNode(testNode); + assert(is>(node)); + const g2 = new Graph(); + // @ts-expect-error The node object is an `add` type, but the generic is a `sub` type + g2.addNode<'sub'>(testNode); + }); + }); + + describe('updateNode', () => { + it('should update the node with the provided id', () => { + const g = new Graph(); + const node: Invocation<'add'> = { + id: 'test-node', + type: 'add', + a: 1, + }; + g.addNode(node); + const updatedNode = g.updateNode('test-node', 'add', { + a: 2, + }); + expect(g.getNode('test-node', 'add').a).toBe(2); + expect(node).toBe(updatedNode); + }); + it('should throw an error if the node is not found', () => { + expect(() => new Graph().updateNode('not-found', 'add', {})).toThrowError(AssertionError); + }); + it('should throw an error if the node is found but has the wrong type', () => { + const g = new Graph(); + g.addNode({ + id: 'test-node', + type: 'add', + a: 1, + }); + expect(() => g.updateNode('test-node', 'sub', {})).toThrowError(AssertionError); + }); + it('should infer types correctly when `type` is omitted', () => { + const g = new Graph(); + g.addNode({ + id: 'test-node', + type: 'add', + a: 1, + }); + const updatedNode = g.updateNode('test-node', 'add', { + a: 2, + }); + assert(is>(updatedNode)); + }); + it('should infer types correctly when `type` is provided', () => { + const g = new Graph(); + g.addNode({ + id: 'test-node', + type: 'add', + a: 1, + }); + const updatedNode = g.updateNode('test-node', 'add', { + a: 2, + }); + assert(is>(updatedNode)); + }); + }); + + describe('addEdge', () => { + it('should add an edge to the graph with the provided values', () => { + const g = new Graph(); + g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b'); + expect(g._graph.edges.length).toBe(1); + expect(g._graph.edges[0]).toEqual({ + source: { node_id: 'from-node', field: 'value' }, + destination: { node_id: 'to-node', field: 'b' }, + }); + }); + it('should throw an error if the edge already exists', () => { + const g = new Graph(); + g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b'); + expect(() => g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b')).toThrowError(AssertionError); + }); + it('should infer field names', () => { + const g = new Graph(); + // @ts-expect-error The first field must be a valid output field of the first type arg + g.addEdge<'add', 'sub'>('from-node', 'not-a-valid-field', 'to-node', 'a'); + // @ts-expect-error The second field must be a valid input field of the second type arg + g.addEdge<'add', 'sub'>('from-node-2', 'value', 'to-node-2', 'not-a-valid-field'); + // @ts-expect-error The first field must be any valid output field + g.addEdge('from-node-3', 'not-a-valid-field', 'to-node-3', 'a'); + // @ts-expect-error The second field must be any valid input field + g.addEdge('from-node-4', 'clip', 'to-node-4', 'not-a-valid-field'); + }); + }); + + describe('getNode', () => { + const g = new Graph(); + const node = g.addNode({ + id: 'test-node', + type: 'add', + }); + + it('should return the node with the provided id', () => { + const n = g.getNode('test-node'); + expect(n).toBe(node); + }); + it('should return the node with the provided id and type', () => { + const n = g.getNode('test-node', 'add'); + expect(n).toBe(node); + assert(is>(node)); + }); + it('should throw an error if the node is not found', () => { + expect(() => g.getNode('not-found')).toThrowError(AssertionError); + }); + it('should throw an error if the node is found but has the wrong type', () => { + expect(() => g.getNode('test-node', 'sub')).toThrowError(AssertionError); + }); + }); + + describe('getNodeSafe', () => { + const g = new Graph(); + const node = g.addNode({ + id: 'test-node', + type: 'add', + }); + it('should return the node if it is found', () => { + expect(g.getNodeSafe('test-node')).toBe(node); + }); + it('should return the node if it is found with the provided type', () => { + expect(g.getNodeSafe('test-node')).toBe(node); + assert(is>(node)); + }); + it("should return undefined if the node isn't found", () => { + expect(g.getNodeSafe('not-found')).toBeUndefined(); + }); + it('should return undefined if the node is found but has the wrong type', () => { + expect(g.getNodeSafe('test-node', 'sub')).toBeUndefined(); + }); + }); + + describe('hasNode', () => { + const g = new Graph(); + g.addNode({ + id: 'test-node', + type: 'add', + }); + + it('should return true if the node is in the graph', () => { + expect(g.hasNode('test-node')).toBe(true); + }); + it('should return false if the node is not in the graph', () => { + expect(g.hasNode('not-found')).toBe(false); + }); + }); + + describe('getEdge', () => { + const g = new Graph(); + g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b'); + it('should return the edge with the provided values', () => { + expect(g.getEdge('from-node', 'value', 'to-node', 'b')).toEqual({ + source: { node_id: 'from-node', field: 'value' }, + destination: { node_id: 'to-node', field: 'b' }, + }); + }); + it('should throw an error if the edge is not found', () => { + expect(() => g.getEdge('from-node', 'value', 'to-node', 'a')).toThrowError(AssertionError); + }); + }); + + describe('getEdgeSafe', () => { + const g = new Graph(); + g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b'); + it('should return the edge if it is found', () => { + expect(g.getEdgeSafe('from-node', 'value', 'to-node', 'b')).toEqual({ + source: { node_id: 'from-node', field: 'value' }, + destination: { node_id: 'to-node', field: 'b' }, + }); + }); + it('should return undefined if the edge is not found', () => { + expect(g.getEdgeSafe('from-node', 'value', 'to-node', 'a')).toBeUndefined(); + }); + }); + + describe('hasEdge', () => { + const g = new Graph(); + g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b'); + it('should return true if the edge is in the graph', () => { + expect(g.hasEdge('from-node', 'value', 'to-node', 'b')).toBe(true); + }); + it('should return false if the edge is not in the graph', () => { + expect(g.hasEdge('from-node', 'value', 'to-node', 'a')).toBe(false); + }); + }); + + describe('getGraph', () => { + it('should return the graph', () => { + const g = new Graph(); + expect(g.getGraph()).toBe(g._graph); + }); + it('should raise an error if the graph is invalid', () => { + const g = new Graph(); + g.addEdge('from-node', 'value', 'to-node', 'b'); + expect(() => g.getGraph()).toThrowError(AssertionError); + }); + }); + + describe('getGraphSafe', () => { + it('should return the graph even if it is invalid', () => { + const g = new Graph(); + g.addEdge('from-node', 'value', 'to-node', 'b'); + expect(g.getGraphSafe()).toBe(g._graph); + }); + }); + + describe('validate', () => { + it('should not throw an error if the graph is valid', () => { + const g = new Graph(); + expect(() => g.validate()).not.toThrow(); + }); + it('should throw an error if the graph is invalid', () => { + const g = new Graph(); + // edge from nowhere to nowhere + g.addEdge('from-node', 'value', 'to-node', 'b'); + expect(() => g.validate()).toThrowError(AssertionError); + }); + }); + + describe('traversal', () => { + const g = new Graph(); + const n1 = g.addNode({ + id: 'n1', + type: 'add', + }); + const n2 = g.addNode({ + id: 'n2', + type: 'alpha_mask_to_tensor', + }); + const n3 = g.addNode({ + id: 'n3', + type: 'add', + }); + const n4 = g.addNode({ + id: 'n4', + type: 'add', + }); + const n5 = g.addNode({ + id: 'n5', + type: 'add', + }); + const e1 = g.addEdge<'add', 'add'>(n1.id, 'value', n3.id, 'a'); + const e2 = g.addEdge<'alpha_mask_to_tensor', 'add'>(n2.id, 'height', n3.id, 'b'); + const e3 = g.addEdge<'add', 'add'>(n3.id, 'value', n4.id, 'a'); + const e4 = g.addEdge<'add', 'add'>(n3.id, 'value', n5.id, 'b'); + describe('getEdgesFrom', () => { + it('should return the edges that start at the provided node', () => { + expect(g.getEdgesFrom(n3.id)).toEqual([e3, e4]); + }); + it('should return the edges that start at the provided node and have the provided field', () => { + expect(g.getEdgesFrom(n2.id, 'height')).toEqual([e2]); + }); + }); + describe('getEdgesTo', () => { + it('should return the edges that end at the provided node', () => { + expect(g.getEdgesTo(n3.id)).toEqual([e1, e2]); + }); + it('should return the edges that end at the provided node and have the provided field', () => { + expect(g.getEdgesTo(n3.id, 'b')).toEqual([e2]); + }); + }); + describe('getIncomers', () => { + it('should return the nodes that have an edge to the provided node', () => { + expect(g.getIncomers(n3.id)).toEqual([n1, n2]); + }); + }); + describe('getOutgoers', () => { + it('should return the nodes that the provided node has an edge to', () => { + expect(g.getOutgoers(n3.id)).toEqual([n4, n5]); + }); + }); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/Graph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/Graph.ts new file mode 100644 index 0000000000..ecade47b23 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graph/Graph.ts @@ -0,0 +1,366 @@ +import { isEqual } from 'lodash-es'; +import type { + AnyInvocation, + AnyInvocationInputField, + AnyInvocationOutputField, + Invocation, + InvocationInputFields, + InvocationOutputFields, + InvocationType, + S, +} from 'services/api/types'; +import type { O } from 'ts-toolbelt'; +import { assert } from 'tsafe'; +import { v4 as uuidv4 } from 'uuid'; + +type GraphType = O.NonNullable>; +type Edge = GraphType['edges'][number]; +type Never = Record; + +// The `core_metadata` node has very lax types, it accepts arbitrary field names. It must be excluded from edge utils +// to preview their types from being widened from a union of valid field names to `string | number | symbol`. +type EdgeNodeType = Exclude; + +type EdgeFromField = TFrom extends EdgeNodeType + ? InvocationOutputFields + : AnyInvocationOutputField; + +type EdgeToField = TTo extends EdgeNodeType + ? InvocationInputFields + : AnyInvocationInputField; + +export class Graph { + _graph: GraphType; + + constructor(id?: string) { + this._graph = { + id: id ?? Graph.uuid(), + nodes: {}, + edges: [], + }; + } + + //#region Node Operations + + /** + * Add a node to the graph. If a node with the same id already exists, an `AssertionError` is raised. + * The optional `is_intermediate` and `use_cache` fields are set to `true` and `true` respectively if not set on the node. + * @param node The node to add. + * @returns The added node. + * @raises `AssertionError` if a node with the same id already exists. + */ + addNode(node: Invocation): Invocation { + assert(this._graph.nodes[node.id] === undefined, Graph.getNodeAlreadyExistsMsg(node.id)); + if (node.is_intermediate === undefined) { + node.is_intermediate = true; + } + if (node.use_cache === undefined) { + node.use_cache = true; + } + this._graph.nodes[node.id] = node; + return node; + } + + /** + * Gets a node from the graph. + * @param id The id of the node to get. + * @param type The type of the node to get. If provided, the retrieved node is guaranteed to be of this type. + * @returns The node. + * @raises `AssertionError` if the node does not exist or if a `type` is provided but the node is not of the expected type. + */ + getNode(id: string, type?: T): Invocation { + const node = this._graph.nodes[id]; + assert(node !== undefined, Graph.getNodeNotFoundMsg(id)); + if (type) { + assert(node.type === type, Graph.getNodeNotOfTypeMsg(node, type)); + } + // We just asserted that the node type is correct, this is OK to cast + return node as Invocation; + } + + /** + * Gets a node from the graph without raising an error if the node does not exist or is not of the expected type. + * @param id The id of the node to get. + * @param type The type of the node to get. If provided, node is guaranteed to be of this type. + * @returns The node, if it exists and is of the correct type. Otherwise, `undefined`. + */ + getNodeSafe(id: string, type?: T): Invocation | undefined { + try { + return this.getNode(id, type); + } catch { + return undefined; + } + } + + /** + * Update a node in the graph. Properties are shallow-copied from `updates` to the node. + * @param id The id of the node to update. + * @param type The type of the node to update. If provided, node is guaranteed to be of this type. + * @param updates The fields to update on the node. + * @returns The updated node. + * @raises `AssertionError` if the node does not exist or its type doesn't match. + */ + updateNode(id: string, type: T, updates: Partial>): Invocation { + const node = this.getNode(id, type); + Object.assign(node, updates); + return node; + } + + /** + * Check if a node exists in the graph. + * @param id The id of the node to check. + */ + hasNode(id: string): boolean { + try { + this.getNode(id); + return true; + } catch { + return false; + } + } + + /** + * Get the immediate incomers of a node. + * @param nodeId The id of the node to get the incomers of. + * @returns The incoming nodes. + * @raises `AssertionError` if the node does not exist. + */ + getIncomers(nodeId: string): AnyInvocation[] { + return this.getEdgesTo(nodeId).map((edge) => this.getNode(edge.source.node_id)); + } + + /** + * Get the immediate outgoers of a node. + * @param nodeId The id of the node to get the outgoers of. + * @returns The outgoing nodes. + * @raises `AssertionError` if the node does not exist. + */ + getOutgoers(nodeId: string): AnyInvocation[] { + return this.getEdgesFrom(nodeId).map((edge) => this.getNode(edge.destination.node_id)); + } + //#endregion + + //#region Edge Operations + + /** + * Add an edge to the graph. If an edge with the same source and destination already exists, an `AssertionError` is raised. + * Provide the from and to node types as generics to get type hints for from and to field names. + * @param fromNodeId The id of the source node. + * @param fromField The field of the source node. + * @param toNodeId The id of the destination node. + * @param toField The field of the destination node. + * @returns The added edge. + * @raises `AssertionError` if an edge with the same source and destination already exists. + */ + addEdge( + fromNodeId: string, + fromField: EdgeFromField, + toNodeId: string, + toField: EdgeToField + ): Edge { + const edge = { + source: { node_id: fromNodeId, field: fromField }, + destination: { node_id: toNodeId, field: toField }, + }; + assert( + !this._graph.edges.some((e) => isEqual(e, edge)), + Graph.getEdgeAlreadyExistsMsg(fromNodeId, fromField, toNodeId, toField) + ); + this._graph.edges.push(edge); + return edge; + } + + /** + * Get an edge from the graph. If the edge does not exist, an `AssertionError` is raised. + * Provide the from and to node types as generics to get type hints for from and to field names. + * @param fromNodeId The id of the source node. + * @param fromField The field of the source node. + * @param toNodeId The id of the destination node. + * @param toField The field of the destination node. + * @returns The edge. + * @raises `AssertionError` if the edge does not exist. + */ + getEdge( + fromNode: string, + fromField: EdgeFromField, + toNode: string, + toField: EdgeToField + ): Edge { + const edge = this._graph.edges.find( + (e) => + e.source.node_id === fromNode && + e.source.field === fromField && + e.destination.node_id === toNode && + e.destination.field === toField + ); + assert(edge !== undefined, Graph.getEdgeNotFoundMsg(fromNode, fromField, toNode, toField)); + return edge; + } + + /** + * Get an edge from the graph, or undefined if it doesn't exist. + * Provide the from and to node types as generics to get type hints for from and to field names. + * @param fromNodeId The id of the source node. + * @param fromField The field of the source node. + * @param toNodeId The id of the destination node. + * @param toField The field of the destination node. + * @returns The edge, or undefined if it doesn't exist. + */ + getEdgeSafe( + fromNode: string, + fromField: EdgeFromField, + toNode: string, + toField: EdgeToField + ): Edge | undefined { + try { + return this.getEdge(fromNode, fromField, toNode, toField); + } catch { + return undefined; + } + } + + /** + * Check if a graph has an edge. + * Provide the from and to node types as generics to get type hints for from and to field names. + * @param fromNodeId The id of the source node. + * @param fromField The field of the source node. + * @param toNodeId The id of the destination node. + * @param toField The field of the destination node. + * @returns Whether the graph has the edge. + */ + + hasEdge( + fromNode: string, + fromField: EdgeFromField, + toNode: string, + toField: EdgeToField + ): boolean { + try { + this.getEdge(fromNode, fromField, toNode, toField); + return true; + } catch { + return false; + } + } + + /** + * Get all edges from a node. If `fromField` is provided, only edges from that field are returned. + * Provide the from node type as a generic to get type hints for from field names. + * @param fromNodeId The id of the source node. + * @param fromField The field of the source node (optional). + * @returns The edges. + */ + getEdgesFrom(fromNodeId: string, fromField?: EdgeFromField): Edge[] { + let edges = this._graph.edges.filter((edge) => edge.source.node_id === fromNodeId); + if (fromField) { + edges = edges.filter((edge) => edge.source.field === fromField); + } + return edges; + } + + /** + * Get all edges to a node. If `toField` is provided, only edges to that field are returned. + * Provide the to node type as a generic to get type hints for to field names. + * @param toNodeId The id of the destination node. + * @param toField The field of the destination node (optional). + * @returns The edges. + */ + getEdgesTo(toNodeId: string, toField?: EdgeToField): Edge[] { + let edges = this._graph.edges.filter((edge) => edge.destination.node_id === toNodeId); + if (toField) { + edges = edges.filter((edge) => edge.destination.field === toField); + } + return edges; + } + + /** + * Delete _all_ matching edges from the graph. Uses _.isEqual for comparison. + * @param edge The edge to delete + */ + private _deleteEdge(edge: Edge): void { + this._graph.edges = this._graph.edges.filter((e) => !isEqual(e, edge)); + } + + /** + * Delete all edges to a node. If `toField` is provided, only edges to that field are deleted. + * Provide the to node type as a generic to get type hints for to field names. + * @param toNodeId The id of the destination node. + * @param toField The field of the destination node (optional). + */ + deleteEdgesTo(toNodeId: string, toField?: EdgeToField): void { + for (const edge of this.getEdgesTo(toNodeId, toField)) { + this._deleteEdge(edge); + } + } + + /** + * Delete all edges from a node. If `fromField` is provided, only edges from that field are deleted. + * Provide the from node type as a generic to get type hints for from field names. + * @param toNodeId The id of the source node. + * @param toField The field of the source node (optional). + */ + deleteEdgesFrom(fromNodeId: string, fromField?: EdgeFromField): void { + for (const edge of this.getEdgesFrom(fromNodeId, fromField)) { + this._deleteEdge(edge); + } + } + //#endregion + + //#region Graph Ops + + /** + * Validate the graph. Checks that all edges have valid source and destination nodes. + * TODO(psyche): Add more validation checks - cycles, valid invocation types, etc. + * @raises `AssertionError` if an edge has an invalid source or destination node. + */ + validate(): void { + for (const edge of this._graph.edges) { + this.getNode(edge.source.node_id); + this.getNode(edge.destination.node_id); + } + } + + /** + * Gets the graph after validating it. + * @returns The graph. + * @raises `AssertionError` if the graph is invalid. + */ + getGraph(): GraphType { + this.validate(); + return this._graph; + } + + /** + * Gets the graph without validating it. + * @returns The graph. + */ + getGraphSafe(): GraphType { + return this._graph; + } + //#endregion + + //#region Util + + static getNodeNotFoundMsg(id: string): string { + return `Node ${id} not found`; + } + + static getNodeNotOfTypeMsg(node: AnyInvocation, expectedType: InvocationType): string { + return `Node ${node.id} is not of type ${expectedType}: ${node.type}`; + } + + static getNodeAlreadyExistsMsg(id: string): string { + return `Node ${id} already exists`; + } + + static getEdgeNotFoundMsg(fromNodeId: string, fromField: string, toNodeId: string, toField: string) { + return `Edge from ${fromNodeId}.${fromField} to ${toNodeId}.${toField} not found`; + } + + static getEdgeAlreadyExistsMsg(fromNodeId: string, fromField: string, toNodeId: string, toField: string) { + return `Edge from ${fromNodeId}.${fromField} to ${toNodeId}.${toField} already exists`; + } + + static uuid = uuidv4; + //#endregion +}