From ce2ad5903c8878d80f400028d8e3f69690eced98 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 09:09:42 +1000 Subject: [PATCH] feat(ui): extract logic for finding candidate fields to own function --- .../store/util/getFirstValidConnection.ts | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts index 00899c065d..1449a3298a 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts @@ -116,3 +116,75 @@ export const getFirstValidConnection = ( return null; }; + +export const getTargetCandidateFields = ( + source: string, + sourceHandle: string, + target: string, + nodes: AnyNode[], + edges: Edge[], + templates: Templates, + edgePendingUpdate: Edge | null +): FieldInputTemplate[] => { + const sourceNode = nodes.find((n) => n.id === source); + const targetNode = nodes.find((n) => n.id === target); + if (!sourceNode || !targetNode) { + return []; + } + + const sourceTemplate = templates[sourceNode.data.type]; + const targetTemplate = templates[targetNode.data.type]; + if (!sourceTemplate || !targetTemplate) { + return []; + } + + const sourceField = sourceTemplate.outputs[sourceHandle]; + + if (!sourceField) { + return []; + } + + const targetCandidateFields = map(targetTemplate.inputs).filter((field) => { + const c = { source, sourceHandle, target, targetHandle: field.name }; + const r = validateConnection(c, nodes, edges, templates, edgePendingUpdate, true); + return r.isValid; + }); + + return targetCandidateFields; +}; + +export const getSourceCandidateFields = ( + target: string, + targetHandle: string, + source: string, + nodes: AnyNode[], + edges: Edge[], + templates: Templates, + edgePendingUpdate: Edge | null +): FieldOutputTemplate[] => { + const targetNode = nodes.find((n) => n.id === target); + const sourceNode = nodes.find((n) => n.id === source); + if (!sourceNode || !targetNode) { + return []; + } + + const sourceTemplate = templates[sourceNode.data.type]; + const targetTemplate = templates[targetNode.data.type]; + if (!sourceTemplate || !targetTemplate) { + return []; + } + + const targetField = targetTemplate.inputs[targetHandle]; + + if (!targetField) { + return []; + } + + const sourceCandidateFields = map(sourceTemplate.outputs).filter((field) => { + const c = { source, sourceHandle: field.name, target, targetHandle }; + const r = validateConnection(c, nodes, edges, templates, edgePendingUpdate, true); + return r.isValid; + }); + + return sourceCandidateFields; +};