mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): add stateful Graph class
This stateful class provides abstractions for building a graph. It exposes graph methods like adding and removing nodes and edges. The methods are documented, tested, and strongly typed.
This commit is contained in:
parent
e3289856c0
commit
9d685da759
@ -0,0 +1,338 @@
|
||||
import { Graph } from 'features/nodes/util/graph/Graph';
|
||||
import type { Invocation } from 'services/api/types';
|
||||
import { assert, AssertionError, is } from 'tsafe';
|
||||
import { validate } from 'uuid';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
describe('Graph', () => {
|
||||
describe('constructor', () => {
|
||||
it('should create a new graph with the correct id', () => {
|
||||
const g = new Graph('test-id');
|
||||
expect(g._graph.id).toBe('test-id');
|
||||
});
|
||||
it('should create a new graph with a uuid id if none is provided', () => {
|
||||
const g = new Graph();
|
||||
expect(g._graph.id).not.toBeUndefined();
|
||||
expect(validate(g._graph.id)).toBeTruthy();
|
||||
});
|
||||
});
|
||||
|
||||
describe('addNode', () => {
|
||||
const testNode = {
|
||||
id: 'test-node',
|
||||
type: 'add',
|
||||
} as const;
|
||||
it('should add a node to the graph', () => {
|
||||
const g = new Graph();
|
||||
g.addNode(testNode);
|
||||
expect(g._graph.nodes['test-node']).not.toBeUndefined();
|
||||
expect(g._graph.nodes['test-node']?.type).toBe('add');
|
||||
});
|
||||
it('should set is_intermediate to true if not provided', () => {
|
||||
const g = new Graph();
|
||||
g.addNode(testNode);
|
||||
expect(g._graph.nodes['test-node']?.is_intermediate).toBe(true);
|
||||
});
|
||||
it('should not overwrite is_intermediate if provided', () => {
|
||||
const g = new Graph();
|
||||
g.addNode({
|
||||
...testNode,
|
||||
is_intermediate: false,
|
||||
});
|
||||
expect(g._graph.nodes['test-node']?.is_intermediate).toBe(false);
|
||||
});
|
||||
it('should set use_cache to true if not provided', () => {
|
||||
const g = new Graph();
|
||||
g.addNode(testNode);
|
||||
expect(g._graph.nodes['test-node']?.use_cache).toBe(true);
|
||||
});
|
||||
it('should not overwrite use_cache if provided', () => {
|
||||
const g = new Graph();
|
||||
g.addNode({
|
||||
...testNode,
|
||||
use_cache: false,
|
||||
});
|
||||
expect(g._graph.nodes['test-node']?.use_cache).toBe(false);
|
||||
});
|
||||
it('should error if the node id is already in the graph', () => {
|
||||
const g = new Graph();
|
||||
g.addNode(testNode);
|
||||
expect(() => g.addNode(testNode)).toThrowError(AssertionError);
|
||||
});
|
||||
it('should infer the types if provided', () => {
|
||||
const g = new Graph();
|
||||
const node = g.addNode(testNode);
|
||||
assert(is<Invocation<'add'>>(node));
|
||||
const g2 = new Graph();
|
||||
// @ts-expect-error The node object is an `add` type, but the generic is a `sub` type
|
||||
g2.addNode<'sub'>(testNode);
|
||||
});
|
||||
});
|
||||
|
||||
describe('updateNode', () => {
|
||||
it('should update the node with the provided id', () => {
|
||||
const g = new Graph();
|
||||
const node: Invocation<'add'> = {
|
||||
id: 'test-node',
|
||||
type: 'add',
|
||||
a: 1,
|
||||
};
|
||||
g.addNode(node);
|
||||
const updatedNode = g.updateNode('test-node', 'add', {
|
||||
a: 2,
|
||||
});
|
||||
expect(g.getNode('test-node', 'add').a).toBe(2);
|
||||
expect(node).toBe(updatedNode);
|
||||
});
|
||||
it('should throw an error if the node is not found', () => {
|
||||
expect(() => new Graph().updateNode('not-found', 'add', {})).toThrowError(AssertionError);
|
||||
});
|
||||
it('should throw an error if the node is found but has the wrong type', () => {
|
||||
const g = new Graph();
|
||||
g.addNode({
|
||||
id: 'test-node',
|
||||
type: 'add',
|
||||
a: 1,
|
||||
});
|
||||
expect(() => g.updateNode('test-node', 'sub', {})).toThrowError(AssertionError);
|
||||
});
|
||||
it('should infer types correctly when `type` is omitted', () => {
|
||||
const g = new Graph();
|
||||
g.addNode({
|
||||
id: 'test-node',
|
||||
type: 'add',
|
||||
a: 1,
|
||||
});
|
||||
const updatedNode = g.updateNode('test-node', 'add', {
|
||||
a: 2,
|
||||
});
|
||||
assert(is<Invocation<'add'>>(updatedNode));
|
||||
});
|
||||
it('should infer types correctly when `type` is provided', () => {
|
||||
const g = new Graph();
|
||||
g.addNode({
|
||||
id: 'test-node',
|
||||
type: 'add',
|
||||
a: 1,
|
||||
});
|
||||
const updatedNode = g.updateNode('test-node', 'add', {
|
||||
a: 2,
|
||||
});
|
||||
assert(is<Invocation<'add'>>(updatedNode));
|
||||
});
|
||||
});
|
||||
|
||||
describe('addEdge', () => {
|
||||
it('should add an edge to the graph with the provided values', () => {
|
||||
const g = new Graph();
|
||||
g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b');
|
||||
expect(g._graph.edges.length).toBe(1);
|
||||
expect(g._graph.edges[0]).toEqual({
|
||||
source: { node_id: 'from-node', field: 'value' },
|
||||
destination: { node_id: 'to-node', field: 'b' },
|
||||
});
|
||||
});
|
||||
it('should throw an error if the edge already exists', () => {
|
||||
const g = new Graph();
|
||||
g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b');
|
||||
expect(() => g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b')).toThrowError(AssertionError);
|
||||
});
|
||||
it('should infer field names', () => {
|
||||
const g = new Graph();
|
||||
// @ts-expect-error The first field must be a valid output field of the first type arg
|
||||
g.addEdge<'add', 'sub'>('from-node', 'not-a-valid-field', 'to-node', 'a');
|
||||
// @ts-expect-error The second field must be a valid input field of the second type arg
|
||||
g.addEdge<'add', 'sub'>('from-node-2', 'value', 'to-node-2', 'not-a-valid-field');
|
||||
// @ts-expect-error The first field must be any valid output field
|
||||
g.addEdge('from-node-3', 'not-a-valid-field', 'to-node-3', 'a');
|
||||
// @ts-expect-error The second field must be any valid input field
|
||||
g.addEdge('from-node-4', 'clip', 'to-node-4', 'not-a-valid-field');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getNode', () => {
|
||||
const g = new Graph();
|
||||
const node = g.addNode({
|
||||
id: 'test-node',
|
||||
type: 'add',
|
||||
});
|
||||
|
||||
it('should return the node with the provided id', () => {
|
||||
const n = g.getNode('test-node');
|
||||
expect(n).toBe(node);
|
||||
});
|
||||
it('should return the node with the provided id and type', () => {
|
||||
const n = g.getNode('test-node', 'add');
|
||||
expect(n).toBe(node);
|
||||
assert(is<Invocation<'add'>>(node));
|
||||
});
|
||||
it('should throw an error if the node is not found', () => {
|
||||
expect(() => g.getNode('not-found')).toThrowError(AssertionError);
|
||||
});
|
||||
it('should throw an error if the node is found but has the wrong type', () => {
|
||||
expect(() => g.getNode('test-node', 'sub')).toThrowError(AssertionError);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getNodeSafe', () => {
|
||||
const g = new Graph();
|
||||
const node = g.addNode({
|
||||
id: 'test-node',
|
||||
type: 'add',
|
||||
});
|
||||
it('should return the node if it is found', () => {
|
||||
expect(g.getNodeSafe('test-node')).toBe(node);
|
||||
});
|
||||
it('should return the node if it is found with the provided type', () => {
|
||||
expect(g.getNodeSafe('test-node')).toBe(node);
|
||||
assert(is<Invocation<'add'>>(node));
|
||||
});
|
||||
it("should return undefined if the node isn't found", () => {
|
||||
expect(g.getNodeSafe('not-found')).toBeUndefined();
|
||||
});
|
||||
it('should return undefined if the node is found but has the wrong type', () => {
|
||||
expect(g.getNodeSafe('test-node', 'sub')).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('hasNode', () => {
|
||||
const g = new Graph();
|
||||
g.addNode({
|
||||
id: 'test-node',
|
||||
type: 'add',
|
||||
});
|
||||
|
||||
it('should return true if the node is in the graph', () => {
|
||||
expect(g.hasNode('test-node')).toBe(true);
|
||||
});
|
||||
it('should return false if the node is not in the graph', () => {
|
||||
expect(g.hasNode('not-found')).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getEdge', () => {
|
||||
const g = new Graph();
|
||||
g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b');
|
||||
it('should return the edge with the provided values', () => {
|
||||
expect(g.getEdge('from-node', 'value', 'to-node', 'b')).toEqual({
|
||||
source: { node_id: 'from-node', field: 'value' },
|
||||
destination: { node_id: 'to-node', field: 'b' },
|
||||
});
|
||||
});
|
||||
it('should throw an error if the edge is not found', () => {
|
||||
expect(() => g.getEdge('from-node', 'value', 'to-node', 'a')).toThrowError(AssertionError);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getEdgeSafe', () => {
|
||||
const g = new Graph();
|
||||
g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b');
|
||||
it('should return the edge if it is found', () => {
|
||||
expect(g.getEdgeSafe('from-node', 'value', 'to-node', 'b')).toEqual({
|
||||
source: { node_id: 'from-node', field: 'value' },
|
||||
destination: { node_id: 'to-node', field: 'b' },
|
||||
});
|
||||
});
|
||||
it('should return undefined if the edge is not found', () => {
|
||||
expect(g.getEdgeSafe('from-node', 'value', 'to-node', 'a')).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('hasEdge', () => {
|
||||
const g = new Graph();
|
||||
g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b');
|
||||
it('should return true if the edge is in the graph', () => {
|
||||
expect(g.hasEdge('from-node', 'value', 'to-node', 'b')).toBe(true);
|
||||
});
|
||||
it('should return false if the edge is not in the graph', () => {
|
||||
expect(g.hasEdge('from-node', 'value', 'to-node', 'a')).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getGraph', () => {
|
||||
it('should return the graph', () => {
|
||||
const g = new Graph();
|
||||
expect(g.getGraph()).toBe(g._graph);
|
||||
});
|
||||
it('should raise an error if the graph is invalid', () => {
|
||||
const g = new Graph();
|
||||
g.addEdge('from-node', 'value', 'to-node', 'b');
|
||||
expect(() => g.getGraph()).toThrowError(AssertionError);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getGraphSafe', () => {
|
||||
it('should return the graph even if it is invalid', () => {
|
||||
const g = new Graph();
|
||||
g.addEdge('from-node', 'value', 'to-node', 'b');
|
||||
expect(g.getGraphSafe()).toBe(g._graph);
|
||||
});
|
||||
});
|
||||
|
||||
describe('validate', () => {
|
||||
it('should not throw an error if the graph is valid', () => {
|
||||
const g = new Graph();
|
||||
expect(() => g.validate()).not.toThrow();
|
||||
});
|
||||
it('should throw an error if the graph is invalid', () => {
|
||||
const g = new Graph();
|
||||
// edge from nowhere to nowhere
|
||||
g.addEdge('from-node', 'value', 'to-node', 'b');
|
||||
expect(() => g.validate()).toThrowError(AssertionError);
|
||||
});
|
||||
});
|
||||
|
||||
describe('traversal', () => {
|
||||
const g = new Graph();
|
||||
const n1 = g.addNode({
|
||||
id: 'n1',
|
||||
type: 'add',
|
||||
});
|
||||
const n2 = g.addNode({
|
||||
id: 'n2',
|
||||
type: 'alpha_mask_to_tensor',
|
||||
});
|
||||
const n3 = g.addNode({
|
||||
id: 'n3',
|
||||
type: 'add',
|
||||
});
|
||||
const n4 = g.addNode({
|
||||
id: 'n4',
|
||||
type: 'add',
|
||||
});
|
||||
const n5 = g.addNode({
|
||||
id: 'n5',
|
||||
type: 'add',
|
||||
});
|
||||
const e1 = g.addEdge<'add', 'add'>(n1.id, 'value', n3.id, 'a');
|
||||
const e2 = g.addEdge<'alpha_mask_to_tensor', 'add'>(n2.id, 'height', n3.id, 'b');
|
||||
const e3 = g.addEdge<'add', 'add'>(n3.id, 'value', n4.id, 'a');
|
||||
const e4 = g.addEdge<'add', 'add'>(n3.id, 'value', n5.id, 'b');
|
||||
describe('getEdgesFrom', () => {
|
||||
it('should return the edges that start at the provided node', () => {
|
||||
expect(g.getEdgesFrom(n3.id)).toEqual([e3, e4]);
|
||||
});
|
||||
it('should return the edges that start at the provided node and have the provided field', () => {
|
||||
expect(g.getEdgesFrom(n2.id, 'height')).toEqual([e2]);
|
||||
});
|
||||
});
|
||||
describe('getEdgesTo', () => {
|
||||
it('should return the edges that end at the provided node', () => {
|
||||
expect(g.getEdgesTo(n3.id)).toEqual([e1, e2]);
|
||||
});
|
||||
it('should return the edges that end at the provided node and have the provided field', () => {
|
||||
expect(g.getEdgesTo(n3.id, 'b')).toEqual([e2]);
|
||||
});
|
||||
});
|
||||
describe('getIncomers', () => {
|
||||
it('should return the nodes that have an edge to the provided node', () => {
|
||||
expect(g.getIncomers(n3.id)).toEqual([n1, n2]);
|
||||
});
|
||||
});
|
||||
describe('getOutgoers', () => {
|
||||
it('should return the nodes that the provided node has an edge to', () => {
|
||||
expect(g.getOutgoers(n3.id)).toEqual([n4, n5]);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
366
invokeai/frontend/web/src/features/nodes/util/graph/Graph.ts
Normal file
366
invokeai/frontend/web/src/features/nodes/util/graph/Graph.ts
Normal file
@ -0,0 +1,366 @@
|
||||
import { isEqual } from 'lodash-es';
|
||||
import type {
|
||||
AnyInvocation,
|
||||
AnyInvocationInputField,
|
||||
AnyInvocationOutputField,
|
||||
Invocation,
|
||||
InvocationInputFields,
|
||||
InvocationOutputFields,
|
||||
InvocationType,
|
||||
S,
|
||||
} from 'services/api/types';
|
||||
import type { O } from 'ts-toolbelt';
|
||||
import { assert } from 'tsafe';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
type GraphType = O.NonNullable<O.Required<S['Graph']>>;
|
||||
type Edge = GraphType['edges'][number];
|
||||
type Never = Record<string, never>;
|
||||
|
||||
// The `core_metadata` node has very lax types, it accepts arbitrary field names. It must be excluded from edge utils
|
||||
// to preview their types from being widened from a union of valid field names to `string | number | symbol`.
|
||||
type EdgeNodeType = Exclude<InvocationType, 'core_metadata'>;
|
||||
|
||||
type EdgeFromField<TFrom extends EdgeNodeType | Never = Never> = TFrom extends EdgeNodeType
|
||||
? InvocationOutputFields<TFrom>
|
||||
: AnyInvocationOutputField;
|
||||
|
||||
type EdgeToField<TTo extends EdgeNodeType | Never = Never> = TTo extends EdgeNodeType
|
||||
? InvocationInputFields<TTo>
|
||||
: AnyInvocationInputField;
|
||||
|
||||
export class Graph {
|
||||
_graph: GraphType;
|
||||
|
||||
constructor(id?: string) {
|
||||
this._graph = {
|
||||
id: id ?? Graph.uuid(),
|
||||
nodes: {},
|
||||
edges: [],
|
||||
};
|
||||
}
|
||||
|
||||
//#region Node Operations
|
||||
|
||||
/**
|
||||
* Add a node to the graph. If a node with the same id already exists, an `AssertionError` is raised.
|
||||
* The optional `is_intermediate` and `use_cache` fields are set to `true` and `true` respectively if not set on the node.
|
||||
* @param node The node to add.
|
||||
* @returns The added node.
|
||||
* @raises `AssertionError` if a node with the same id already exists.
|
||||
*/
|
||||
addNode<T extends InvocationType>(node: Invocation<T>): Invocation<T> {
|
||||
assert(this._graph.nodes[node.id] === undefined, Graph.getNodeAlreadyExistsMsg(node.id));
|
||||
if (node.is_intermediate === undefined) {
|
||||
node.is_intermediate = true;
|
||||
}
|
||||
if (node.use_cache === undefined) {
|
||||
node.use_cache = true;
|
||||
}
|
||||
this._graph.nodes[node.id] = node;
|
||||
return node;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a node from the graph.
|
||||
* @param id The id of the node to get.
|
||||
* @param type The type of the node to get. If provided, the retrieved node is guaranteed to be of this type.
|
||||
* @returns The node.
|
||||
* @raises `AssertionError` if the node does not exist or if a `type` is provided but the node is not of the expected type.
|
||||
*/
|
||||
getNode<T extends InvocationType>(id: string, type?: T): Invocation<T> {
|
||||
const node = this._graph.nodes[id];
|
||||
assert(node !== undefined, Graph.getNodeNotFoundMsg(id));
|
||||
if (type) {
|
||||
assert(node.type === type, Graph.getNodeNotOfTypeMsg(node, type));
|
||||
}
|
||||
// We just asserted that the node type is correct, this is OK to cast
|
||||
return node as Invocation<T>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a node from the graph without raising an error if the node does not exist or is not of the expected type.
|
||||
* @param id The id of the node to get.
|
||||
* @param type The type of the node to get. If provided, node is guaranteed to be of this type.
|
||||
* @returns The node, if it exists and is of the correct type. Otherwise, `undefined`.
|
||||
*/
|
||||
getNodeSafe<T extends InvocationType>(id: string, type?: T): Invocation<T> | undefined {
|
||||
try {
|
||||
return this.getNode(id, type);
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update a node in the graph. Properties are shallow-copied from `updates` to the node.
|
||||
* @param id The id of the node to update.
|
||||
* @param type The type of the node to update. If provided, node is guaranteed to be of this type.
|
||||
* @param updates The fields to update on the node.
|
||||
* @returns The updated node.
|
||||
* @raises `AssertionError` if the node does not exist or its type doesn't match.
|
||||
*/
|
||||
updateNode<T extends InvocationType>(id: string, type: T, updates: Partial<Invocation<T>>): Invocation<T> {
|
||||
const node = this.getNode(id, type);
|
||||
Object.assign(node, updates);
|
||||
return node;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a node exists in the graph.
|
||||
* @param id The id of the node to check.
|
||||
*/
|
||||
hasNode(id: string): boolean {
|
||||
try {
|
||||
this.getNode(id);
|
||||
return true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the immediate incomers of a node.
|
||||
* @param nodeId The id of the node to get the incomers of.
|
||||
* @returns The incoming nodes.
|
||||
* @raises `AssertionError` if the node does not exist.
|
||||
*/
|
||||
getIncomers(nodeId: string): AnyInvocation[] {
|
||||
return this.getEdgesTo(nodeId).map((edge) => this.getNode(edge.source.node_id));
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the immediate outgoers of a node.
|
||||
* @param nodeId The id of the node to get the outgoers of.
|
||||
* @returns The outgoing nodes.
|
||||
* @raises `AssertionError` if the node does not exist.
|
||||
*/
|
||||
getOutgoers(nodeId: string): AnyInvocation[] {
|
||||
return this.getEdgesFrom(nodeId).map((edge) => this.getNode(edge.destination.node_id));
|
||||
}
|
||||
//#endregion
|
||||
|
||||
//#region Edge Operations
|
||||
|
||||
/**
|
||||
* Add an edge to the graph. If an edge with the same source and destination already exists, an `AssertionError` is raised.
|
||||
* Provide the from and to node types as generics to get type hints for from and to field names.
|
||||
* @param fromNodeId The id of the source node.
|
||||
* @param fromField The field of the source node.
|
||||
* @param toNodeId The id of the destination node.
|
||||
* @param toField The field of the destination node.
|
||||
* @returns The added edge.
|
||||
* @raises `AssertionError` if an edge with the same source and destination already exists.
|
||||
*/
|
||||
addEdge<TFrom extends EdgeNodeType, TTo extends EdgeNodeType>(
|
||||
fromNodeId: string,
|
||||
fromField: EdgeFromField<TFrom>,
|
||||
toNodeId: string,
|
||||
toField: EdgeToField<TTo>
|
||||
): Edge {
|
||||
const edge = {
|
||||
source: { node_id: fromNodeId, field: fromField },
|
||||
destination: { node_id: toNodeId, field: toField },
|
||||
};
|
||||
assert(
|
||||
!this._graph.edges.some((e) => isEqual(e, edge)),
|
||||
Graph.getEdgeAlreadyExistsMsg(fromNodeId, fromField, toNodeId, toField)
|
||||
);
|
||||
this._graph.edges.push(edge);
|
||||
return edge;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get an edge from the graph. If the edge does not exist, an `AssertionError` is raised.
|
||||
* Provide the from and to node types as generics to get type hints for from and to field names.
|
||||
* @param fromNodeId The id of the source node.
|
||||
* @param fromField The field of the source node.
|
||||
* @param toNodeId The id of the destination node.
|
||||
* @param toField The field of the destination node.
|
||||
* @returns The edge.
|
||||
* @raises `AssertionError` if the edge does not exist.
|
||||
*/
|
||||
getEdge<TFrom extends EdgeNodeType, TTo extends EdgeNodeType>(
|
||||
fromNode: string,
|
||||
fromField: EdgeFromField<TFrom>,
|
||||
toNode: string,
|
||||
toField: EdgeToField<TTo>
|
||||
): Edge {
|
||||
const edge = this._graph.edges.find(
|
||||
(e) =>
|
||||
e.source.node_id === fromNode &&
|
||||
e.source.field === fromField &&
|
||||
e.destination.node_id === toNode &&
|
||||
e.destination.field === toField
|
||||
);
|
||||
assert(edge !== undefined, Graph.getEdgeNotFoundMsg(fromNode, fromField, toNode, toField));
|
||||
return edge;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get an edge from the graph, or undefined if it doesn't exist.
|
||||
* Provide the from and to node types as generics to get type hints for from and to field names.
|
||||
* @param fromNodeId The id of the source node.
|
||||
* @param fromField The field of the source node.
|
||||
* @param toNodeId The id of the destination node.
|
||||
* @param toField The field of the destination node.
|
||||
* @returns The edge, or undefined if it doesn't exist.
|
||||
*/
|
||||
getEdgeSafe<TFrom extends EdgeNodeType, TTo extends EdgeNodeType>(
|
||||
fromNode: string,
|
||||
fromField: EdgeFromField<TFrom>,
|
||||
toNode: string,
|
||||
toField: EdgeToField<TTo>
|
||||
): Edge | undefined {
|
||||
try {
|
||||
return this.getEdge(fromNode, fromField, toNode, toField);
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a graph has an edge.
|
||||
* Provide the from and to node types as generics to get type hints for from and to field names.
|
||||
* @param fromNodeId The id of the source node.
|
||||
* @param fromField The field of the source node.
|
||||
* @param toNodeId The id of the destination node.
|
||||
* @param toField The field of the destination node.
|
||||
* @returns Whether the graph has the edge.
|
||||
*/
|
||||
|
||||
hasEdge<TFrom extends EdgeNodeType, TTo extends EdgeNodeType>(
|
||||
fromNode: string,
|
||||
fromField: EdgeFromField<TFrom>,
|
||||
toNode: string,
|
||||
toField: EdgeToField<TTo>
|
||||
): boolean {
|
||||
try {
|
||||
this.getEdge(fromNode, fromField, toNode, toField);
|
||||
return true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all edges from a node. If `fromField` is provided, only edges from that field are returned.
|
||||
* Provide the from node type as a generic to get type hints for from field names.
|
||||
* @param fromNodeId The id of the source node.
|
||||
* @param fromField The field of the source node (optional).
|
||||
* @returns The edges.
|
||||
*/
|
||||
getEdgesFrom<TFrom extends EdgeNodeType>(fromNodeId: string, fromField?: EdgeFromField<TFrom>): Edge[] {
|
||||
let edges = this._graph.edges.filter((edge) => edge.source.node_id === fromNodeId);
|
||||
if (fromField) {
|
||||
edges = edges.filter((edge) => edge.source.field === fromField);
|
||||
}
|
||||
return edges;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all edges to a node. If `toField` is provided, only edges to that field are returned.
|
||||
* Provide the to node type as a generic to get type hints for to field names.
|
||||
* @param toNodeId The id of the destination node.
|
||||
* @param toField The field of the destination node (optional).
|
||||
* @returns The edges.
|
||||
*/
|
||||
getEdgesTo<TTo extends EdgeNodeType>(toNodeId: string, toField?: EdgeToField<TTo>): Edge[] {
|
||||
let edges = this._graph.edges.filter((edge) => edge.destination.node_id === toNodeId);
|
||||
if (toField) {
|
||||
edges = edges.filter((edge) => edge.destination.field === toField);
|
||||
}
|
||||
return edges;
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete _all_ matching edges from the graph. Uses _.isEqual for comparison.
|
||||
* @param edge The edge to delete
|
||||
*/
|
||||
private _deleteEdge(edge: Edge): void {
|
||||
this._graph.edges = this._graph.edges.filter((e) => !isEqual(e, edge));
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete all edges to a node. If `toField` is provided, only edges to that field are deleted.
|
||||
* Provide the to node type as a generic to get type hints for to field names.
|
||||
* @param toNodeId The id of the destination node.
|
||||
* @param toField The field of the destination node (optional).
|
||||
*/
|
||||
deleteEdgesTo<TTo extends EdgeNodeType>(toNodeId: string, toField?: EdgeToField<TTo>): void {
|
||||
for (const edge of this.getEdgesTo<TTo>(toNodeId, toField)) {
|
||||
this._deleteEdge(edge);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete all edges from a node. If `fromField` is provided, only edges from that field are deleted.
|
||||
* Provide the from node type as a generic to get type hints for from field names.
|
||||
* @param toNodeId The id of the source node.
|
||||
* @param toField The field of the source node (optional).
|
||||
*/
|
||||
deleteEdgesFrom<TFrom extends EdgeNodeType>(fromNodeId: string, fromField?: EdgeFromField<TFrom>): void {
|
||||
for (const edge of this.getEdgesFrom<TFrom>(fromNodeId, fromField)) {
|
||||
this._deleteEdge(edge);
|
||||
}
|
||||
}
|
||||
//#endregion
|
||||
|
||||
//#region Graph Ops
|
||||
|
||||
/**
|
||||
* Validate the graph. Checks that all edges have valid source and destination nodes.
|
||||
* TODO(psyche): Add more validation checks - cycles, valid invocation types, etc.
|
||||
* @raises `AssertionError` if an edge has an invalid source or destination node.
|
||||
*/
|
||||
validate(): void {
|
||||
for (const edge of this._graph.edges) {
|
||||
this.getNode(edge.source.node_id);
|
||||
this.getNode(edge.destination.node_id);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the graph after validating it.
|
||||
* @returns The graph.
|
||||
* @raises `AssertionError` if the graph is invalid.
|
||||
*/
|
||||
getGraph(): GraphType {
|
||||
this.validate();
|
||||
return this._graph;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets the graph without validating it.
|
||||
* @returns The graph.
|
||||
*/
|
||||
getGraphSafe(): GraphType {
|
||||
return this._graph;
|
||||
}
|
||||
//#endregion
|
||||
|
||||
//#region Util
|
||||
|
||||
static getNodeNotFoundMsg(id: string): string {
|
||||
return `Node ${id} not found`;
|
||||
}
|
||||
|
||||
static getNodeNotOfTypeMsg(node: AnyInvocation, expectedType: InvocationType): string {
|
||||
return `Node ${node.id} is not of type ${expectedType}: ${node.type}`;
|
||||
}
|
||||
|
||||
static getNodeAlreadyExistsMsg(id: string): string {
|
||||
return `Node ${id} already exists`;
|
||||
}
|
||||
|
||||
static getEdgeNotFoundMsg(fromNodeId: string, fromField: string, toNodeId: string, toField: string) {
|
||||
return `Edge from ${fromNodeId}.${fromField} to ${toNodeId}.${toField} not found`;
|
||||
}
|
||||
|
||||
static getEdgeAlreadyExistsMsg(fromNodeId: string, fromField: string, toNodeId: string, toField: string) {
|
||||
return `Edge from ${fromNodeId}.${fromField} to ${toNodeId}.${toField} already exists`;
|
||||
}
|
||||
|
||||
static uuid = uuidv4;
|
||||
//#endregion
|
||||
}
|
Loading…
Reference in New Issue
Block a user