mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): support multiple fields for getEdgesTo
, getEdgesFrom
, deleteEdgesTo
, deleteEdgesFrom
This commit is contained in:
parent
2be66b1546
commit
e8d3a7c870
@ -316,7 +316,7 @@ describe('Graph', () => {
|
||||
expect(g.getEdgesFrom(n3)).toEqual([e3, e4]);
|
||||
});
|
||||
it('should return the edges that start at the provided node and have the provided field', () => {
|
||||
expect(g.getEdgesFrom(n2, 'height')).toEqual([e2]);
|
||||
expect(g.getEdgesFrom(n3, ['value'])).toEqual([e3, e4]);
|
||||
});
|
||||
});
|
||||
describe('getEdgesTo', () => {
|
||||
@ -324,7 +324,7 @@ describe('Graph', () => {
|
||||
expect(g.getEdgesTo(n3)).toEqual([e1, e2]);
|
||||
});
|
||||
it('should return the edges that end at the provided node and have the provided field', () => {
|
||||
expect(g.getEdgesTo(n3, 'b')).toEqual([e2]);
|
||||
expect(g.getEdgesTo(n3, ['b', 'a'])).toEqual([e1, e2]);
|
||||
});
|
||||
});
|
||||
describe('getIncomers', () => {
|
||||
@ -372,7 +372,7 @@ describe('Graph', () => {
|
||||
const _e1 = g.addEdge(n1, 'height', n2, 'a');
|
||||
const e2 = g.addEdge(n1, 'width', n2, 'b');
|
||||
const e3 = g.addEdge(n1, 'width', n3, 'b');
|
||||
g.deleteEdgesFrom(n1, 'height');
|
||||
g.deleteEdgesFrom(n1, ['height']);
|
||||
expect(g.getEdgesFrom(n1)).toEqual([e2, e3]);
|
||||
});
|
||||
});
|
||||
@ -410,7 +410,7 @@ describe('Graph', () => {
|
||||
const _e1 = g.addEdge(n1, 'height', n3, 'a');
|
||||
const e2 = g.addEdge(n1, 'width', n3, 'b');
|
||||
const _e3 = g.addEdge(n2, 'width', n3, 'a');
|
||||
g.deleteEdgesTo(n3, 'a');
|
||||
g.deleteEdgesTo(n3, ['a']);
|
||||
expect(g.getEdgesTo(n3)).toEqual([e2]);
|
||||
});
|
||||
});
|
||||
|
@ -231,13 +231,14 @@ export class Graph {
|
||||
* Get all edges from a node. If `fromField` is provided, only edges from that field are returned.
|
||||
* Provide the from node type as a generic to get type hints for from field names.
|
||||
* @param fromNodeId The id of the source node.
|
||||
* @param fromField The field of the source node (optional).
|
||||
* @param fromFields The field of the source node (optional).
|
||||
* @returns The edges.
|
||||
*/
|
||||
getEdgesFrom<T extends AnyInvocation>(fromNode: T, fromField?: OutputFields<T>): Edge[] {
|
||||
getEdgesFrom<T extends AnyInvocation>(fromNode: T, fromFields?: OutputFields<T>[]): Edge[] {
|
||||
let edges = this._graph.edges.filter((edge) => edge.source.node_id === fromNode.id);
|
||||
if (fromField) {
|
||||
edges = edges.filter((edge) => edge.source.field === fromField);
|
||||
if (fromFields) {
|
||||
// TODO(psyche): figure out how to satisfy TS here without casting - this is _not_ an unsafe cast
|
||||
edges = edges.filter((edge) => (fromFields as AnyInvocationOutputField[]).includes(edge.source.field));
|
||||
}
|
||||
return edges;
|
||||
}
|
||||
@ -246,13 +247,13 @@ export class Graph {
|
||||
* Get all edges to a node. If `toField` is provided, only edges to that field are returned.
|
||||
* Provide the to node type as a generic to get type hints for to field names.
|
||||
* @param toNodeId The id of the destination node.
|
||||
* @param toField The field of the destination node (optional).
|
||||
* @param toFields The field of the destination node (optional).
|
||||
* @returns The edges.
|
||||
*/
|
||||
getEdgesTo<T extends AnyInvocation>(toNode: T, toField?: InputFields<T>): Edge[] {
|
||||
getEdgesTo<T extends AnyInvocation>(toNode: T, toFields?: InputFields<T>[]): Edge[] {
|
||||
let edges = this._graph.edges.filter((edge) => edge.destination.node_id === toNode.id);
|
||||
if (toField) {
|
||||
edges = edges.filter((edge) => edge.destination.field === toField);
|
||||
if (toFields) {
|
||||
edges = edges.filter((edge) => (toFields as AnyInvocationInputField[]).includes(edge.destination.field));
|
||||
}
|
||||
return edges;
|
||||
}
|
||||
@ -269,10 +270,10 @@ export class Graph {
|
||||
* Delete all edges to a node. If `toField` is provided, only edges to that field are deleted.
|
||||
* Provide the to node type as a generic to get type hints for to field names.
|
||||
* @param toNode The destination node.
|
||||
* @param toField The field of the destination node (optional).
|
||||
* @param toFields The field of the destination node (optional).
|
||||
*/
|
||||
deleteEdgesTo<T extends AnyInvocation>(toNode: T, toField?: InputFields<T>): void {
|
||||
for (const edge of this.getEdgesTo(toNode, toField)) {
|
||||
deleteEdgesTo<T extends AnyInvocation>(toNode: T, toFields?: InputFields<T>[]): void {
|
||||
for (const edge of this.getEdgesTo(toNode, toFields)) {
|
||||
this._deleteEdge(edge);
|
||||
}
|
||||
}
|
||||
@ -281,10 +282,10 @@ export class Graph {
|
||||
* Delete all edges from a node. If `fromField` is provided, only edges from that field are deleted.
|
||||
* Provide the from node type as a generic to get type hints for from field names.
|
||||
* @param toNodeId The id of the source node.
|
||||
* @param toField The field of the source node (optional).
|
||||
* @param fromFields The field of the source node (optional).
|
||||
*/
|
||||
deleteEdgesFrom<T extends AnyInvocation>(fromNode: T, fromField?: OutputFields<T>): void {
|
||||
for (const edge of this.getEdgesFrom(fromNode, fromField)) {
|
||||
deleteEdgesFrom<T extends AnyInvocation>(fromNode: T, fromFields?: OutputFields<T>[]): void {
|
||||
for (const edge of this.getEdgesFrom(fromNode, fromFields)) {
|
||||
this._deleteEdge(edge);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user