mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): refine graph building util
Simpler types and API surface.
This commit is contained in:
parent
4020bf47e2
commit
8f6078d007
@ -69,63 +69,20 @@ describe('Graph', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('updateNode', () => {
|
||||
it('should update the node with the provided id', () => {
|
||||
const g = new Graph();
|
||||
const node: Invocation<'add'> = {
|
||||
id: 'test-node',
|
||||
type: 'add',
|
||||
a: 1,
|
||||
};
|
||||
g.addNode(node);
|
||||
const updatedNode = g.updateNode('test-node', 'add', {
|
||||
a: 2,
|
||||
});
|
||||
expect(g.getNode('test-node', 'add').a).toBe(2);
|
||||
expect(node).toBe(updatedNode);
|
||||
});
|
||||
it('should throw an error if the node is not found', () => {
|
||||
expect(() => new Graph().updateNode('not-found', 'add', {})).toThrowError(AssertionError);
|
||||
});
|
||||
it('should throw an error if the node is found but has the wrong type', () => {
|
||||
const g = new Graph();
|
||||
g.addNode({
|
||||
id: 'test-node',
|
||||
type: 'add',
|
||||
a: 1,
|
||||
});
|
||||
expect(() => g.updateNode('test-node', 'sub', {})).toThrowError(AssertionError);
|
||||
});
|
||||
it('should infer types correctly when `type` is omitted', () => {
|
||||
const g = new Graph();
|
||||
g.addNode({
|
||||
id: 'test-node',
|
||||
type: 'add',
|
||||
a: 1,
|
||||
});
|
||||
const updatedNode = g.updateNode('test-node', 'add', {
|
||||
a: 2,
|
||||
});
|
||||
assert(is<Invocation<'add'>>(updatedNode));
|
||||
});
|
||||
it('should infer types correctly when `type` is provided', () => {
|
||||
const g = new Graph();
|
||||
g.addNode({
|
||||
id: 'test-node',
|
||||
type: 'add',
|
||||
a: 1,
|
||||
});
|
||||
const updatedNode = g.updateNode('test-node', 'add', {
|
||||
a: 2,
|
||||
});
|
||||
assert(is<Invocation<'add'>>(updatedNode));
|
||||
});
|
||||
});
|
||||
|
||||
describe('addEdge', () => {
|
||||
const add: Invocation<'add'> = {
|
||||
id: 'from-node',
|
||||
type: 'add',
|
||||
};
|
||||
const sub: Invocation<'sub'> = {
|
||||
id: 'to-node',
|
||||
type: 'sub',
|
||||
};
|
||||
it('should add an edge to the graph with the provided values', () => {
|
||||
const g = new Graph();
|
||||
g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b');
|
||||
g.addNode(add);
|
||||
g.addNode(sub);
|
||||
g.addEdge(add, 'value', sub, 'b');
|
||||
expect(g._graph.edges.length).toBe(1);
|
||||
expect(g._graph.edges[0]).toEqual({
|
||||
source: { node_id: 'from-node', field: 'value' },
|
||||
@ -134,19 +91,19 @@ describe('Graph', () => {
|
||||
});
|
||||
it('should throw an error if the edge already exists', () => {
|
||||
const g = new Graph();
|
||||
g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b');
|
||||
expect(() => g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b')).toThrowError(AssertionError);
|
||||
g.addEdge(add, 'value', sub, 'b');
|
||||
expect(() => g.addEdge(add, 'value', sub, 'b')).toThrowError(AssertionError);
|
||||
});
|
||||
it('should infer field names', () => {
|
||||
const g = new Graph();
|
||||
// @ts-expect-error The first field must be a valid output field of the first type arg
|
||||
g.addEdge<'add', 'sub'>('from-node', 'not-a-valid-field', 'to-node', 'a');
|
||||
g.addEdge(add, 'not-a-valid-field', add, 'a');
|
||||
// @ts-expect-error The second field must be a valid input field of the second type arg
|
||||
g.addEdge<'add', 'sub'>('from-node-2', 'value', 'to-node-2', 'not-a-valid-field');
|
||||
g.addEdge(add, 'value', sub, 'not-a-valid-field');
|
||||
// @ts-expect-error The first field must be any valid output field
|
||||
g.addEdge('from-node-3', 'not-a-valid-field', 'to-node-3', 'a');
|
||||
g.addEdge(add, 'not-a-valid-field', sub, 'a');
|
||||
// @ts-expect-error The second field must be any valid input field
|
||||
g.addEdge('from-node-4', 'clip', 'to-node-4', 'not-a-valid-field');
|
||||
g.addEdge(add, 'clip', sub, 'not-a-valid-field');
|
||||
});
|
||||
});
|
||||
|
||||
@ -161,38 +118,9 @@ describe('Graph', () => {
|
||||
const n = g.getNode('test-node');
|
||||
expect(n).toBe(node);
|
||||
});
|
||||
it('should return the node with the provided id and type', () => {
|
||||
const n = g.getNode('test-node', 'add');
|
||||
expect(n).toBe(node);
|
||||
assert(is<Invocation<'add'>>(node));
|
||||
});
|
||||
it('should throw an error if the node is not found', () => {
|
||||
expect(() => g.getNode('not-found')).toThrowError(AssertionError);
|
||||
});
|
||||
it('should throw an error if the node is found but has the wrong type', () => {
|
||||
expect(() => g.getNode('test-node', 'sub')).toThrowError(AssertionError);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getNodeSafe', () => {
|
||||
const g = new Graph();
|
||||
const node = g.addNode({
|
||||
id: 'test-node',
|
||||
type: 'add',
|
||||
});
|
||||
it('should return the node if it is found', () => {
|
||||
expect(g.getNodeSafe('test-node')).toBe(node);
|
||||
});
|
||||
it('should return the node if it is found with the provided type', () => {
|
||||
expect(g.getNodeSafe('test-node')).toBe(node);
|
||||
assert(is<Invocation<'add'>>(node));
|
||||
});
|
||||
it("should return undefined if the node isn't found", () => {
|
||||
expect(g.getNodeSafe('not-found')).toBeUndefined();
|
||||
});
|
||||
it('should return undefined if the node is found but has the wrong type', () => {
|
||||
expect(g.getNodeSafe('test-node', 'sub')).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('hasNode', () => {
|
||||
@ -212,40 +140,42 @@ describe('Graph', () => {
|
||||
|
||||
describe('getEdge', () => {
|
||||
const g = new Graph();
|
||||
g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b');
|
||||
const add: Invocation<'add'> = {
|
||||
id: 'from-node',
|
||||
type: 'add',
|
||||
};
|
||||
const sub: Invocation<'sub'> = {
|
||||
id: 'to-node',
|
||||
type: 'sub',
|
||||
};
|
||||
g.addEdge(add, 'value', sub, 'b');
|
||||
it('should return the edge with the provided values', () => {
|
||||
expect(g.getEdge('from-node', 'value', 'to-node', 'b')).toEqual({
|
||||
expect(g.getEdge(add, 'value', sub, 'b')).toEqual({
|
||||
source: { node_id: 'from-node', field: 'value' },
|
||||
destination: { node_id: 'to-node', field: 'b' },
|
||||
});
|
||||
});
|
||||
it('should throw an error if the edge is not found', () => {
|
||||
expect(() => g.getEdge('from-node', 'value', 'to-node', 'a')).toThrowError(AssertionError);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getEdgeSafe', () => {
|
||||
const g = new Graph();
|
||||
g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b');
|
||||
it('should return the edge if it is found', () => {
|
||||
expect(g.getEdgeSafe('from-node', 'value', 'to-node', 'b')).toEqual({
|
||||
source: { node_id: 'from-node', field: 'value' },
|
||||
destination: { node_id: 'to-node', field: 'b' },
|
||||
});
|
||||
});
|
||||
it('should return undefined if the edge is not found', () => {
|
||||
expect(g.getEdgeSafe('from-node', 'value', 'to-node', 'a')).toBeUndefined();
|
||||
expect(() => g.getEdge(add, 'value', sub, 'a')).toThrowError(AssertionError);
|
||||
});
|
||||
});
|
||||
|
||||
describe('hasEdge', () => {
|
||||
const g = new Graph();
|
||||
g.addEdge<'add', 'sub'>('from-node', 'value', 'to-node', 'b');
|
||||
const add: Invocation<'add'> = {
|
||||
id: 'from-node',
|
||||
type: 'add',
|
||||
};
|
||||
const sub: Invocation<'sub'> = {
|
||||
id: 'to-node',
|
||||
type: 'sub',
|
||||
};
|
||||
g.addEdge(add, 'value', sub, 'b');
|
||||
it('should return true if the edge is in the graph', () => {
|
||||
expect(g.hasEdge('from-node', 'value', 'to-node', 'b')).toBe(true);
|
||||
expect(g.hasEdge(add, 'value', sub, 'b')).toBe(true);
|
||||
});
|
||||
it('should return false if the edge is not in the graph', () => {
|
||||
expect(g.hasEdge('from-node', 'value', 'to-node', 'a')).toBe(false);
|
||||
expect(g.hasEdge(add, 'value', sub, 'a')).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
@ -256,7 +186,15 @@ describe('Graph', () => {
|
||||
});
|
||||
it('should raise an error if the graph is invalid', () => {
|
||||
const g = new Graph();
|
||||
g.addEdge('from-node', 'value', 'to-node', 'b');
|
||||
const add: Invocation<'add'> = {
|
||||
id: 'from-node',
|
||||
type: 'add',
|
||||
};
|
||||
const sub: Invocation<'sub'> = {
|
||||
id: 'to-node',
|
||||
type: 'sub',
|
||||
};
|
||||
g.addEdge(add, 'value', sub, 'b');
|
||||
expect(() => g.getGraph()).toThrowError(AssertionError);
|
||||
});
|
||||
});
|
||||
@ -264,7 +202,15 @@ describe('Graph', () => {
|
||||
describe('getGraphSafe', () => {
|
||||
it('should return the graph even if it is invalid', () => {
|
||||
const g = new Graph();
|
||||
g.addEdge('from-node', 'value', 'to-node', 'b');
|
||||
const add: Invocation<'add'> = {
|
||||
id: 'from-node',
|
||||
type: 'add',
|
||||
};
|
||||
const sub: Invocation<'sub'> = {
|
||||
id: 'to-node',
|
||||
type: 'sub',
|
||||
};
|
||||
g.addEdge(add, 'value', sub, 'b');
|
||||
expect(g.getGraphSafe()).toBe(g._graph);
|
||||
});
|
||||
});
|
||||
@ -276,8 +222,16 @@ describe('Graph', () => {
|
||||
});
|
||||
it('should throw an error if the graph is invalid', () => {
|
||||
const g = new Graph();
|
||||
const add: Invocation<'add'> = {
|
||||
id: 'from-node',
|
||||
type: 'add',
|
||||
};
|
||||
const sub: Invocation<'sub'> = {
|
||||
id: 'to-node',
|
||||
type: 'sub',
|
||||
};
|
||||
// edge from nowhere to nowhere
|
||||
g.addEdge('from-node', 'value', 'to-node', 'b');
|
||||
g.addEdge(add, 'value', sub, 'b');
|
||||
expect(() => g.validate()).toThrowError(AssertionError);
|
||||
});
|
||||
});
|
||||
@ -304,34 +258,34 @@ describe('Graph', () => {
|
||||
id: 'n5',
|
||||
type: 'add',
|
||||
});
|
||||
const e1 = g.addEdge<'add', 'add'>(n1.id, 'value', n3.id, 'a');
|
||||
const e2 = g.addEdge<'alpha_mask_to_tensor', 'add'>(n2.id, 'height', n3.id, 'b');
|
||||
const e3 = g.addEdge<'add', 'add'>(n3.id, 'value', n4.id, 'a');
|
||||
const e4 = g.addEdge<'add', 'add'>(n3.id, 'value', n5.id, 'b');
|
||||
const e1 = g.addEdge(n1, 'value', n3, 'a');
|
||||
const e2 = g.addEdge(n2, 'height', n3, 'b');
|
||||
const e3 = g.addEdge(n3, 'value', n4, 'a');
|
||||
const e4 = g.addEdge(n3, 'value', n5, 'b');
|
||||
describe('getEdgesFrom', () => {
|
||||
it('should return the edges that start at the provided node', () => {
|
||||
expect(g.getEdgesFrom(n3.id)).toEqual([e3, e4]);
|
||||
expect(g.getEdgesFrom(n3)).toEqual([e3, e4]);
|
||||
});
|
||||
it('should return the edges that start at the provided node and have the provided field', () => {
|
||||
expect(g.getEdgesFrom(n2.id, 'height')).toEqual([e2]);
|
||||
expect(g.getEdgesFrom(n2, 'height')).toEqual([e2]);
|
||||
});
|
||||
});
|
||||
describe('getEdgesTo', () => {
|
||||
it('should return the edges that end at the provided node', () => {
|
||||
expect(g.getEdgesTo(n3.id)).toEqual([e1, e2]);
|
||||
expect(g.getEdgesTo(n3)).toEqual([e1, e2]);
|
||||
});
|
||||
it('should return the edges that end at the provided node and have the provided field', () => {
|
||||
expect(g.getEdgesTo(n3.id, 'b')).toEqual([e2]);
|
||||
expect(g.getEdgesTo(n3, 'b')).toEqual([e2]);
|
||||
});
|
||||
});
|
||||
describe('getIncomers', () => {
|
||||
it('should return the nodes that have an edge to the provided node', () => {
|
||||
expect(g.getIncomers(n3.id)).toEqual([n1, n2]);
|
||||
expect(g.getIncomers(n3)).toEqual([n1, n2]);
|
||||
});
|
||||
});
|
||||
describe('getOutgoers', () => {
|
||||
it('should return the nodes that the provided node has an edge to', () => {
|
||||
expect(g.getOutgoers(n3.id)).toEqual([n4, n5]);
|
||||
expect(g.getOutgoers(n3)).toEqual([n4, n5]);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
@ -3,31 +3,26 @@ import type {
|
||||
AnyInvocation,
|
||||
AnyInvocationInputField,
|
||||
AnyInvocationOutputField,
|
||||
InputFields,
|
||||
Invocation,
|
||||
InvocationInputFields,
|
||||
InvocationOutputFields,
|
||||
InvocationType,
|
||||
S,
|
||||
OutputFields,
|
||||
} from 'services/api/types';
|
||||
import type { O } from 'ts-toolbelt';
|
||||
import { assert } from 'tsafe';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
type GraphType = O.NonNullable<O.Required<S['Graph']>>;
|
||||
type Edge = GraphType['edges'][number];
|
||||
type Never = Record<string, never>;
|
||||
type Edge = {
|
||||
source: {
|
||||
node_id: string;
|
||||
field: AnyInvocationOutputField;
|
||||
};
|
||||
destination: {
|
||||
node_id: string;
|
||||
field: AnyInvocationInputField;
|
||||
};
|
||||
};
|
||||
|
||||
// The `core_metadata` node has very lax types, it accepts arbitrary field names. It must be excluded from edge utils
|
||||
// to preview their types from being widened from a union of valid field names to `string | number | symbol`.
|
||||
type EdgeNodeType = Exclude<InvocationType, 'core_metadata'>;
|
||||
|
||||
type EdgeFromField<TFrom extends EdgeNodeType | Never = Never> = TFrom extends EdgeNodeType
|
||||
? InvocationOutputFields<TFrom>
|
||||
: AnyInvocationOutputField;
|
||||
|
||||
type EdgeToField<TTo extends EdgeNodeType | Never = Never> = TTo extends EdgeNodeType
|
||||
? InvocationInputFields<TTo>
|
||||
: AnyInvocationInputField;
|
||||
type GraphType = { id: string; nodes: Record<string, AnyInvocation>; edges: Edge[] };
|
||||
|
||||
export class Graph {
|
||||
_graph: GraphType;
|
||||
@ -64,45 +59,12 @@ export class Graph {
|
||||
/**
|
||||
* Gets a node from the graph.
|
||||
* @param id The id of the node to get.
|
||||
* @param type The type of the node to get. If provided, the retrieved node is guaranteed to be of this type.
|
||||
* @returns The node.
|
||||
* @raises `AssertionError` if the node does not exist or if a `type` is provided but the node is not of the expected type.
|
||||
*/
|
||||
getNode<T extends InvocationType>(id: string, type?: T): Invocation<T> {
|
||||
getNode(id: string): AnyInvocation {
|
||||
const node = this._graph.nodes[id];
|
||||
assert(node !== undefined, Graph.getNodeNotFoundMsg(id));
|
||||
if (type) {
|
||||
assert(node.type === type, Graph.getNodeNotOfTypeMsg(node, type));
|
||||
}
|
||||
// We just asserted that the node type is correct, this is OK to cast
|
||||
return node as Invocation<T>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Gets a node from the graph without raising an error if the node does not exist or is not of the expected type.
|
||||
* @param id The id of the node to get.
|
||||
* @param type The type of the node to get. If provided, node is guaranteed to be of this type.
|
||||
* @returns The node, if it exists and is of the correct type. Otherwise, `undefined`.
|
||||
*/
|
||||
getNodeSafe<T extends InvocationType>(id: string, type?: T): Invocation<T> | undefined {
|
||||
try {
|
||||
return this.getNode(id, type);
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Update a node in the graph. Properties are shallow-copied from `updates` to the node.
|
||||
* @param id The id of the node to update.
|
||||
* @param type The type of the node to update. If provided, node is guaranteed to be of this type.
|
||||
* @param updates The fields to update on the node.
|
||||
* @returns The updated node.
|
||||
* @raises `AssertionError` if the node does not exist or its type doesn't match.
|
||||
*/
|
||||
updateNode<T extends InvocationType>(id: string, type: T, updates: Partial<Invocation<T>>): Invocation<T> {
|
||||
const node = this.getNode(id, type);
|
||||
Object.assign(node, updates);
|
||||
return node;
|
||||
}
|
||||
|
||||
@ -125,8 +87,8 @@ export class Graph {
|
||||
* @returns The incoming nodes.
|
||||
* @raises `AssertionError` if the node does not exist.
|
||||
*/
|
||||
getIncomers(nodeId: string): AnyInvocation[] {
|
||||
return this.getEdgesTo(nodeId).map((edge) => this.getNode(edge.source.node_id));
|
||||
getIncomers(node: AnyInvocation): AnyInvocation[] {
|
||||
return this.getEdgesTo(node).map((edge) => this.getNode(edge.source.node_id));
|
||||
}
|
||||
|
||||
/**
|
||||
@ -135,8 +97,8 @@ export class Graph {
|
||||
* @returns The outgoing nodes.
|
||||
* @raises `AssertionError` if the node does not exist.
|
||||
*/
|
||||
getOutgoers(nodeId: string): AnyInvocation[] {
|
||||
return this.getEdgesFrom(nodeId).map((edge) => this.getNode(edge.destination.node_id));
|
||||
getOutgoers(node: AnyInvocation): AnyInvocation[] {
|
||||
return this.getEdgesFrom(node).map((edge) => this.getNode(edge.destination.node_id));
|
||||
}
|
||||
//#endregion
|
||||
|
||||
@ -144,28 +106,26 @@ export class Graph {
|
||||
|
||||
/**
|
||||
* Add an edge to the graph. If an edge with the same source and destination already exists, an `AssertionError` is raised.
|
||||
* Provide the from and to node types as generics to get type hints for from and to field names.
|
||||
* @param fromNodeId The id of the source node.
|
||||
* If providing node ids, provide the from and to node types as generics to get type hints for from and to field names.
|
||||
* @param fromNode The source node or id of the source node.
|
||||
* @param fromField The field of the source node.
|
||||
* @param toNodeId The id of the destination node.
|
||||
* @param toNode The source node or id of the destination node.
|
||||
* @param toField The field of the destination node.
|
||||
* @returns The added edge.
|
||||
* @raises `AssertionError` if an edge with the same source and destination already exists.
|
||||
*/
|
||||
addEdge<TFrom extends EdgeNodeType, TTo extends EdgeNodeType>(
|
||||
fromNodeId: string,
|
||||
fromField: EdgeFromField<TFrom>,
|
||||
toNodeId: string,
|
||||
toField: EdgeToField<TTo>
|
||||
addEdge<TFrom extends AnyInvocation, TTo extends AnyInvocation>(
|
||||
fromNode: TFrom,
|
||||
fromField: OutputFields<TFrom>,
|
||||
toNode: TTo,
|
||||
toField: InputFields<TTo>
|
||||
): Edge {
|
||||
const edge = {
|
||||
source: { node_id: fromNodeId, field: fromField },
|
||||
destination: { node_id: toNodeId, field: toField },
|
||||
const edge: Edge = {
|
||||
source: { node_id: fromNode.id, field: fromField },
|
||||
destination: { node_id: toNode.id, field: toField },
|
||||
};
|
||||
assert(
|
||||
!this._graph.edges.some((e) => isEqual(e, edge)),
|
||||
Graph.getEdgeAlreadyExistsMsg(fromNodeId, fromField, toNodeId, toField)
|
||||
);
|
||||
const edgeAlreadyExists = this._graph.edges.some((e) => isEqual(e, edge));
|
||||
assert(!edgeAlreadyExists, Graph.getEdgeAlreadyExistsMsg(fromNode.id, fromField, toNode.id, toField));
|
||||
this._graph.edges.push(edge);
|
||||
return edge;
|
||||
}
|
||||
@ -180,45 +140,23 @@ export class Graph {
|
||||
* @returns The edge.
|
||||
* @raises `AssertionError` if the edge does not exist.
|
||||
*/
|
||||
getEdge<TFrom extends EdgeNodeType, TTo extends EdgeNodeType>(
|
||||
fromNode: string,
|
||||
fromField: EdgeFromField<TFrom>,
|
||||
toNode: string,
|
||||
toField: EdgeToField<TTo>
|
||||
getEdge<TFrom extends AnyInvocation, TTo extends AnyInvocation>(
|
||||
fromNode: TFrom,
|
||||
fromField: OutputFields<TFrom>,
|
||||
toNode: TTo,
|
||||
toField: InputFields<TTo>
|
||||
): Edge {
|
||||
const edge = this._graph.edges.find(
|
||||
(e) =>
|
||||
e.source.node_id === fromNode &&
|
||||
e.source.node_id === fromNode.id &&
|
||||
e.source.field === fromField &&
|
||||
e.destination.node_id === toNode &&
|
||||
e.destination.node_id === toNode.id &&
|
||||
e.destination.field === toField
|
||||
);
|
||||
assert(edge !== undefined, Graph.getEdgeNotFoundMsg(fromNode, fromField, toNode, toField));
|
||||
assert(edge !== undefined, Graph.getEdgeNotFoundMsg(fromNode.id, fromField, toNode.id, toField));
|
||||
return edge;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get an edge from the graph, or undefined if it doesn't exist.
|
||||
* Provide the from and to node types as generics to get type hints for from and to field names.
|
||||
* @param fromNodeId The id of the source node.
|
||||
* @param fromField The field of the source node.
|
||||
* @param toNodeId The id of the destination node.
|
||||
* @param toField The field of the destination node.
|
||||
* @returns The edge, or undefined if it doesn't exist.
|
||||
*/
|
||||
getEdgeSafe<TFrom extends EdgeNodeType, TTo extends EdgeNodeType>(
|
||||
fromNode: string,
|
||||
fromField: EdgeFromField<TFrom>,
|
||||
toNode: string,
|
||||
toField: EdgeToField<TTo>
|
||||
): Edge | undefined {
|
||||
try {
|
||||
return this.getEdge(fromNode, fromField, toNode, toField);
|
||||
} catch {
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a graph has an edge.
|
||||
* Provide the from and to node types as generics to get type hints for from and to field names.
|
||||
@ -229,11 +167,11 @@ export class Graph {
|
||||
* @returns Whether the graph has the edge.
|
||||
*/
|
||||
|
||||
hasEdge<TFrom extends EdgeNodeType, TTo extends EdgeNodeType>(
|
||||
fromNode: string,
|
||||
fromField: EdgeFromField<TFrom>,
|
||||
toNode: string,
|
||||
toField: EdgeToField<TTo>
|
||||
hasEdge<TFrom extends AnyInvocation, TTo extends AnyInvocation>(
|
||||
fromNode: TFrom,
|
||||
fromField: OutputFields<TFrom>,
|
||||
toNode: TTo,
|
||||
toField: InputFields<TTo>
|
||||
): boolean {
|
||||
try {
|
||||
this.getEdge(fromNode, fromField, toNode, toField);
|
||||
@ -250,8 +188,8 @@ export class Graph {
|
||||
* @param fromField The field of the source node (optional).
|
||||
* @returns The edges.
|
||||
*/
|
||||
getEdgesFrom<TFrom extends EdgeNodeType>(fromNodeId: string, fromField?: EdgeFromField<TFrom>): Edge[] {
|
||||
let edges = this._graph.edges.filter((edge) => edge.source.node_id === fromNodeId);
|
||||
getEdgesFrom<T extends AnyInvocation>(fromNode: T, fromField?: OutputFields<T>): Edge[] {
|
||||
let edges = this._graph.edges.filter((edge) => edge.source.node_id === fromNode.id);
|
||||
if (fromField) {
|
||||
edges = edges.filter((edge) => edge.source.field === fromField);
|
||||
}
|
||||
@ -265,8 +203,8 @@ export class Graph {
|
||||
* @param toField The field of the destination node (optional).
|
||||
* @returns The edges.
|
||||
*/
|
||||
getEdgesTo<TTo extends EdgeNodeType>(toNodeId: string, toField?: EdgeToField<TTo>): Edge[] {
|
||||
let edges = this._graph.edges.filter((edge) => edge.destination.node_id === toNodeId);
|
||||
getEdgesTo<T extends AnyInvocation>(toNode: T, toField?: InputFields<T>): Edge[] {
|
||||
let edges = this._graph.edges.filter((edge) => edge.destination.node_id === toNode.id);
|
||||
if (toField) {
|
||||
edges = edges.filter((edge) => edge.destination.field === toField);
|
||||
}
|
||||
@ -284,11 +222,11 @@ export class Graph {
|
||||
/**
|
||||
* Delete all edges to a node. If `toField` is provided, only edges to that field are deleted.
|
||||
* Provide the to node type as a generic to get type hints for to field names.
|
||||
* @param toNodeId The id of the destination node.
|
||||
* @param toNode The destination node.
|
||||
* @param toField The field of the destination node (optional).
|
||||
*/
|
||||
deleteEdgesTo<TTo extends EdgeNodeType>(toNodeId: string, toField?: EdgeToField<TTo>): void {
|
||||
for (const edge of this.getEdgesTo<TTo>(toNodeId, toField)) {
|
||||
deleteEdgesTo<T extends AnyInvocation>(toNode: T, toField?: InputFields<T>): void {
|
||||
for (const edge of this.getEdgesTo(toNode, toField)) {
|
||||
this._deleteEdge(edge);
|
||||
}
|
||||
}
|
||||
@ -299,8 +237,8 @@ export class Graph {
|
||||
* @param toNodeId The id of the source node.
|
||||
* @param toField The field of the source node (optional).
|
||||
*/
|
||||
deleteEdgesFrom<TFrom extends EdgeNodeType>(fromNodeId: string, fromField?: EdgeFromField<TFrom>): void {
|
||||
for (const edge of this.getEdgesFrom<TFrom>(fromNodeId, fromField)) {
|
||||
deleteEdgesFrom<T extends AnyInvocation>(fromNode: T, fromField?: OutputFields<T>): void {
|
||||
for (const edge of this.getEdgesFrom(fromNode, fromField)) {
|
||||
this._deleteEdge(edge);
|
||||
}
|
||||
}
|
||||
|
@ -10,6 +10,7 @@ describe('MetadataUtil', () => {
|
||||
describe('getNode', () => {
|
||||
it('should return the metadata node if one exists', () => {
|
||||
const g = new Graph();
|
||||
// @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
|
||||
const metadataNode = g.addNode({ id: MetadataUtil.metadataNodeId, type: 'core_metadata' });
|
||||
expect(MetadataUtil.getNode(g)).toEqual(metadataNode);
|
||||
});
|
||||
@ -56,14 +57,16 @@ describe('MetadataUtil', () => {
|
||||
it('should add an edge from from metadata to the receiving node', () => {
|
||||
const n = g.addNode({ id: 'my-node', type: 'img_resize' });
|
||||
MetadataUtil.add(g, { foo: 'bar' });
|
||||
MetadataUtil.setMetadataReceivingNode(g, n.id);
|
||||
expect(g.hasEdge(MetadataUtil.metadataNodeId, 'metadata', n.id, 'metadata')).toBe(true);
|
||||
MetadataUtil.setMetadataReceivingNode(g, n);
|
||||
// @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
|
||||
expect(g.hasEdge(MetadataUtil.getNode(g), 'metadata', n, 'metadata')).toBe(true);
|
||||
});
|
||||
it('should remove existing metadata edges', () => {
|
||||
const n2 = g.addNode({ id: 'my-other-node', type: 'img_resize' });
|
||||
MetadataUtil.setMetadataReceivingNode(g, n2.id);
|
||||
expect(g.getIncomers(n2.id).length).toBe(1);
|
||||
expect(g.hasEdge(MetadataUtil.metadataNodeId, 'metadata', n2.id, 'metadata')).toBe(true);
|
||||
MetadataUtil.setMetadataReceivingNode(g, n2);
|
||||
expect(g.getIncomers(n2).length).toBe(1);
|
||||
// @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
|
||||
expect(g.hasEdge(MetadataUtil.getNode(g), 'metadata', n2, 'metadata')).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
|
@ -1,34 +1,50 @@
|
||||
import type { ModelIdentifierField } from 'features/nodes/types/common';
|
||||
import { METADATA } from 'features/nodes/util/graph/constants';
|
||||
import { isString, unset } from 'lodash-es';
|
||||
import type { AnyModelConfig, Invocation } from 'services/api/types';
|
||||
import type {
|
||||
AnyInvocation,
|
||||
AnyInvocationIncMetadata,
|
||||
AnyModelConfig,
|
||||
CoreMetadataInvocation,
|
||||
S,
|
||||
} from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import type { Graph } from './Graph';
|
||||
|
||||
const isCoreMetadata = (node: S['Graph']['nodes'][string]): node is CoreMetadataInvocation =>
|
||||
node.type === 'core_metadata';
|
||||
|
||||
export class MetadataUtil {
|
||||
static metadataNodeId = METADATA;
|
||||
|
||||
static getNode(graph: Graph): Invocation<'core_metadata'> {
|
||||
return graph.getNode(this.metadataNodeId, 'core_metadata');
|
||||
static getNode(g: Graph): CoreMetadataInvocation {
|
||||
const node = g.getNode(this.metadataNodeId) as AnyInvocationIncMetadata;
|
||||
assert(isCoreMetadata(node));
|
||||
return node;
|
||||
}
|
||||
|
||||
static add(graph: Graph, metadata: Partial<Invocation<'core_metadata'>>): Invocation<'core_metadata'> {
|
||||
const metadataNode = graph.getNodeSafe(this.metadataNodeId, 'core_metadata');
|
||||
if (!metadataNode) {
|
||||
return graph.addNode({
|
||||
static add(g: Graph, metadata: Partial<CoreMetadataInvocation>): CoreMetadataInvocation {
|
||||
try {
|
||||
const node = g.getNode(this.metadataNodeId) as AnyInvocationIncMetadata;
|
||||
assert(isCoreMetadata(node));
|
||||
Object.assign(node, metadata);
|
||||
return node;
|
||||
} catch {
|
||||
const metadataNode: CoreMetadataInvocation = {
|
||||
id: this.metadataNodeId,
|
||||
type: 'core_metadata',
|
||||
...metadata,
|
||||
});
|
||||
} else {
|
||||
return graph.updateNode(this.metadataNodeId, 'core_metadata', metadata);
|
||||
};
|
||||
// @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
|
||||
return g.addNode(metadataNode);
|
||||
}
|
||||
}
|
||||
|
||||
static remove(graph: Graph, key: string): Invocation<'core_metadata'>;
|
||||
static remove(graph: Graph, keys: string[]): Invocation<'core_metadata'>;
|
||||
static remove(graph: Graph, keyOrKeys: string | string[]): Invocation<'core_metadata'> {
|
||||
const metadataNode = this.getNode(graph);
|
||||
static remove(g: Graph, key: string): CoreMetadataInvocation;
|
||||
static remove(g: Graph, keys: string[]): CoreMetadataInvocation;
|
||||
static remove(g: Graph, keyOrKeys: string | string[]): CoreMetadataInvocation {
|
||||
const metadataNode = this.getNode(g);
|
||||
if (isString(keyOrKeys)) {
|
||||
unset(metadataNode, keyOrKeys);
|
||||
} else {
|
||||
@ -39,10 +55,11 @@ export class MetadataUtil {
|
||||
return metadataNode;
|
||||
}
|
||||
|
||||
static setMetadataReceivingNode(graph: Graph, nodeId: string): void {
|
||||
// We need to break the rules to update metadata - `addEdge` doesn't allow `core_metadata` as a node type
|
||||
graph._graph.edges = graph._graph.edges.filter((edge) => edge.source.node_id !== this.metadataNodeId);
|
||||
graph.addEdge(this.metadataNodeId, 'metadata', nodeId, 'metadata');
|
||||
static setMetadataReceivingNode(g: Graph, node: AnyInvocation): void {
|
||||
// @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
|
||||
g.deleteEdgesFrom(this.getNode(g));
|
||||
// @ts-expect-error `Graph` excludes `core_metadata` nodes due to its excessively wide typing
|
||||
g.addEdge(this.getNode(g), 'metadata', node, 'metadata');
|
||||
}
|
||||
|
||||
static getModelMetadataField({ key, hash, name, base, type }: AnyModelConfig): ModelIdentifierField {
|
||||
|
@ -131,30 +131,29 @@ export type WorkflowRecordListItemDTO = S['WorkflowRecordListItemDTO'];
|
||||
|
||||
export type KeysOfUnion<T> = T extends T ? keyof T : never;
|
||||
|
||||
export type NonInputFields = 'id' | 'type' | 'is_intermediate' | 'use_cache';
|
||||
export type NonOutputFields = 'type';
|
||||
export type AnyInvocation = Graph['nodes'][string];
|
||||
export type AnyInvocationExcludeCoreMetata = Exclude<AnyInvocation, { type: 'core_metadata' }>;
|
||||
export type InvocationType = AnyInvocation['type'];
|
||||
export type InvocationTypeExcludeCoreMetadata = Exclude<InvocationType, 'core_metadata'>;
|
||||
export type AnyInvocation = Exclude<
|
||||
Graph['nodes'][string],
|
||||
S['CoreMetadataInvocation'] | S['MetadataInvocation'] | S['MetadataItemInvocation'] | S['MergeMetadataInvocation']
|
||||
>;
|
||||
export type AnyInvocationIncMetadata = S['Graph']['nodes'][string];
|
||||
|
||||
export type InvocationType = AnyInvocation['type'];
|
||||
export type InvocationOutputMap = S['InvocationOutputMap'];
|
||||
export type AnyInvocationOutput = InvocationOutputMap[InvocationType];
|
||||
|
||||
export type Invocation<T extends InvocationType> = Extract<AnyInvocation, { type: T }>;
|
||||
export type InvocationExcludeCoreMetadata<T extends InvocationTypeExcludeCoreMetadata> = Extract<
|
||||
AnyInvocation,
|
||||
{ type: T }
|
||||
>;
|
||||
export type InvocationInputFields<T extends InvocationTypeExcludeCoreMetadata> = Exclude<
|
||||
keyof Invocation<T>,
|
||||
NonInputFields
|
||||
>;
|
||||
export type AnyInvocationInputField = Exclude<KeysOfUnion<AnyInvocationExcludeCoreMetata>, NonInputFields>;
|
||||
|
||||
export type InvocationOutput<T extends InvocationType> = InvocationOutputMap[T];
|
||||
export type InvocationOutputFields<T extends InvocationType> = Exclude<keyof InvocationOutput<T>, NonOutputFields>;
|
||||
export type AnyInvocationOutputField = Exclude<KeysOfUnion<AnyInvocationOutput>, NonOutputFields>;
|
||||
|
||||
export type NonInputFields = 'id' | 'type' | 'is_intermediate' | 'use_cache' | 'board' | 'metadata';
|
||||
export type AnyInvocationInputField = Exclude<KeysOfUnion<Required<AnyInvocation>>, NonInputFields>;
|
||||
export type InputFields<T extends AnyInvocation> = Extract<keyof T, AnyInvocationInputField>;
|
||||
|
||||
export type NonOutputFields = 'type';
|
||||
export type AnyInvocationOutputField = Exclude<KeysOfUnion<Required<AnyInvocationOutput>>, NonOutputFields>;
|
||||
export type OutputFields<T extends AnyInvocation> = Extract<
|
||||
keyof InvocationOutputMap[T['type']],
|
||||
AnyInvocationOutputField
|
||||
>;
|
||||
|
||||
// General nodes
|
||||
export type CollectInvocation = Invocation<'collect'>;
|
||||
|
Loading…
Reference in New Issue
Block a user