feat(ui): support multiple fields for getEdgesTo, getEdgesFrom, deleteEdgesTo, deleteEdgesFrom

This commit is contained in:
psychedelicious 2024-05-13 16:07:34 +10:00
parent 2be66b1546
commit e8d3a7c870
2 changed files with 19 additions and 18 deletions
invokeai/frontend/web/src/features/nodes/util/graph

View File

@ -316,7 +316,7 @@ describe('Graph', () => {
expect(g.getEdgesFrom(n3)).toEqual([e3, e4]);
});
it('should return the edges that start at the provided node and have the provided field', () => {
expect(g.getEdgesFrom(n2, 'height')).toEqual([e2]);
expect(g.getEdgesFrom(n3, ['value'])).toEqual([e3, e4]);
});
});
describe('getEdgesTo', () => {
@ -324,7 +324,7 @@ describe('Graph', () => {
expect(g.getEdgesTo(n3)).toEqual([e1, e2]);
});
it('should return the edges that end at the provided node and have the provided field', () => {
expect(g.getEdgesTo(n3, 'b')).toEqual([e2]);
expect(g.getEdgesTo(n3, ['b', 'a'])).toEqual([e1, e2]);
});
});
describe('getIncomers', () => {
@ -372,7 +372,7 @@ describe('Graph', () => {
const _e1 = g.addEdge(n1, 'height', n2, 'a');
const e2 = g.addEdge(n1, 'width', n2, 'b');
const e3 = g.addEdge(n1, 'width', n3, 'b');
g.deleteEdgesFrom(n1, 'height');
g.deleteEdgesFrom(n1, ['height']);
expect(g.getEdgesFrom(n1)).toEqual([e2, e3]);
});
});
@ -410,7 +410,7 @@ describe('Graph', () => {
const _e1 = g.addEdge(n1, 'height', n3, 'a');
const e2 = g.addEdge(n1, 'width', n3, 'b');
const _e3 = g.addEdge(n2, 'width', n3, 'a');
g.deleteEdgesTo(n3, 'a');
g.deleteEdgesTo(n3, ['a']);
expect(g.getEdgesTo(n3)).toEqual([e2]);
});
});

View File

@ -231,13 +231,14 @@ export class Graph {
* Get all edges from a node. If `fromField` is provided, only edges from that field are returned.
* Provide the from node type as a generic to get type hints for from field names.
* @param fromNodeId The id of the source node.
* @param fromField The field of the source node (optional).
* @param fromFields The field of the source node (optional).
* @returns The edges.
*/
getEdgesFrom<T extends AnyInvocation>(fromNode: T, fromField?: OutputFields<T>): Edge[] {
getEdgesFrom<T extends AnyInvocation>(fromNode: T, fromFields?: OutputFields<T>[]): Edge[] {
let edges = this._graph.edges.filter((edge) => edge.source.node_id === fromNode.id);
if (fromField) {
edges = edges.filter((edge) => edge.source.field === fromField);
if (fromFields) {
// TODO(psyche): figure out how to satisfy TS here without casting - this is _not_ an unsafe cast
edges = edges.filter((edge) => (fromFields as AnyInvocationOutputField[]).includes(edge.source.field));
}
return edges;
}
@ -246,13 +247,13 @@ export class Graph {
* Get all edges to a node. If `toField` is provided, only edges to that field are returned.
* Provide the to node type as a generic to get type hints for to field names.
* @param toNodeId The id of the destination node.
* @param toField The field of the destination node (optional).
* @param toFields The field of the destination node (optional).
* @returns The edges.
*/
getEdgesTo<T extends AnyInvocation>(toNode: T, toField?: InputFields<T>): Edge[] {
getEdgesTo<T extends AnyInvocation>(toNode: T, toFields?: InputFields<T>[]): Edge[] {
let edges = this._graph.edges.filter((edge) => edge.destination.node_id === toNode.id);
if (toField) {
edges = edges.filter((edge) => edge.destination.field === toField);
if (toFields) {
edges = edges.filter((edge) => (toFields as AnyInvocationInputField[]).includes(edge.destination.field));
}
return edges;
}
@ -269,10 +270,10 @@ export class Graph {
* Delete all edges to a node. If `toField` is provided, only edges to that field are deleted.
* Provide the to node type as a generic to get type hints for to field names.
* @param toNode The destination node.
* @param toField The field of the destination node (optional).
* @param toFields The field of the destination node (optional).
*/
deleteEdgesTo<T extends AnyInvocation>(toNode: T, toField?: InputFields<T>): void {
for (const edge of this.getEdgesTo(toNode, toField)) {
deleteEdgesTo<T extends AnyInvocation>(toNode: T, toFields?: InputFields<T>[]): void {
for (const edge of this.getEdgesTo(toNode, toFields)) {
this._deleteEdge(edge);
}
}
@ -281,10 +282,10 @@ export class Graph {
* Delete all edges from a node. If `fromField` is provided, only edges from that field are deleted.
* Provide the from node type as a generic to get type hints for from field names.
* @param toNodeId The id of the source node.
* @param toField The field of the source node (optional).
* @param fromFields The field of the source node (optional).
*/
deleteEdgesFrom<T extends AnyInvocation>(fromNode: T, fromField?: OutputFields<T>): void {
for (const edge of this.getEdgesFrom(fromNode, fromField)) {
deleteEdgesFrom<T extends AnyInvocation>(fromNode: T, fromFields?: OutputFields<T>[]): void {
for (const edge of this.getEdgesFrom(fromNode, fromFields)) {
this._deleteEdge(edge);
}
}