feat(ui): refine graph building util

Simpler types and API surface.
This commit is contained in:
psychedelicious 2024-05-05 16:14:54 +10:00
parent 4020bf47e2
commit 8f6078d007
5 changed files with 189 additions and 278 deletions

View File

@ -69,63 +69,20 @@ describe('Graph', () => {
}); });
}); });
describe('updateNode', () => {
it('should update the node with the provided id', () => {
const g = new Graph();
const node: Invocation<'add'> = {
id: 'test-node',
type: 'add',
a: 1,
};
g.addNode(node);
const updatedNode = g.updateNode('test-node', 'add', {
a: 2,
});
expect(g.getNode('test-node', 'add').a).toBe(2);
expect(node).toBe(updatedNode);
});
it('should throw an error if the node is not found', () => {
expect(() => new Graph().updateNode('not-found', 'add', {})).toThrowError(AssertionError);
});
it('should throw an error if the node is found but has the wrong type', () => {
const g = new Graph();
g.addNode({
id: 'test-node',
type: 'add',
a: 1,
});
expect(() => g.updateNode('test-node', 'sub', {})).toThrowError(AssertionError);
});
it('should infer types correctly when `type` is omitted', () => {
const g = new Graph();
g.addNode({
id: 'test-node',
type: 'add',
a: 1,
});
const updatedNode = g.updateNode('test-node', 'add', {
a: 2,
});
assert(is<Invocation<'add'>>(updatedNode));
});
it('should infer types correctly when `type` is provided', () => {
const g = new Graph();
g.addNode({
id: 'test-node',
type: 'add',
a: 1,
});
const updatedNode = g.updateNode('test-node', 'add', {
a: 2,
});
assert(is<Invocation<'add'>>(updatedNode));
});
});
describe('addEdge', () => { describe('addEdge', () => {
const add: Invocation<'add'> = {
id: 'from-node',
type: 'add',
};
const sub: Invocation<'sub'> = {
id: 'to-node',
type: 'sub',
};
it('should add an edge to the graph with the provided values', () => { it('should add an edge to the graph with the provided values', () => {
const g = new Graph(); const g = new Graph();
g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b'); g.addNode(add);
g.addNode(sub);
g.addEdge(add, 'value', sub, 'b');
expect(g._graph.edges.length).toBe(1); expect(g._graph.edges.length).toBe(1);
expect(g._graph.edges[0]).toEqual({ expect(g._graph.edges[0]).toEqual({
source: { node_id: 'from-node', field: 'value' }, source: { node_id: 'from-node', field: 'value' },
@ -134,19 +91,19 @@ describe('Graph', () => {
}); });
it('should throw an error if the edge already exists', () => { it('should throw an error if the edge already exists', () => {
const g = new Graph(); const g = new Graph();
g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b'); g.addEdge(add, 'value', sub, 'b');
expect(() => g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b')).toThrowError(AssertionError); expect(() => g.addEdge(add, 'value', sub, 'b')).toThrowError(AssertionError);
}); });
it('should infer field names', () => { it('should infer field names', () => {
const g = new Graph(); const g = new Graph();
// @ts-expect-error The first field must be a valid output field of the first type arg // @ts-expect-error The first field must be a valid output field of the first type arg
g.addEdge<'add', 'sub'>('from-node', 'not-a-valid-field', 'to-node', 'a'); g.addEdge(add, 'not-a-valid-field', add, 'a');
// @ts-expect-error The second field must be a valid input field of the second type arg // @ts-expect-error The second field must be a valid input field of the second type arg
g.addEdge<'add', 'sub'>('from-node-2', 'value', 'to-node-2', 'not-a-valid-field'); g.addEdge(add, 'value', sub, 'not-a-valid-field');
// @ts-expect-error The first field must be any valid output field // @ts-expect-error The first field must be any valid output field
g.addEdge('from-node-3', 'not-a-valid-field', 'to-node-3', 'a'); g.addEdge(add, 'not-a-valid-field', sub, 'a');
// @ts-expect-error The second field must be any valid input field // @ts-expect-error The second field must be any valid input field
g.addEdge('from-node-4', 'clip', 'to-node-4', 'not-a-valid-field'); g.addEdge(add, 'clip', sub, 'not-a-valid-field');
}); });
}); });
@ -161,38 +118,9 @@ describe('Graph', () => {
const n = g.getNode('test-node'); const n = g.getNode('test-node');
expect(n).toBe(node); expect(n).toBe(node);
}); });
it('should return the node with the provided id and type', () => {
const n = g.getNode('test-node', 'add');
expect(n).toBe(node);
assert(is<Invocation<'add'>>(node));
});
it('should throw an error if the node is not found', () => { it('should throw an error if the node is not found', () => {
expect(() => g.getNode('not-found')).toThrowError(AssertionError); expect(() => g.getNode('not-found')).toThrowError(AssertionError);
}); });
it('should throw an error if the node is found but has the wrong type', () => {
expect(() => g.getNode('test-node', 'sub')).toThrowError(AssertionError);
});
});
describe('getNodeSafe', () => {
const g = new Graph();
const node = g.addNode({
id: 'test-node',
type: 'add',
});
it('should return the node if it is found', () => {
expect(g.getNodeSafe('test-node')).toBe(node);
});
it('should return the node if it is found with the provided type', () => {
expect(g.getNodeSafe('test-node')).toBe(node);
assert(is<Invocation<'add'>>(node));
});
it("should return undefined if the node isn't found", () => {
expect(g.getNodeSafe('not-found')).toBeUndefined();
});
it('should return undefined if the node is found but has the wrong type', () => {
expect(g.getNodeSafe('test-node', 'sub')).toBeUndefined();
});
}); });
describe('hasNode', () => { describe('hasNode', () => {
@ -212,40 +140,42 @@ describe('Graph', () => {
describe('getEdge', () => { describe('getEdge', () => {
const g = new Graph(); const g = new Graph();
g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b'); const add: Invocation<'add'> = {
id: 'from-node',
type: 'add',
};
const sub: Invocation<'sub'> = {
id: 'to-node',
type: 'sub',
};
g.addEdge(add, 'value', sub, 'b');
it('should return the edge with the provided values', () => { it('should return the edge with the provided values', () => {
expect(g.getEdge('from-node', 'value', 'to-node', 'b')).toEqual({ expect(g.getEdge(add, 'value', sub, 'b')).toEqual({
source: { node_id: 'from-node', field: 'value' }, source: { node_id: 'from-node', field: 'value' },
destination: { node_id: 'to-node', field: 'b' }, destination: { node_id: 'to-node', field: 'b' },
}); });
}); });
it('should throw an error if the edge is not found', () => { it('should throw an error if the edge is not found', () => {
expect(() => g.getEdge('from-node', 'value', 'to-node', 'a')).toThrowError(AssertionError); expect(() => g.getEdge(add, 'value', sub, 'a')).toThrowError(AssertionError);
});
});
describe('getEdgeSafe', () => {
const g = new Graph();
g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b');
it('should return the edge if it is found', () => {
expect(g.getEdgeSafe('from-node', 'value', 'to-node', 'b')).toEqual({
source: { node_id: 'from-node', field: 'value' },
destination: { node_id: 'to-node', field: 'b' },
});
});
it('should return undefined if the edge is not found', () => {
expect(g.getEdgeSafe('from-node', 'value', 'to-node', 'a')).toBeUndefined();
}); });
}); });
describe('hasEdge', () => { describe('hasEdge', () => {
const g = new Graph(); const g = new Graph();
g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b'); const add: Invocation<'add'> = {
id: 'from-node',
type: 'add',
};
const sub: Invocation<'sub'> = {
id: 'to-node',
type: 'sub',
};
g.addEdge(add, 'value', sub, 'b');
it('should return true if the edge is in the graph', () => { it('should return true if the edge is in the graph', () => {
expect(g.hasEdge('from-node', 'value', 'to-node', 'b')).toBe(true); expect(g.hasEdge(add, 'value', sub, 'b')).toBe(true);
}); });
it('should return false if the edge is not in the graph', () => { it('should return false if the edge is not in the graph', () => {
expect(g.hasEdge('from-node', 'value', 'to-node', 'a')).toBe(false); expect(g.hasEdge(add, 'value', sub, 'a')).toBe(false);
}); });
}); });
@ -256,7 +186,15 @@ describe('Graph', () => {
}); });
it('should raise an error if the graph is invalid', () => { it('should raise an error if the graph is invalid', () => {
const g = new Graph(); const g = new Graph();
g.addEdge('from-node', 'value', 'to-node', 'b'); const add: Invocation<'add'> = {
id: 'from-node',
type: 'add',
};
const sub: Invocation<'sub'> = {
id: 'to-node',
type: 'sub',
};
g.addEdge(add, 'value', sub, 'b');
expect(() => g.getGraph()).toThrowError(AssertionError); expect(() => g.getGraph()).toThrowError(AssertionError);
}); });
}); });
@ -264,7 +202,15 @@ describe('Graph', () => {
describe('getGraphSafe', () => { describe('getGraphSafe', () => {
it('should return the graph even if it is invalid', () => { it('should return the graph even if it is invalid', () => {
const g = new Graph(); const g = new Graph();
g.addEdge('from-node', 'value', 'to-node', 'b'); const add: Invocation<'add'> = {
id: 'from-node',
type: 'add',
};
const sub: Invocation<'sub'> = {
id: 'to-node',
type: 'sub',
};
g.addEdge(add, 'value', sub, 'b');
expect(g.getGraphSafe()).toBe(g._graph); expect(g.getGraphSafe()).toBe(g._graph);
}); });
}); });
@ -276,8 +222,16 @@ describe('Graph', () => {
}); });
it('should throw an error if the graph is invalid', () => { it('should throw an error if the graph is invalid', () => {
const g = new Graph(); const g = new Graph();
const add: Invocation<'add'> = {
id: 'from-node',
type: 'add',
};
const sub: Invocation<'sub'> = {
id: 'to-node',
type: 'sub',
};
// edge from nowhere to nowhere // edge from nowhere to nowhere
g.addEdge('from-node', 'value', 'to-node', 'b'); g.addEdge(add, 'value', sub, 'b');
expect(() => g.validate()).toThrowError(AssertionError); expect(() => g.validate()).toThrowError(AssertionError);
}); });
}); });
@ -304,34 +258,34 @@ describe('Graph', () => {
id: 'n5', id: 'n5',
type: 'add', type: 'add',
}); });
const e1 = g.addEdge<'add', 'add'>(n1.id, 'value', n3.id, 'a'); const e1 = g.addEdge(n1, 'value', n3, 'a');
const e2 = g.addEdge<'alpha_mask_to_tensor', 'add'>(n2.id, 'height', n3.id, 'b'); const e2 = g.addEdge(n2, 'height', n3, 'b');
const e3 = g.addEdge<'add', 'add'>(n3.id, 'value', n4.id, 'a'); const e3 = g.addEdge(n3, 'value', n4, 'a');
const e4 = g.addEdge<'add', 'add'>(n3.id, 'value', n5.id, 'b'); const e4 = g.addEdge(n3, 'value', n5, 'b');
describe('getEdgesFrom', () => { describe('getEdgesFrom', () => {
it('should return the edges that start at the provided node', () => { it('should return the edges that start at the provided node', () => {
expect(g.getEdgesFrom(n3.id)).toEqual([e3, e4]); expect(g.getEdgesFrom(n3)).toEqual([e3, e4]);
}); });
it('should return the edges that start at the provided node and have the provided field', () => { it('should return the edges that start at the provided node and have the provided field', () => {
expect(g.getEdgesFrom(n2.id, 'height')).toEqual([e2]); expect(g.getEdgesFrom(n2, 'height')).toEqual([e2]);
}); });
}); });
describe('getEdgesTo', () => { describe('getEdgesTo', () => {
it('should return the edges that end at the provided node', () => { it('should return the edges that end at the provided node', () => {
expect(g.getEdgesTo(n3.id)).toEqual([e1, e2]); expect(g.getEdgesTo(n3)).toEqual([e1, e2]);
}); });
it('should return the edges that end at the provided node and have the provided field', () => { it('should return the edges that end at the provided node and have the provided field', () => {
expect(g.getEdgesTo(n3.id, 'b')).toEqual([e2]); expect(g.getEdgesTo(n3, 'b')).toEqual([e2]);
}); });
}); });
describe('getIncomers', () => { describe('getIncomers', () => {
it('should return the nodes that have an edge to the provided node', () => { it('should return the nodes that have an edge to the provided node', () => {
expect(g.getIncomers(n3.id)).toEqual([n1, n2]); expect(g.getIncomers(n3)).toEqual([n1, n2]);
}); });
}); });
describe('getOutgoers', () => { describe('getOutgoers', () => {
it('should return the nodes that the provided node has an edge to', () => { it('should return the nodes that the provided node has an edge to', () => {
expect(g.getOutgoers(n3.id)).toEqual([n4, n5]); expect(g.getOutgoers(n3)).toEqual([n4, n5]);
}); });
}); });
}); });

View File

@ -3,31 +3,26 @@ import type {
AnyInvocation, AnyInvocation,
AnyInvocationInputField, AnyInvocationInputField,
AnyInvocationOutputField, AnyInvocationOutputField,
InputFields,
Invocation, Invocation,
InvocationInputFields,
InvocationOutputFields,
InvocationType, InvocationType,
S, OutputFields,
} from 'services/api/types'; } from 'services/api/types';
import type { O } from 'ts-toolbelt';
import { assert } from 'tsafe'; import { assert } from 'tsafe';
import { v4 as uuidv4 } from 'uuid'; import { v4 as uuidv4 } from 'uuid';
type GraphType = O.NonNullable<O.Required<S['Graph']>>; type Edge = {
type Edge = GraphType['edges'][number]; source: {
type Never = Record<string, never>; node_id: string;
field: AnyInvocationOutputField;
};
destination: {
node_id: string;
field: AnyInvocationInputField;
};
};
// The `core_metadata` node has very lax types, it accepts arbitrary field names. It must be excluded from edge utils type GraphType = { id: string; nodes: Record<string, AnyInvocation>; edges: Edge[] };
// to preview their types from being widened from a union of valid field names to `string | number | symbol`.
type EdgeNodeType = Exclude<InvocationType, 'core_metadata'>;
type EdgeFromField<TFrom extends EdgeNodeType | Never = Never> = TFrom extends EdgeNodeType
? InvocationOutputFields<TFrom>
: AnyInvocationOutputField;
type EdgeToField<TTo extends EdgeNodeType | Never = Never> = TTo extends EdgeNodeType
? InvocationInputFields<TTo>
: AnyInvocationInputField;
export class Graph { export class Graph {
_graph: GraphType; _graph: GraphType;
@ -64,45 +59,12 @@ export class Graph {
/** /**
* Gets a node from the graph. * Gets a node from the graph.
* @param id The id of the node to get. * @param id The id of the node to get.
* @param type The type of the node to get. If provided, the retrieved node is guaranteed to be of this type.
* @returns The node. * @returns The node.
* @raises `AssertionError` if the node does not exist or if a `type` is provided but the node is not of the expected type. * @raises `AssertionError` if the node does not exist or if a `type` is provided but the node is not of the expected type.
*/ */
getNode<T extends InvocationType>(id: string, type?: T): Invocation<T> { getNode(id: string): AnyInvocation {
const node = this._graph.nodes[id]; const node = this._graph.nodes[id];
assert(node !== undefined, Graph.getNodeNotFoundMsg(id)); assert(node !== undefined, Graph.getNodeNotFoundMsg(id));
if (type) {
assert(node.type === type, Graph.getNodeNotOfTypeMsg(node, type));
}
// We just asserted that the node type is correct, this is OK to cast
return node as Invocation<T>;
}
/**
* Gets a node from the graph without raising an error if the node does not exist or is not of the expected type.
* @param id The id of the node to get.
* @param type The type of the node to get. If provided, node is guaranteed to be of this type.
* @returns The node, if it exists and is of the correct type. Otherwise, `undefined`.
*/
getNodeSafe<T extends InvocationType>(id: string, type?: T): Invocation<T> | undefined {
try {
return this.getNode(id, type);
} catch {
return undefined;
}
}
/**
* Update a node in the graph. Properties are shallow-copied from `updates` to the node.
* @param id The id of the node to update.
* @param type The type of the node to update. If provided, node is guaranteed to be of this type.
* @param updates The fields to update on the node.
* @returns The updated node.
* @raises `AssertionError` if the node does not exist or its type doesn't match.
*/
updateNode<T extends InvocationType>(id: string, type: T, updates: Partial<Invocation<T>>): Invocation<T> {
const node = this.getNode(id, type);
Object.assign(node, updates);
return node; return node;
} }
@ -125,8 +87,8 @@ export class Graph {
* @returns The incoming nodes. * @returns The incoming nodes.
* @raises `AssertionError` if the node does not exist. * @raises `AssertionError` if the node does not exist.
*/ */
getIncomers(nodeId: string): AnyInvocation[] { getIncomers(node: AnyInvocation): AnyInvocation[] {
return this.getEdgesTo(nodeId).map((edge) => this.getNode(edge.source.node_id)); return this.getEdgesTo(node).map((edge) => this.getNode(edge.source.node_id));
} }
/** /**
@ -135,8 +97,8 @@ export class Graph {
* @returns The outgoing nodes. * @returns The outgoing nodes.
* @raises `AssertionError` if the node does not exist. * @raises `AssertionError` if the node does not exist.
*/ */
getOutgoers(nodeId: string): AnyInvocation[] { getOutgoers(node: AnyInvocation): AnyInvocation[] {
return this.getEdgesFrom(nodeId).map((edge) => this.getNode(edge.destination.node_id)); return this.getEdgesFrom(node).map((edge) => this.getNode(edge.destination.node_id));
} }
//#endregion //#endregion
@ -144,28 +106,26 @@ export class Graph {
/** /**
* Add an edge to the graph. If an edge with the same source and destination already exists, an `AssertionError` is raised. * Add an edge to the graph. If an edge with the same source and destination already exists, an `AssertionError` is raised.
* Provide the from and to node types as generics to get type hints for from and to field names. * If providing node ids, provide the from and to node types as generics to get type hints for from and to field names.
* @param fromNodeId The id of the source node. * @param fromNode The source node or id of the source node.
* @param fromField The field of the source node. * @param fromField The field of the source node.
* @param toNodeId The id of the destination node. * @param toNode The source node or id of the destination node.
* @param toField The field of the destination node. * @param toField The field of the destination node.
* @returns The added edge. * @returns The added edge.
* @raises `AssertionError` if an edge with the same source and destination already exists. * @raises `AssertionError` if an edge with the same source and destination already exists.
*/ */
addEdge<TFrom extends EdgeNodeType, TTo extends EdgeNodeType>( addEdge<TFrom extends AnyInvocation, TTo extends AnyInvocation>(
fromNodeId: string, fromNode: TFrom,
fromField: EdgeFromField<TFrom>, fromField: OutputFields<TFrom>,
toNodeId: string, toNode: TTo,
toField: EdgeToField<TTo> toField: InputFields<TTo>
): Edge { ): Edge {
const edge = { const edge: Edge = {
source: { node_id: fromNodeId, field: fromField }, source: { node_id: fromNode.id, field: fromField },
destination: { node_id: toNodeId, field: toField }, destination: { node_id: toNode.id, field: toField },
}; };
assert( const edgeAlreadyExists = this._graph.edges.some((e) => isEqual(e, edge));
!this._graph.edges.some((e) => isEqual(e, edge)), assert(!edgeAlreadyExists, Graph.getEdgeAlreadyExistsMsg(fromNode.id, fromField, toNode.id, toField));
Graph.getEdgeAlreadyExistsMsg(fromNodeId, fromField, toNodeId, toField)
);
this._graph.edges.push(edge); this._graph.edges.push(edge);
return edge; return edge;
} }
@ -180,45 +140,23 @@ export class Graph {
* @returns The edge. * @returns The edge.
* @raises `AssertionError` if the edge does not exist. * @raises `AssertionError` if the edge does not exist.
*/ */
getEdge<TFrom extends EdgeNodeType, TTo extends EdgeNodeType>( getEdge<TFrom extends AnyInvocation, TTo extends AnyInvocation>(
fromNode: string, fromNode: TFrom,
fromField: EdgeFromField<TFrom>, fromField: OutputFields<TFrom>,
toNode: string, toNode: TTo,
toField: EdgeToField<TTo> toField: InputFields<TTo>
): Edge { ): Edge {
const edge = this._graph.edges.find( const edge = this._graph.edges.find(
(e) => (e) =>
e.source.node_id === fromNode && e.source.node_id === fromNode.id &&
e.source.field === fromField && e.source.field === fromField &&
e.destination.node_id === toNode && e.destination.node_id === toNode.id &&
e.destination.field === toField e.destination.field === toField
); );
assert(edge !== undefined, Graph.getEdgeNotFoundMsg(fromNode, fromField, toNode, toField)); assert(edge !== undefined, Graph.getEdgeNotFoundMsg(fromNode.id, fromField, toNode.id, toField));
return edge; return edge;
} }
/**
* Get an edge from the graph, or undefined if it doesn't exist.
* Provide the from and to node types as generics to get type hints for from and to field names.
* @param fromNodeId The id of the source node.
* @param fromField The field of the source node.
* @param toNodeId The id of the destination node.
* @param toField The field of the destination node.
* @returns The edge, or undefined if it doesn't exist.
*/
getEdgeSafe<TFrom extends EdgeNodeType, TTo extends EdgeNodeType>(
fromNode: string,
fromField: EdgeFromField<TFrom>,
toNode: string,
toField: EdgeToField<TTo>
): Edge | undefined {
try {
return this.getEdge(fromNode, fromField, toNode, toField);
} catch {
return undefined;
}
}
/** /**
* Check if a graph has an edge. * Check if a graph has an edge.
* Provide the from and to node types as generics to get type hints for from and to field names. * Provide the from and to node types as generics to get type hints for from and to field names.
@ -229,11 +167,11 @@ export class Graph {
* @returns Whether the graph has the edge. * @returns Whether the graph has the edge.
*/ */
hasEdge<TFrom extends EdgeNodeType, TTo extends EdgeNodeType>( hasEdge<TFrom extends AnyInvocation, TTo extends AnyInvocation>(
fromNode: string, fromNode: TFrom,
fromField: EdgeFromField<TFrom>, fromField: OutputFields<TFrom>,
toNode: string, toNode: TTo,
toField: EdgeToField<TTo> toField: InputFields<TTo>
): boolean { ): boolean {
try { try {
this.getEdge(fromNode, fromField, toNode, toField); this.getEdge(fromNode, fromField, toNode, toField);
@ -250,8 +188,8 @@ export class Graph {
* @param fromField The field of the source node (optional). * @param fromField The field of the source node (optional).
* @returns The edges. * @returns The edges.
*/ */
getEdgesFrom<TFrom extends EdgeNodeType>(fromNodeId: string, fromField?: EdgeFromField<TFrom>): Edge[] { getEdgesFrom<T extends AnyInvocation>(fromNode: T, fromField?: OutputFields<T>): Edge[] {
let edges = this._graph.edges.filter((edge) => edge.source.node_id === fromNodeId); let edges = this._graph.edges.filter((edge) => edge.source.node_id === fromNode.id);
if (fromField) { if (fromField) {
edges = edges.filter((edge) => edge.source.field === fromField); edges = edges.filter((edge) => edge.source.field === fromField);
} }
@ -265,8 +203,8 @@ export class Graph {
* @param toField The field of the destination node (optional). * @param toField The field of the destination node (optional).
* @returns The edges. * @returns The edges.
*/ */
getEdgesTo<TTo extends EdgeNodeType>(toNodeId: string, toField?: EdgeToField<TTo>): Edge[] { getEdgesTo<T extends AnyInvocation>(toNode: T, toField?: InputFields<T>): Edge[] {
let edges = this._graph.edges.filter((edge) => edge.destination.node_id === toNodeId); let edges = this._graph.edges.filter((edge) => edge.destination.node_id === toNode.id);
if (toField) { if (toField) {
edges = edges.filter((edge) => edge.destination.field === toField); edges = edges.filter((edge) => edge.destination.field === toField);
} }
@ -284,11 +222,11 @@ export class Graph {
/** /**
* Delete all edges to a node. If `toField` is provided, only edges to that field are deleted. * Delete all edges to a node. If `toField` is provided, only edges to that field are deleted.
* Provide the to node type as a generic to get type hints for to field names. * Provide the to node type as a generic to get type hints for to field names.
* @param toNodeId The id of the destination node. * @param toNode The destination node.
* @param toField The field of the destination node (optional). * @param toField The field of the destination node (optional).
*/ */
deleteEdgesTo<TTo extends EdgeNodeType>(toNodeId: string, toField?: EdgeToField<TTo>): void { deleteEdgesTo<T extends AnyInvocation>(toNode: T, toField?: InputFields<T>): void {
for (const edge of this.getEdgesTo<TTo>(toNodeId, toField)) { for (const edge of this.getEdgesTo(toNode, toField)) {
this._deleteEdge(edge); this._deleteEdge(edge);
} }
} }
@ -299,8 +237,8 @@ export class Graph {
* @param toNodeId The id of the source node. * @param toNodeId The id of the source node.
* @param toField The field of the source node (optional). * @param toField The field of the source node (optional).
*/ */
deleteEdgesFrom<TFrom extends EdgeNodeType>(fromNodeId: string, fromField?: EdgeFromField<TFrom>): void { deleteEdgesFrom<T extends AnyInvocation>(fromNode: T, fromField?: OutputFields<T>): void {
for (const edge of this.getEdgesFrom<TFrom>(fromNodeId, fromField)) { for (const edge of this.getEdgesFrom(fromNode, fromField)) {
this._deleteEdge(edge); this._deleteEdge(edge);
} }
} }

View File

@ -10,6 +10,7 @@ describe('MetadataUtil', () => {
describe('getNode', () => { describe('getNode', () => {
it('should return the metadata node if one exists', () => { it('should return the metadata node if one exists', () => {
const g = new Graph(); const g = new Graph();
// @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
const metadataNode = g.addNode({ id: MetadataUtil.metadataNodeId, type: 'core_metadata' }); const metadataNode = g.addNode({ id: MetadataUtil.metadataNodeId, type: 'core_metadata' });
expect(MetadataUtil.getNode(g)).toEqual(metadataNode); expect(MetadataUtil.getNode(g)).toEqual(metadataNode);
}); });
@ -56,14 +57,16 @@ describe('MetadataUtil', () => {
it('should add an edge from from metadata to the receiving node', () => { it('should add an edge from from metadata to the receiving node', () => {
const n = g.addNode({ id: 'my-node', type: 'img_resize' }); const n = g.addNode({ id: 'my-node', type: 'img_resize' });
MetadataUtil.add(g, { foo: 'bar' }); MetadataUtil.add(g, { foo: 'bar' });
MetadataUtil.setMetadataReceivingNode(g, n.id); MetadataUtil.setMetadataReceivingNode(g, n);
expect(g.hasEdge(MetadataUtil.metadataNodeId, 'metadata', n.id, 'metadata')).toBe(true); // @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
expect(g.hasEdge(MetadataUtil.getNode(g), 'metadata', n, 'metadata')).toBe(true);
}); });
it('should remove existing metadata edges', () => { it('should remove existing metadata edges', () => {
const n2 = g.addNode({ id: 'my-other-node', type: 'img_resize' }); const n2 = g.addNode({ id: 'my-other-node', type: 'img_resize' });
MetadataUtil.setMetadataReceivingNode(g, n2.id); MetadataUtil.setMetadataReceivingNode(g, n2);
expect(g.getIncomers(n2.id).length).toBe(1); expect(g.getIncomers(n2).length).toBe(1);
expect(g.hasEdge(MetadataUtil.metadataNodeId, 'metadata', n2.id, 'metadata')).toBe(true); // @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
expect(g.hasEdge(MetadataUtil.getNode(g), 'metadata', n2, 'metadata')).toBe(true);
}); });
}); });

View File

@ -1,34 +1,50 @@
import type { ModelIdentifierField } from 'features/nodes/types/common'; import type { ModelIdentifierField } from 'features/nodes/types/common';
import { METADATA } from 'features/nodes/util/graph/constants'; import { METADATA } from 'features/nodes/util/graph/constants';
import { isString, unset } from 'lodash-es'; import { isString, unset } from 'lodash-es';
import type { AnyModelConfig, Invocation } from 'services/api/types'; import type {
AnyInvocation,
AnyInvocationIncMetadata,
AnyModelConfig,
CoreMetadataInvocation,
S,
} from 'services/api/types';
import { assert } from 'tsafe';
import type { Graph } from './Graph'; import type { Graph } from './Graph';
const isCoreMetadata = (node: S['Graph']['nodes'][string]): node is CoreMetadataInvocation =>
node.type === 'core_metadata';
export class MetadataUtil { export class MetadataUtil {
static metadataNodeId = METADATA; static metadataNodeId = METADATA;
static getNode(graph: Graph): Invocation<'core_metadata'> { static getNode(g: Graph): CoreMetadataInvocation {
return graph.getNode(this.metadataNodeId, 'core_metadata'); const node = g.getNode(this.metadataNodeId) as AnyInvocationIncMetadata;
assert(isCoreMetadata(node));
return node;
} }
static add(graph: Graph, metadata: Partial<Invocation<'core_metadata'>>): Invocation<'core_metadata'> { static add(g: Graph, metadata: Partial<CoreMetadataInvocation>): CoreMetadataInvocation {
const metadataNode = graph.getNodeSafe(this.metadataNodeId, 'core_metadata'); try {
if (!metadataNode) { const node = g.getNode(this.metadataNodeId) as AnyInvocationIncMetadata;
return graph.addNode({ assert(isCoreMetadata(node));
Object.assign(node, metadata);
return node;
} catch {
const metadataNode: CoreMetadataInvocation = {
id: this.metadataNodeId, id: this.metadataNodeId,
type: 'core_metadata', type: 'core_metadata',
...metadata, ...metadata,
}); };
} else { // @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
return graph.updateNode(this.metadataNodeId, 'core_metadata', metadata); return g.addNode(metadataNode);
} }
} }
static remove(graph: Graph, key: string): Invocation<'core_metadata'>; static remove(g: Graph, key: string): CoreMetadataInvocation;
static remove(graph: Graph, keys: string[]): Invocation<'core_metadata'>; static remove(g: Graph, keys: string[]): CoreMetadataInvocation;
static remove(graph: Graph, keyOrKeys: string | string[]): Invocation<'core_metadata'> { static remove(g: Graph, keyOrKeys: string | string[]): CoreMetadataInvocation {
const metadataNode = this.getNode(graph); const metadataNode = this.getNode(g);
if (isString(keyOrKeys)) { if (isString(keyOrKeys)) {
unset(metadataNode, keyOrKeys); unset(metadataNode, keyOrKeys);
} else { } else {
@ -39,10 +55,11 @@ export class MetadataUtil {
return metadataNode; return metadataNode;
} }
static setMetadataReceivingNode(graph: Graph, nodeId: string): void { static setMetadataReceivingNode(g: Graph, node: AnyInvocation): void {
// We need to break the rules to update metadata - `addEdge` doesn't allow `core_metadata` as a node type // @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
graph._graph.edges = graph._graph.edges.filter((edge) => edge.source.node_id !== this.metadataNodeId); g.deleteEdgesFrom(this.getNode(g));
graph.addEdge(this.metadataNodeId, 'metadata', nodeId, 'metadata'); // @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
g.addEdge(this.getNode(g), 'metadata', node, 'metadata');
} }
static getModelMetadataField({ key, hash, name, base, type }: AnyModelConfig): ModelIdentifierField { static getModelMetadataField({ key, hash, name, base, type }: AnyModelConfig): ModelIdentifierField {

View File

@ -131,30 +131,29 @@ export type WorkflowRecordListItemDTO = S['WorkflowRecordListItemDTO'];
export type KeysOfUnion<T> = T extends T ? keyof T : never; export type KeysOfUnion<T> = T extends T ? keyof T : never;
export type NonInputFields = 'id' | 'type' | 'is_intermediate' | 'use_cache'; export type AnyInvocation = Exclude<
export type NonOutputFields = 'type'; Graph['nodes'][string],
export type AnyInvocation = Graph['nodes'][string]; S['CoreMetadataInvocation'] | S['MetadataInvocation'] | S['MetadataItemInvocation'] | S['MergeMetadataInvocation']
export type AnyInvocationExcludeCoreMetata = Exclude<AnyInvocation, { type: 'core_metadata' }>; >;
export type InvocationType = AnyInvocation['type']; export type AnyInvocationIncMetadata = S['Graph']['nodes'][string];
export type InvocationTypeExcludeCoreMetadata = Exclude<InvocationType, 'core_metadata'>;
export type InvocationType = AnyInvocation['type'];
export type InvocationOutputMap = S['InvocationOutputMap']; export type InvocationOutputMap = S['InvocationOutputMap'];
export type AnyInvocationOutput = InvocationOutputMap[InvocationType]; export type AnyInvocationOutput = InvocationOutputMap[InvocationType];
export type Invocation<T extends InvocationType> = Extract<AnyInvocation, { type: T }>; export type Invocation<T extends InvocationType> = Extract<AnyInvocation, { type: T }>;
export type InvocationExcludeCoreMetadata<T extends InvocationTypeExcludeCoreMetadata> = Extract<
AnyInvocation,
{ type: T }
>;
export type InvocationInputFields<T extends InvocationTypeExcludeCoreMetadata> = Exclude<
keyof Invocation<T>,
NonInputFields
>;
export type AnyInvocationInputField = Exclude<KeysOfUnion<AnyInvocationExcludeCoreMetata>, NonInputFields>;
export type InvocationOutput<T extends InvocationType> = InvocationOutputMap[T]; export type InvocationOutput<T extends InvocationType> = InvocationOutputMap[T];
export type InvocationOutputFields<T extends InvocationType> = Exclude<keyof InvocationOutput<T>, NonOutputFields>;
export type AnyInvocationOutputField = Exclude<KeysOfUnion<AnyInvocationOutput>, NonOutputFields>; export type NonInputFields = 'id' | 'type' | 'is_intermediate' | 'use_cache' | 'board' | 'metadata';
export type AnyInvocationInputField = Exclude<KeysOfUnion<Required<AnyInvocation>>, NonInputFields>;
export type InputFields<T extends AnyInvocation> = Extract<keyof T, AnyInvocationInputField>;
export type NonOutputFields = 'type';
export type AnyInvocationOutputField = Exclude<KeysOfUnion<Required<AnyInvocationOutput>>, NonOutputFields>;
export type OutputFields<T extends AnyInvocation> = Extract<
keyof InvocationOutputMap[T['type']],
AnyInvocationOutputField
>;
// General nodes // General nodes
export type CollectInvocation = Invocation<'collect'>; export type CollectInvocation = Invocation<'collect'>;