From 9359c03c3caa67a693b06ed572c77d1c2be1b688 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 1 Jan 2024 20:37:14 +1100 Subject: [PATCH] feat(ui): use zod-less workflow builder when appropriate --- .../listeners/enqueueRequestedNodes.ts | 12 ++- .../nodes/hooks/useWorkflowWatcher.ts | 4 +- .../nodes/util/workflow/buildWorkflow.ts | 94 +++++++++++++------ 3 files changed, 76 insertions(+), 34 deletions(-) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes.ts index 7fd98b890c..08d59bcc98 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes.ts @@ -1,6 +1,8 @@ import { enqueueRequested } from 'app/store/actions'; import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph'; -import { buildWorkflow } from 'features/nodes/util/workflow/buildWorkflow'; +import { + buildWorkflowRight, +} from 'features/nodes/util/workflow/buildWorkflow'; import { queueApi } from 'services/api/endpoints/queue'; import type { BatchConfig } from 'services/api/types'; @@ -15,14 +17,16 @@ export const addEnqueueRequestedNodes = () => { const { nodes, edges } = state.nodes; const workflow = state.workflow; const graph = buildNodesGraph(state.nodes); - const builtWorkflow = buildWorkflow({ + const builtWorkflow = buildWorkflowRight({ nodes, edges, workflow, }); - // embedded workflows don't have an id - delete builtWorkflow.id; + if (builtWorkflow) { + // embedded workflows don't have an id + delete builtWorkflow.id; + } const batchConfig: BatchConfig = { batch: { diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useWorkflowWatcher.ts b/invokeai/frontend/web/src/features/nodes/hooks/useWorkflowWatcher.ts index 564445df0f..3092c1e6ce 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useWorkflowWatcher.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useWorkflowWatcher.ts @@ -1,7 +1,7 @@ import { useAppSelector } from 'app/store/storeHooks'; import type { WorkflowV2 } from 'features/nodes/types/workflow'; import type { BuildWorkflowArg } from 'features/nodes/util/workflow/buildWorkflow'; -import { buildWorkflow } from 'features/nodes/util/workflow/buildWorkflow'; +import { buildWorkflowFast } from 'features/nodes/util/workflow/buildWorkflow'; import { debounce } from 'lodash-es'; import { atom } from 'nanostores'; import { useEffect } from 'react'; @@ -9,7 +9,7 @@ import { useEffect } from 'react'; export const $builtWorkflow = atom(null); const debouncedBuildWorkflow = debounce((arg: BuildWorkflowArg) => { - $builtWorkflow.set(buildWorkflow(arg)); + $builtWorkflow.set(buildWorkflowFast(arg)); }, 300); export const useWorkflowWatcher = () => { diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts index 0f18ad4d5e..320872faf2 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts @@ -2,8 +2,7 @@ import { logger } from 'app/logging/logger'; import { parseify } from 'common/util/serialize'; import type { NodesState, WorkflowsState } from 'features/nodes/store/types'; import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; -import { zWorkflowEdge, zWorkflowNode } from 'features/nodes/types/workflow'; +import { type WorkflowV2, zWorkflowV2 } from 'features/nodes/types/workflow'; import i18n from 'i18n'; import { cloneDeep, omit } from 'lodash-es'; import { fromZodError } from 'zod-validation-error'; @@ -16,14 +15,12 @@ export type BuildWorkflowArg = { export type BuildWorkflowFunction = (arg: BuildWorkflowArg) => WorkflowV2; -export const buildWorkflow: BuildWorkflowFunction = ({ +export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflow, -}) => { +}: BuildWorkflowArg): WorkflowV2 => { const clonedWorkflow = omit(cloneDeep(workflow), 'isTouched'); - const clonedNodes = cloneDeep(nodes); - const clonedEdges = cloneDeep(edges); const newWorkflow: WorkflowV2 = { ...clonedWorkflow, @@ -31,31 +28,72 @@ export const buildWorkflow: BuildWorkflowFunction = ({ edges: [], }; - clonedNodes - .filter((n) => isInvocationNode(n) || isNotesNode(n)) // Workflows only contain invocation and notes nodes - .forEach((node) => { - const result = zWorkflowNode.safeParse(node); - if (!result.success) { - const { message } = fromZodError(result.error, { - prefix: i18n.t('nodes.unableToParseNode'), - }); - logger('nodes').warn({ node: parseify(node) }, message); - return; - } - newWorkflow.nodes.push(result.data); - }); - - clonedEdges.forEach((edge) => { - const result = zWorkflowEdge.safeParse(edge); - if (!result.success) { - const { message } = fromZodError(result.error, { - prefix: i18n.t('nodes.unableToParseEdge'), + nodes.forEach((node) => { + if (isInvocationNode(node) && node.type) { + newWorkflow.nodes.push({ + id: node.id, + type: node.type, + data: cloneDeep(node.data), + position: { ...node.position }, + width: node.width, + height: node.height, + }); + } else if (isNotesNode(node) && node.type) { + newWorkflow.nodes.push({ + id: node.id, + type: node.type, + data: cloneDeep(node.data), + position: { ...node.position }, + width: node.width, + height: node.height, + }); + } + }); + + edges.forEach((edge) => { + if (edge.type === 'default' && edge.sourceHandle && edge.targetHandle) { + newWorkflow.edges.push({ + id: edge.id, + type: edge.type, + source: edge.source, + target: edge.target, + sourceHandle: edge.sourceHandle, + targetHandle: edge.targetHandle, + }); + } else if (edge.type === 'collapsed') { + newWorkflow.edges.push({ + id: edge.id, + type: edge.type, + source: edge.source, + target: edge.target, }); - logger('nodes').warn({ edge: parseify(edge) }, message); - return; } - newWorkflow.edges.push(result.data); }); return newWorkflow; }; + +export const buildWorkflowRight = ({ + nodes, + edges, + workflow, +}: BuildWorkflowArg): WorkflowV2 | null => { + const newWorkflowUnsafe = { + ...workflow, + nodes, + edges, + }; + + const result = zWorkflowV2.safeParse(newWorkflowUnsafe); + + if (!result.success) { + const { message } = fromZodError(result.error, { + prefix: i18n.t('nodes.unableToParseNode'), + }); + + logger('nodes').warn({ workflow: parseify(newWorkflowUnsafe) }, message); + return null; + } + + return result.data; +};