feat(ui): refine graph building util

Simpler types and API surface.
This commit is contained in:
psychedelicious 2024-05-05 16:14:54 +10:00
parent 4020bf47e2
commit 8f6078d007
5 changed files with 189 additions and 278 deletions

View File

@ -69,63 +69,20 @@ describe('Graph', () => {
});
});
describe('updateNode', () => {
it('should update the node with the provided id', () => {
const g = new Graph();
const node: Invocation<'add'> = {
id: 'test-node',
type: 'add',
a: 1,
};
g.addNode(node);
const updatedNode = g.updateNode('test-node', 'add', {
a: 2,
});
expect(g.getNode('test-node', 'add').a).toBe(2);
expect(node).toBe(updatedNode);
});
it('should throw an error if the node is not found', () => {
expect(() => new Graph().updateNode('not-found', 'add', {})).toThrowError(AssertionError);
});
it('should throw an error if the node is found but has the wrong type', () => {
const g = new Graph();
g.addNode({
id: 'test-node',
type: 'add',
a: 1,
});
expect(() => g.updateNode('test-node', 'sub', {})).toThrowError(AssertionError);
});
it('should infer types correctly when `type` is omitted', () => {
const g = new Graph();
g.addNode({
id: 'test-node',
type: 'add',
a: 1,
});
const updatedNode = g.updateNode('test-node', 'add', {
a: 2,
});
assert(is<Invocation<'add'>>(updatedNode));
});
it('should infer types correctly when `type` is provided', () => {
const g = new Graph();
g.addNode({
id: 'test-node',
type: 'add',
a: 1,
});
const updatedNode = g.updateNode('test-node', 'add', {
a: 2,
});
assert(is<Invocation<'add'>>(updatedNode));
});
});
describe('addEdge', () => {
const add: Invocation<'add'> = {
id: 'from-node',
type: 'add',
};
const sub: Invocation<'sub'> = {
id: 'to-node',
type: 'sub',
};
it('should add an edge to the graph with the provided values', () => {
const g = new Graph();
g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b');
g.addNode(add);
g.addNode(sub);
g.addEdge(add, 'value', sub, 'b');
expect(g._graph.edges.length).toBe(1);
expect(g._graph.edges[0]).toEqual({
source: { node_id: 'from-node', field: 'value' },
@ -134,19 +91,19 @@ describe('Graph', () => {
});
it('should throw an error if the edge already exists', () => {
const g = new Graph();
g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b');
expect(() => g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b')).toThrowError(AssertionError);
g.addEdge(add, 'value', sub, 'b');
expect(() => g.addEdge(add, 'value', sub, 'b')).toThrowError(AssertionError);
});
it('should infer field names', () => {
const g = new Graph();
// @ts-expect-error The first field must be a valid output field of the first type arg
g.addEdge<'add', 'sub'>('from-node', 'not-a-valid-field', 'to-node', 'a');
g.addEdge(add, 'not-a-valid-field', add, 'a');
// @ts-expect-error The second field must be a valid input field of the second type arg
g.addEdge<'add', 'sub'>('from-node-2', 'value', 'to-node-2', 'not-a-valid-field');
g.addEdge(add, 'value', sub, 'not-a-valid-field');
// @ts-expect-error The first field must be any valid output field
g.addEdge('from-node-3', 'not-a-valid-field', 'to-node-3', 'a');
g.addEdge(add, 'not-a-valid-field', sub, 'a');
// @ts-expect-error The second field must be any valid input field
g.addEdge('from-node-4', 'clip', 'to-node-4', 'not-a-valid-field');
g.addEdge(add, 'clip', sub, 'not-a-valid-field');
});
});
@ -161,38 +118,9 @@ describe('Graph', () => {
const n = g.getNode('test-node');
expect(n).toBe(node);
});
it('should return the node with the provided id and type', () => {
const n = g.getNode('test-node', 'add');
expect(n).toBe(node);
assert(is<Invocation<'add'>>(node));
});
it('should throw an error if the node is not found', () => {
expect(() => g.getNode('not-found')).toThrowError(AssertionError);
});
it('should throw an error if the node is found but has the wrong type', () => {
expect(() => g.getNode('test-node', 'sub')).toThrowError(AssertionError);
});
});
describe('getNodeSafe', () => {
const g = new Graph();
const node = g.addNode({
id: 'test-node',
type: 'add',
});
it('should return the node if it is found', () => {
expect(g.getNodeSafe('test-node')).toBe(node);
});
it('should return the node if it is found with the provided type', () => {
expect(g.getNodeSafe('test-node')).toBe(node);
assert(is<Invocation<'add'>>(node));
});
it("should return undefined if the node isn't found", () => {
expect(g.getNodeSafe('not-found')).toBeUndefined();
});
it('should return undefined if the node is found but has the wrong type', () => {
expect(g.getNodeSafe('test-node', 'sub')).toBeUndefined();
});
});
describe('hasNode', () => {
@ -212,40 +140,42 @@ describe('Graph', () => {
describe('getEdge', () => {
const g = new Graph();
g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b');
const add: Invocation<'add'> = {
id: 'from-node',
type: 'add',
};
const sub: Invocation<'sub'> = {
id: 'to-node',
type: 'sub',
};
g.addEdge(add, 'value', sub, 'b');
it('should return the edge with the provided values', () => {
expect(g.getEdge('from-node', 'value', 'to-node', 'b')).toEqual({
expect(g.getEdge(add, 'value', sub, 'b')).toEqual({
source: { node_id: 'from-node', field: 'value' },
destination: { node_id: 'to-node', field: 'b' },
});
});
it('should throw an error if the edge is not found', () => {
expect(() => g.getEdge('from-node', 'value', 'to-node', 'a')).toThrowError(AssertionError);
});
});
describe('getEdgeSafe', () => {
const g = new Graph();
g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b');
it('should return the edge if it is found', () => {
expect(g.getEdgeSafe('from-node', 'value', 'to-node', 'b')).toEqual({
source: { node_id: 'from-node', field: 'value' },
destination: { node_id: 'to-node', field: 'b' },
});
});
it('should return undefined if the edge is not found', () => {
expect(g.getEdgeSafe('from-node', 'value', 'to-node', 'a')).toBeUndefined();
expect(() => g.getEdge(add, 'value', sub, 'a')).toThrowError(AssertionError);
});
});
describe('hasEdge', () => {
const g = new Graph();
g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b');
const add: Invocation<'add'> = {
id: 'from-node',
type: 'add',
};
const sub: Invocation<'sub'> = {
id: 'to-node',
type: 'sub',
};
g.addEdge(add, 'value', sub, 'b');
it('should return true if the edge is in the graph', () => {
expect(g.hasEdge('from-node', 'value', 'to-node', 'b')).toBe(true);
expect(g.hasEdge(add, 'value', sub, 'b')).toBe(true);
});
it('should return false if the edge is not in the graph', () => {
expect(g.hasEdge('from-node', 'value', 'to-node', 'a')).toBe(false);
expect(g.hasEdge(add, 'value', sub, 'a')).toBe(false);
});
});
@ -256,7 +186,15 @@ describe('Graph', () => {
});
it('should raise an error if the graph is invalid', () => {
const g = new Graph();
g.addEdge('from-node', 'value', 'to-node', 'b');
const add: Invocation<'add'> = {
id: 'from-node',
type: 'add',
};
const sub: Invocation<'sub'> = {
id: 'to-node',
type: 'sub',
};
g.addEdge(add, 'value', sub, 'b');
expect(() => g.getGraph()).toThrowError(AssertionError);
});
});
@ -264,7 +202,15 @@ describe('Graph', () => {
describe('getGraphSafe', () => {
it('should return the graph even if it is invalid', () => {
const g = new Graph();
g.addEdge('from-node', 'value', 'to-node', 'b');
const add: Invocation<'add'> = {
id: 'from-node',
type: 'add',
};
const sub: Invocation<'sub'> = {
id: 'to-node',
type: 'sub',
};
g.addEdge(add, 'value', sub, 'b');
expect(g.getGraphSafe()).toBe(g._graph);
});
});
@ -276,8 +222,16 @@ describe('Graph', () => {
});
it('should throw an error if the graph is invalid', () => {
const g = new Graph();
const add: Invocation<'add'> = {
id: 'from-node',
type: 'add',
};
const sub: Invocation<'sub'> = {
id: 'to-node',
type: 'sub',
};
// edge from nowhere to nowhere
g.addEdge('from-node', 'value', 'to-node', 'b');
g.addEdge(add, 'value', sub, 'b');
expect(() => g.validate()).toThrowError(AssertionError);
});
});
@ -304,34 +258,34 @@ describe('Graph', () => {
id: 'n5',
type: 'add',
});
const e1 = g.addEdge<'add', 'add'>(n1.id, 'value', n3.id, 'a');
const e2 = g.addEdge<'alpha_mask_to_tensor', 'add'>(n2.id, 'height', n3.id, 'b');
const e3 = g.addEdge<'add', 'add'>(n3.id, 'value', n4.id, 'a');
const e4 = g.addEdge<'add', 'add'>(n3.id, 'value', n5.id, 'b');
const e1 = g.addEdge(n1, 'value', n3, 'a');
const e2 = g.addEdge(n2, 'height', n3, 'b');
const e3 = g.addEdge(n3, 'value', n4, 'a');
const e4 = g.addEdge(n3, 'value', n5, 'b');
describe('getEdgesFrom', () => {
it('should return the edges that start at the provided node', () => {
expect(g.getEdgesFrom(n3.id)).toEqual([e3, e4]);
expect(g.getEdgesFrom(n3)).toEqual([e3, e4]);
});
it('should return the edges that start at the provided node and have the provided field', () => {
expect(g.getEdgesFrom(n2.id, 'height')).toEqual([e2]);
expect(g.getEdgesFrom(n2, 'height')).toEqual([e2]);
});
});
describe('getEdgesTo', () => {
it('should return the edges that end at the provided node', () => {
expect(g.getEdgesTo(n3.id)).toEqual([e1, e2]);
expect(g.getEdgesTo(n3)).toEqual([e1, e2]);
});
it('should return the edges that end at the provided node and have the provided field', () => {
expect(g.getEdgesTo(n3.id, 'b')).toEqual([e2]);
expect(g.getEdgesTo(n3, 'b')).toEqual([e2]);
});
});
describe('getIncomers', () => {
it('should return the nodes that have an edge to the provided node', () => {
expect(g.getIncomers(n3.id)).toEqual([n1, n2]);
expect(g.getIncomers(n3)).toEqual([n1, n2]);
});
});
describe('getOutgoers', () => {
it('should return the nodes that the provided node has an edge to', () => {
expect(g.getOutgoers(n3.id)).toEqual([n4, n5]);
expect(g.getOutgoers(n3)).toEqual([n4, n5]);
});
});
});

View File

@ -3,31 +3,26 @@ import type {
AnyInvocation,
AnyInvocationInputField,
AnyInvocationOutputField,
InputFields,
Invocation,
InvocationInputFields,
InvocationOutputFields,
InvocationType,
S,
OutputFields,
} from 'services/api/types';
import type { O } from 'ts-toolbelt';
import { assert } from 'tsafe';
import { v4 as uuidv4 } from 'uuid';
type GraphType = O.NonNullable<O.Required<S['Graph']>>;
type Edge = GraphType['edges'][number];
type Never = Record<string, never>;
type Edge = {
source: {
node_id: string;
field: AnyInvocationOutputField;
};
destination: {
node_id: string;
field: AnyInvocationInputField;
};
};
// The `core_metadata` node has very lax types, it accepts arbitrary field names. It must be excluded from edge utils
// to preview their types from being widened from a union of valid field names to `string | number | symbol`.
type EdgeNodeType = Exclude<InvocationType, 'core_metadata'>;
type EdgeFromField<TFrom extends EdgeNodeType | Never = Never> = TFrom extends EdgeNodeType
? InvocationOutputFields<TFrom>
: AnyInvocationOutputField;
type EdgeToField<TTo extends EdgeNodeType | Never = Never> = TTo extends EdgeNodeType
? InvocationInputFields<TTo>
: AnyInvocationInputField;
type GraphType = { id: string; nodes: Record<string, AnyInvocation>; edges: Edge[] };
export class Graph {
_graph: GraphType;
@ -64,45 +59,12 @@ export class Graph {
/**
* Gets a node from the graph.
* @param id The id of the node to get.
* @param type The type of the node to get. If provided, the retrieved node is guaranteed to be of this type.
* @returns The node.
* @raises `AssertionError` if the node does not exist or if a `type` is provided but the node is not of the expected type.
*/
getNode<T extends InvocationType>(id: string, type?: T): Invocation<T> {
getNode(id: string): AnyInvocation {
const node = this._graph.nodes[id];
assert(node !== undefined, Graph.getNodeNotFoundMsg(id));
if (type) {
assert(node.type === type, Graph.getNodeNotOfTypeMsg(node, type));
}
// We just asserted that the node type is correct, this is OK to cast
return node as Invocation<T>;
}
/**
* Gets a node from the graph without raising an error if the node does not exist or is not of the expected type.
* @param id The id of the node to get.
* @param type The type of the node to get. If provided, node is guaranteed to be of this type.
* @returns The node, if it exists and is of the correct type. Otherwise, `undefined`.
*/
getNodeSafe<T extends InvocationType>(id: string, type?: T): Invocation<T> | undefined {
try {
return this.getNode(id, type);
} catch {
return undefined;
}
}
/**
* Update a node in the graph. Properties are shallow-copied from `updates` to the node.
* @param id The id of the node to update.
* @param type The type of the node to update. If provided, node is guaranteed to be of this type.
* @param updates The fields to update on the node.
* @returns The updated node.
* @raises `AssertionError` if the node does not exist or its type doesn't match.
*/
updateNode<T extends InvocationType>(id: string, type: T, updates: Partial<Invocation<T>>): Invocation<T> {
const node = this.getNode(id, type);
Object.assign(node, updates);
return node;
}
@ -125,8 +87,8 @@ export class Graph {
* @returns The incoming nodes.
* @raises `AssertionError` if the node does not exist.
*/
getIncomers(nodeId: string): AnyInvocation[] {
return this.getEdgesTo(nodeId).map((edge) => this.getNode(edge.source.node_id));
getIncomers(node: AnyInvocation): AnyInvocation[] {
return this.getEdgesTo(node).map((edge) => this.getNode(edge.source.node_id));
}
/**
@ -135,8 +97,8 @@ export class Graph {
* @returns The outgoing nodes.
* @raises `AssertionError` if the node does not exist.
*/
getOutgoers(nodeId: string): AnyInvocation[] {
return this.getEdgesFrom(nodeId).map((edge) => this.getNode(edge.destination.node_id));
getOutgoers(node: AnyInvocation): AnyInvocation[] {
return this.getEdgesFrom(node).map((edge) => this.getNode(edge.destination.node_id));
}
//#endregion
@ -144,28 +106,26 @@ export class Graph {
/**
* Add an edge to the graph. If an edge with the same source and destination already exists, an `AssertionError` is raised.
* Provide the from and to node types as generics to get type hints for from and to field names.
* @param fromNodeId The id of the source node.
* If providing node ids, provide the from and to node types as generics to get type hints for from and to field names.
* @param fromNode The source node or id of the source node.
* @param fromField The field of the source node.
* @param toNodeId The id of the destination node.
* @param toNode The source node or id of the destination node.
* @param toField The field of the destination node.
* @returns The added edge.
* @raises `AssertionError` if an edge with the same source and destination already exists.
*/
addEdge<TFrom extends EdgeNodeType, TTo extends EdgeNodeType>(
fromNodeId: string,
fromField: EdgeFromField<TFrom>,
toNodeId: string,
toField: EdgeToField<TTo>
addEdge<TFrom extends AnyInvocation, TTo extends AnyInvocation>(
fromNode: TFrom,
fromField: OutputFields<TFrom>,
toNode: TTo,
toField: InputFields<TTo>
): Edge {
const edge = {
source: { node_id: fromNodeId, field: fromField },
destination: { node_id: toNodeId, field: toField },
const edge: Edge = {
source: { node_id: fromNode.id, field: fromField },
destination: { node_id: toNode.id, field: toField },
};
assert(
!this._graph.edges.some((e) => isEqual(e, edge)),
Graph.getEdgeAlreadyExistsMsg(fromNodeId, fromField, toNodeId, toField)
);
const edgeAlreadyExists = this._graph.edges.some((e) => isEqual(e, edge));
assert(!edgeAlreadyExists, Graph.getEdgeAlreadyExistsMsg(fromNode.id, fromField, toNode.id, toField));
this._graph.edges.push(edge);
return edge;
}
@ -180,45 +140,23 @@ export class Graph {
* @returns The edge.
* @raises `AssertionError` if the edge does not exist.
*/
getEdge<TFrom extends EdgeNodeType, TTo extends EdgeNodeType>(
fromNode: string,
fromField: EdgeFromField<TFrom>,
toNode: string,
toField: EdgeToField<TTo>
getEdge<TFrom extends AnyInvocation, TTo extends AnyInvocation>(
fromNode: TFrom,
fromField: OutputFields<TFrom>,
toNode: TTo,
toField: InputFields<TTo>
): Edge {
const edge = this._graph.edges.find(
(e) =>
e.source.node_id === fromNode &&
e.source.node_id === fromNode.id &&
e.source.field === fromField &&
e.destination.node_id === toNode &&
e.destination.node_id === toNode.id &&
e.destination.field === toField
);
assert(edge !== undefined, Graph.getEdgeNotFoundMsg(fromNode, fromField, toNode, toField));
assert(edge !== undefined, Graph.getEdgeNotFoundMsg(fromNode.id, fromField, toNode.id, toField));
return edge;
}
/**
* Get an edge from the graph, or undefined if it doesn't exist.
* Provide the from and to node types as generics to get type hints for from and to field names.
* @param fromNodeId The id of the source node.
* @param fromField The field of the source node.
* @param toNodeId The id of the destination node.
* @param toField The field of the destination node.
* @returns The edge, or undefined if it doesn't exist.
*/
getEdgeSafe<TFrom extends EdgeNodeType, TTo extends EdgeNodeType>(
fromNode: string,
fromField: EdgeFromField<TFrom>,
toNode: string,
toField: EdgeToField<TTo>
): Edge | undefined {
try {
return this.getEdge(fromNode, fromField, toNode, toField);
} catch {
return undefined;
}
}
/**
* Check if a graph has an edge.
* Provide the from and to node types as generics to get type hints for from and to field names.
@ -229,11 +167,11 @@ export class Graph {
* @returns Whether the graph has the edge.
*/
hasEdge<TFrom extends EdgeNodeType, TTo extends EdgeNodeType>(
fromNode: string,
fromField: EdgeFromField<TFrom>,
toNode: string,
toField: EdgeToField<TTo>
hasEdge<TFrom extends AnyInvocation, TTo extends AnyInvocation>(
fromNode: TFrom,
fromField: OutputFields<TFrom>,
toNode: TTo,
toField: InputFields<TTo>
): boolean {
try {
this.getEdge(fromNode, fromField, toNode, toField);
@ -250,8 +188,8 @@ export class Graph {
* @param fromField The field of the source node (optional).
* @returns The edges.
*/
getEdgesFrom<TFrom extends EdgeNodeType>(fromNodeId: string, fromField?: EdgeFromField<TFrom>): Edge[] {
let edges = this._graph.edges.filter((edge) => edge.source.node_id === fromNodeId);
getEdgesFrom<T extends AnyInvocation>(fromNode: T, fromField?: OutputFields<T>): Edge[] {
let edges = this._graph.edges.filter((edge) => edge.source.node_id === fromNode.id);
if (fromField) {
edges = edges.filter((edge) => edge.source.field === fromField);
}
@ -265,8 +203,8 @@ export class Graph {
* @param toField The field of the destination node (optional).
* @returns The edges.
*/
getEdgesTo<TTo extends EdgeNodeType>(toNodeId: string, toField?: EdgeToField<TTo>): Edge[] {
let edges = this._graph.edges.filter((edge) => edge.destination.node_id === toNodeId);
getEdgesTo<T extends AnyInvocation>(toNode: T, toField?: InputFields<T>): Edge[] {
let edges = this._graph.edges.filter((edge) => edge.destination.node_id === toNode.id);
if (toField) {
edges = edges.filter((edge) => edge.destination.field === toField);
}
@ -284,11 +222,11 @@ export class Graph {
/**
* Delete all edges to a node. If `toField` is provided, only edges to that field are deleted.
* Provide the to node type as a generic to get type hints for to field names.
* @param toNodeId The id of the destination node.
* @param toNode The destination node.
* @param toField The field of the destination node (optional).
*/
deleteEdgesTo<TTo extends EdgeNodeType>(toNodeId: string, toField?: EdgeToField<TTo>): void {
for (const edge of this.getEdgesTo<TTo>(toNodeId, toField)) {
deleteEdgesTo<T extends AnyInvocation>(toNode: T, toField?: InputFields<T>): void {
for (const edge of this.getEdgesTo(toNode, toField)) {
this._deleteEdge(edge);
}
}
@ -299,8 +237,8 @@ export class Graph {
* @param toNodeId The id of the source node.
* @param toField The field of the source node (optional).
*/
deleteEdgesFrom<TFrom extends EdgeNodeType>(fromNodeId: string, fromField?: EdgeFromField<TFrom>): void {
for (const edge of this.getEdgesFrom<TFrom>(fromNodeId, fromField)) {
deleteEdgesFrom<T extends AnyInvocation>(fromNode: T, fromField?: OutputFields<T>): void {
for (const edge of this.getEdgesFrom(fromNode, fromField)) {
this._deleteEdge(edge);
}
}

View File

@ -10,6 +10,7 @@ describe('MetadataUtil', () => {
describe('getNode', () => {
it('should return the metadata node if one exists', () => {
const g = new Graph();
// @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
const metadataNode = g.addNode({ id: MetadataUtil.metadataNodeId, type: 'core_metadata' });
expect(MetadataUtil.getNode(g)).toEqual(metadataNode);
});
@ -56,14 +57,16 @@ describe('MetadataUtil', () => {
it('should add an edge from from metadata to the receiving node', () => {
const n = g.addNode({ id: 'my-node', type: 'img_resize' });
MetadataUtil.add(g, { foo: 'bar' });
MetadataUtil.setMetadataReceivingNode(g, n.id);
expect(g.hasEdge(MetadataUtil.metadataNodeId, 'metadata', n.id, 'metadata')).toBe(true);
MetadataUtil.setMetadataReceivingNode(g, n);
// @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
expect(g.hasEdge(MetadataUtil.getNode(g), 'metadata', n, 'metadata')).toBe(true);
});
it('should remove existing metadata edges', () => {
const n2 = g.addNode({ id: 'my-other-node', type: 'img_resize' });
MetadataUtil.setMetadataReceivingNode(g, n2.id);
expect(g.getIncomers(n2.id).length).toBe(1);
expect(g.hasEdge(MetadataUtil.metadataNodeId, 'metadata', n2.id, 'metadata')).toBe(true);
MetadataUtil.setMetadataReceivingNode(g, n2);
expect(g.getIncomers(n2).length).toBe(1);
// @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
expect(g.hasEdge(MetadataUtil.getNode(g), 'metadata', n2, 'metadata')).toBe(true);
});
});

View File

@ -1,34 +1,50 @@
import type { ModelIdentifierField } from 'features/nodes/types/common';
import { METADATA } from 'features/nodes/util/graph/constants';
import { isString, unset } from 'lodash-es';
import type { AnyModelConfig, Invocation } from 'services/api/types';
import type {
AnyInvocation,
AnyInvocationIncMetadata,
AnyModelConfig,
CoreMetadataInvocation,
S,
} from 'services/api/types';
import { assert } from 'tsafe';
import type { Graph } from './Graph';
const isCoreMetadata = (node: S['Graph']['nodes'][string]): node is CoreMetadataInvocation =>
node.type === 'core_metadata';
export class MetadataUtil {
static metadataNodeId = METADATA;
static getNode(graph: Graph): Invocation<'core_metadata'> {
return graph.getNode(this.metadataNodeId, 'core_metadata');
static getNode(g: Graph): CoreMetadataInvocation {
const node = g.getNode(this.metadataNodeId) as AnyInvocationIncMetadata;
assert(isCoreMetadata(node));
return node;
}
static add(graph: Graph, metadata: Partial<Invocation<'core_metadata'>>): Invocation<'core_metadata'> {
const metadataNode = graph.getNodeSafe(this.metadataNodeId, 'core_metadata');
if (!metadataNode) {
return graph.addNode({
static add(g: Graph, metadata: Partial<CoreMetadataInvocation>): CoreMetadataInvocation {
try {
const node = g.getNode(this.metadataNodeId) as AnyInvocationIncMetadata;
assert(isCoreMetadata(node));
Object.assign(node, metadata);
return node;
} catch {
const metadataNode: CoreMetadataInvocation = {
id: this.metadataNodeId,
type: 'core_metadata',
...metadata,
});
} else {
return graph.updateNode(this.metadataNodeId, 'core_metadata', metadata);
};
// @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
return g.addNode(metadataNode);
}
}
static remove(graph: Graph, key: string): Invocation<'core_metadata'>;
static remove(graph: Graph, keys: string[]): Invocation<'core_metadata'>;
static remove(graph: Graph, keyOrKeys: string | string[]): Invocation<'core_metadata'> {
const metadataNode = this.getNode(graph);
static remove(g: Graph, key: string): CoreMetadataInvocation;
static remove(g: Graph, keys: string[]): CoreMetadataInvocation;
static remove(g: Graph, keyOrKeys: string | string[]): CoreMetadataInvocation {
const metadataNode = this.getNode(g);
if (isString(keyOrKeys)) {
unset(metadataNode, keyOrKeys);
} else {
@ -39,10 +55,11 @@ export class MetadataUtil {
return metadataNode;
}
static setMetadataReceivingNode(graph: Graph, nodeId: string): void {
// We need to break the rules to update metadata - `addEdge` doesn't allow `core_metadata` as a node type
graph._graph.edges = graph._graph.edges.filter((edge) => edge.source.node_id !== this.metadataNodeId);
graph.addEdge(this.metadataNodeId, 'metadata', nodeId, 'metadata');
static setMetadataReceivingNode(g: Graph, node: AnyInvocation): void {
// @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
g.deleteEdgesFrom(this.getNode(g));
// @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
g.addEdge(this.getNode(g), 'metadata', node, 'metadata');
}
static getModelMetadataField({ key, hash, name, base, type }: AnyModelConfig): ModelIdentifierField {

View File

@ -131,30 +131,29 @@ export type WorkflowRecordListItemDTO = S['WorkflowRecordListItemDTO'];
export type KeysOfUnion<T> = T extends T ? keyof T : never;
export type NonInputFields = 'id' | 'type' | 'is_intermediate' | 'use_cache';
export type NonOutputFields = 'type';
export type AnyInvocation = Graph['nodes'][string];
export type AnyInvocationExcludeCoreMetata = Exclude<AnyInvocation, { type: 'core_metadata' }>;
export type InvocationType = AnyInvocation['type'];
export type InvocationTypeExcludeCoreMetadata = Exclude<InvocationType, 'core_metadata'>;
export type AnyInvocation = Exclude<
Graph['nodes'][string],
S['CoreMetadataInvocation'] | S['MetadataInvocation'] | S['MetadataItemInvocation'] | S['MergeMetadataInvocation']
>;
export type AnyInvocationIncMetadata = S['Graph']['nodes'][string];
export type InvocationType = AnyInvocation['type'];
export type InvocationOutputMap = S['InvocationOutputMap'];
export type AnyInvocationOutput = InvocationOutputMap[InvocationType];
export type Invocation<T extends InvocationType> = Extract<AnyInvocation, { type: T }>;
export type InvocationExcludeCoreMetadata<T extends InvocationTypeExcludeCoreMetadata> = Extract<
AnyInvocation,
{ type: T }
>;
export type InvocationInputFields<T extends InvocationTypeExcludeCoreMetadata> = Exclude<
keyof Invocation<T>,
NonInputFields
>;
export type AnyInvocationInputField = Exclude<KeysOfUnion<AnyInvocationExcludeCoreMetata>, NonInputFields>;
export type InvocationOutput<T extends InvocationType> = InvocationOutputMap[T];
export type InvocationOutputFields<T extends InvocationType> = Exclude<keyof InvocationOutput<T>, NonOutputFields>;
export type AnyInvocationOutputField = Exclude<KeysOfUnion<AnyInvocationOutput>, NonOutputFields>;
export type NonInputFields = 'id' | 'type' | 'is_intermediate' | 'use_cache' | 'board' | 'metadata';
export type AnyInvocationInputField = Exclude<KeysOfUnion<Required<AnyInvocation>>, NonInputFields>;
export type InputFields<T extends AnyInvocation> = Extract<keyof T, AnyInvocationInputField>;
export type NonOutputFields = 'type';
export type AnyInvocationOutputField = Exclude<KeysOfUnion<Required<AnyInvocationOutput>>, NonOutputFields>;
export type OutputFields<T extends AnyInvocation> = Extract<
keyof InvocationOutputMap[T['type']],
AnyInvocationOutputField
>;
// General nodes
export type CollectInvocation = Invocation<'collect'>;