From ad8778df6c004246b2d7a06cc42a37f29cc45a79 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 17 May 2024 13:21:01 +1000 Subject: [PATCH] feat(ui): extract node execution state from nodesSlice This state is ephemeral and not undoable. --- .../socketio/socketGeneratorProgress.ts | 10 ++ .../socketio/socketInvocationComplete.ts | 15 ++- .../socketio/socketInvocationError.ts | 12 +++ .../socketio/socketInvocationStarted.ts | 9 ++ .../socketio/socketQueueItemStatusChanged.ts | 19 ++++ .../features/nodes/components/flow/Flow.tsx | 2 + .../InvocationNodeStatusIndicator.tsx | 13 +-- .../flow/nodes/common/NodeWrapper.tsx | 18 +--- .../inspector/InspectorOutputsTab.tsx | 16 +-- .../features/nodes/hooks/useExecutionState.ts | 56 +++++++++++ .../src/features/nodes/store/nodesSlice.ts | 97 +------------------ .../web/src/features/nodes/store/types.ts | 2 +- 12 files changed, 141 insertions(+), 128 deletions(-) create mode 100644 invokeai/frontend/web/src/features/nodes/hooks/useExecutionState.ts diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts index bb113a09ee..2dd598396a 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts @@ -1,5 +1,8 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; +import { deepClone } from 'common/util/deepClone'; +import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState'; +import { zNodeStatus } from 'features/nodes/types/invocation'; import { socketGeneratorProgress } from 'services/events/actions'; const log = logger('socketio'); @@ -9,6 +12,13 @@ export const addGeneratorProgressEventListener = (startAppListening: AppStartLis actionCreator: socketGeneratorProgress, effect: (action) => { log.trace(action.payload, `Generator progress`); + const { source_node_id, step, total_steps, progress_image } = action.payload.data; + const nes = deepClone($nodeExecutionStates.get()[source_node_id]); + if (nes) { + nes.status = zNodeStatus.enum.IN_PROGRESS; + nes.progress = (step + 1) / total_steps; + nes.progressImage = progress_image ?? null; + } }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts index fb3a4a41c9..06dc08d846 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts @@ -1,5 +1,6 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; +import { deepClone } from 'common/util/deepClone'; import { parseify } from 'common/util/serialize'; import { addImageToStagingArea } from 'features/canvas/store/canvasSlice'; import { @@ -9,7 +10,9 @@ import { isImageViewerOpenChanged, } from 'features/gallery/store/gallerySlice'; import { IMAGE_CATEGORIES } from 'features/gallery/store/types'; +import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState'; import { isImageOutput } from 'features/nodes/types/common'; +import { zNodeStatus } from 'features/nodes/types/invocation'; import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants'; import { boardsApi } from 'services/api/endpoints/boards'; import { imagesApi } from 'services/api/endpoints/images'; @@ -28,7 +31,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi const { data } = action.payload; log.debug({ data: parseify(data) }, `Invocation complete (${action.payload.data.node.type})`); - const { result, node, queue_batch_id } = data; + const { result, node, queue_batch_id, source_node_id } = data; // This complete event has an associated image output if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type)) { const { image_name } = result.image; @@ -110,6 +113,16 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi } } } + + const nes = deepClone($nodeExecutionStates.get()[source_node_id]); + if (nes) { + nes.status = zNodeStatus.enum.COMPLETED; + if (nes.progress !== null) { + nes.progress = 1; + } + nes.outputs.push(result); + upsertExecutionState(nes.nodeId, nes); + } }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts index fb898b4c7a..ce26c4dd7d 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationError.ts @@ -1,5 +1,8 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; +import { deepClone } from 'common/util/deepClone'; +import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState'; +import { zNodeStatus } from 'features/nodes/types/invocation'; import { socketInvocationError } from 'services/events/actions'; const log = logger('socketio'); @@ -9,6 +12,15 @@ export const addInvocationErrorEventListener = (startAppListening: AppStartListe actionCreator: socketInvocationError, effect: (action) => { log.error(action.payload, `Invocation error (${action.payload.data.node.type})`); + const { source_node_id } = action.payload.data; + const nes = deepClone($nodeExecutionStates.get()[source_node_id]); + if (nes) { + nes.status = zNodeStatus.enum.FAILED; + nes.error = action.payload.data.error; + nes.progress = null; + nes.progressImage = null; + upsertExecutionState(nes.nodeId, nes); + } }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted.ts index baf476a66b..9d6e0ac14d 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationStarted.ts @@ -1,5 +1,8 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; +import { deepClone } from 'common/util/deepClone'; +import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState'; +import { zNodeStatus } from 'features/nodes/types/invocation'; import { socketInvocationStarted } from 'services/events/actions'; const log = logger('socketio'); @@ -9,6 +12,12 @@ export const addInvocationStartedEventListener = (startAppListening: AppStartLis actionCreator: socketInvocationStarted, effect: (action) => { log.debug(action.payload, `Invocation started (${action.payload.data.node.type})`); + const { source_node_id } = action.payload.data; + const nes = deepClone($nodeExecutionStates.get()[source_node_id]); + if (nes) { + nes.status = zNodeStatus.enum.IN_PROGRESS; + upsertExecutionState(nes.nodeId, nes); + } }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged.ts index 84073bb427..2adc529766 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketQueueItemStatusChanged.ts @@ -1,5 +1,9 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; +import { deepClone } from 'common/util/deepClone'; +import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState'; +import { zNodeStatus } from 'features/nodes/types/invocation'; +import { forEach } from 'lodash-es'; import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue'; import { socketQueueItemStatusChanged } from 'services/events/actions'; @@ -54,6 +58,21 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening: dispatch( queueApi.util.invalidateTags(['CurrentSessionQueueItem', 'NextSessionQueueItem', 'InvocationCacheStatus']) ); + + if (['in_progress'].includes(action.payload.data.queue_item.status)) { + forEach($nodeExecutionStates.get(), (nes) => { + if (!nes) { + return; + } + const clone = deepClone(nes); + clone.status = zNodeStatus.enum.PENDING; + clone.error = null; + clone.progress = null; + clone.progressImage = null; + clone.outputs = []; + $nodeExecutionStates.setKey(clone.nodeId, clone); + }); + } }, }); }; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx index f1fcf24af2..8b33323ddd 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx @@ -3,6 +3,7 @@ import { useStore } from '@nanostores/react'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useConnection } from 'features/nodes/hooks/useConnection'; import { useCopyPaste } from 'features/nodes/hooks/useCopyPaste'; +import { useSyncExecutionState } from 'features/nodes/hooks/useExecutionState'; import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection'; import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher'; import { @@ -81,6 +82,7 @@ export const Flow = memo(() => { const isValidConnection = useIsValidConnection(); const cancelConnection = useReactFlowStore(selectCancelConnection); useWorkflowWatcher(); + useSyncExecutionState(); const [borderRadius] = useToken('radii', ['base']); const flowStyles = useMemo( diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeStatusIndicator.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeStatusIndicator.tsx index 3138cb32fe..b58f6fe8ba 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeStatusIndicator.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeStatusIndicator.tsx @@ -1,12 +1,10 @@ import type { SystemStyleObject } from '@invoke-ai/ui-library'; import { Badge, CircularProgress, Flex, Icon, Image, Text, Tooltip } from '@invoke-ai/ui-library'; -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; -import { useAppSelector } from 'app/store/storeHooks'; -import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; +import { useExecutionState } from 'features/nodes/hooks/useExecutionState'; import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; import type { NodeExecutionState } from 'features/nodes/types/invocation'; import { zNodeStatus } from 'features/nodes/types/invocation'; -import { memo, useMemo } from 'react'; +import { memo } from 'react'; import { useTranslation } from 'react-i18next'; import { PiCheckBold, PiDotsThreeOutlineFill, PiWarningBold } from 'react-icons/pi'; @@ -24,12 +22,7 @@ const circleStyles: SystemStyleObject = { }; const InvocationNodeStatusIndicator = ({ nodeId }: Props) => { - const selectNodeExecutionState = useMemo( - () => createMemoizedSelector(selectNodesSlice, (nodes) => nodes.nodeExecutionStates[nodeId]), - [nodeId] - ); - - const nodeExecutionState = useAppSelector(selectNodeExecutionState); + const nodeExecutionState = useExecutionState(nodeId); if (!nodeExecutionState) { return null; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeWrapper.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeWrapper.tsx index 51649f4f82..57426982ef 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeWrapper.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeWrapper.tsx @@ -1,14 +1,14 @@ import type { ChakraProps } from '@invoke-ai/ui-library'; import { Box, useGlobalMenuClose, useToken } from '@invoke-ai/ui-library'; -import { createSelector } from '@reduxjs/toolkit'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay'; +import { useExecutionState } from 'features/nodes/hooks/useExecutionState'; import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode'; -import { nodeExclusivelySelected, selectNodesSlice } from 'features/nodes/store/nodesSlice'; +import { nodeExclusivelySelected } from 'features/nodes/store/nodesSlice'; import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from 'features/nodes/types/constants'; import { zNodeStatus } from 'features/nodes/types/invocation'; import type { MouseEvent, PropsWithChildren } from 'react'; -import { memo, useCallback, useMemo } from 'react'; +import { memo, useCallback } from 'react'; type NodeWrapperProps = PropsWithChildren & { nodeId: string; @@ -20,16 +20,8 @@ const NodeWrapper = (props: NodeWrapperProps) => { const { nodeId, width, children, selected } = props; const { isMouseOverNode, handleMouseOut, handleMouseOver } = useMouseOverNode(nodeId); - const selectIsInProgress = useMemo( - () => - createSelector( - selectNodesSlice, - (nodes) => nodes.nodeExecutionStates[nodeId]?.status === zNodeStatus.enum.IN_PROGRESS - ), - [nodeId] - ); - - const isInProgress = useAppSelector(selectIsInProgress); + const executionState = useExecutionState(nodeId); + const isInProgress = executionState?.status === zNodeStatus.enum.IN_PROGRESS; const [nodeInProgress, shadowsXl, shadowsBase] = useToken('shadows', [ 'nodeInProgress', diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx index 17a1dd33f1..d4150243b9 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx @@ -5,6 +5,7 @@ import { useAppSelector } from 'app/store/storeHooks'; import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer'; +import { useExecutionState } from 'features/nodes/hooks/useExecutionState'; import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice'; import { selectLastSelectedNode } from 'features/nodes/store/selectors'; import { isInvocationNode } from 'features/nodes/types/invocation'; @@ -23,27 +24,26 @@ const InspectorOutputsTab = () => { const lastSelectedNode = selectLastSelectedNode(nodes); const lastSelectedNodeTemplate = lastSelectedNode ? templates[lastSelectedNode.data.type] : undefined; - const nes = nodes.nodeExecutionStates[lastSelectedNode?.id ?? '__UNKNOWN_NODE__']; - - if (!isInvocationNode(lastSelectedNode) || !nes || !lastSelectedNodeTemplate) { + if (!isInvocationNode(lastSelectedNode) || !lastSelectedNodeTemplate) { return; } return { - outputs: nes.outputs, + nodeId: lastSelectedNode.id, outputType: lastSelectedNodeTemplate.outputType, }; }), [templates] ); const data = useAppSelector(selector); + const nes = useExecutionState(data?.nodeId); const { t } = useTranslation(); - if (!data) { + if (!data || !nes) { return ; } - if (data.outputs.length === 0) { + if (nes.outputs.length === 0) { return ; } @@ -52,11 +52,11 @@ const InspectorOutputsTab = () => { {data.outputType === 'image_output' ? ( - data.outputs.map((result, i) => ( + nes.outputs.map((result, i) => ( )) ) : ( - + )} diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useExecutionState.ts b/invokeai/frontend/web/src/features/nodes/hooks/useExecutionState.ts new file mode 100644 index 0000000000..0e5dc1ac43 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/hooks/useExecutionState.ts @@ -0,0 +1,56 @@ +import { useStore } from '@nanostores/react'; +import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { useAppSelector } from 'app/store/storeHooks'; +import { deepClone } from 'common/util/deepClone'; +import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; +import type { NodeExecutionStates } from 'features/nodes/store/types'; +import type { NodeExecutionState } from 'features/nodes/types/invocation'; +import { zNodeStatus } from 'features/nodes/types/invocation'; +import { map } from 'nanostores'; +import { useEffect, useMemo } from 'react'; + +export const $nodeExecutionStates = map({}); + +const initialNodeExecutionState: Omit = { + status: zNodeStatus.enum.PENDING, + error: null, + progress: null, + progressImage: null, + outputs: [], +}; + +export const useExecutionState = (nodeId?: string) => { + const executionStates = useStore($nodeExecutionStates, nodeId ? { keys: [nodeId] } : undefined); + const executionState = useMemo(() => (nodeId ? executionStates[nodeId] : undefined), [executionStates, nodeId]); + return executionState; +}; + +const removeNodeExecutionState = (nodeId: string) => { + $nodeExecutionStates.setKey(nodeId, undefined); +}; + +export const upsertExecutionState = (nodeId: string, updates?: Partial) => { + const state = $nodeExecutionStates.get()[nodeId]; + if (!state) { + $nodeExecutionStates.setKey(nodeId, { ...deepClone(initialNodeExecutionState), nodeId, ...updates }); + } else { + $nodeExecutionStates.setKey(nodeId, { ...state, ...updates }); + } +}; + +const selectNodeIds = createMemoizedSelector(selectNodesSlice, (nodesSlice) => nodesSlice.nodes.map((node) => node.id)); + +export const useSyncExecutionState = () => { + const nodeIds = useAppSelector(selectNodeIds); + useEffect(() => { + const nodeExecutionStates = $nodeExecutionStates.get(); + const nodeIdsToAdd = nodeIds.filter((id) => !nodeExecutionStates[id]); + const nodeIdsToRemove = Object.keys(nodeExecutionStates).filter((id) => !nodeIds.includes(id)); + for (const id of nodeIdsToAdd) { + upsertExecutionState(id); + } + for (const id of nodeIdsToRemove) { + removeNodeExecutionState(id); + } + }, [nodeIds]); +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 2218530c31..644287dd29 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -1,7 +1,6 @@ import type { PayloadAction, UnknownAction } from '@reduxjs/toolkit'; import { createSlice, isAnyOf } from '@reduxjs/toolkit'; import type { PersistConfig, RootState } from 'app/store/store'; -import { deepClone } from 'common/util/deepClone'; import { workflowLoaded } from 'features/nodes/store/actions'; import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants'; import type { @@ -43,38 +42,21 @@ import { zT2IAdapterModelFieldValue, zVAEModelFieldValue, } from 'features/nodes/types/field'; -import type { AnyNode, InvocationNodeEdge, NodeExecutionState } from 'features/nodes/types/invocation'; -import { isInvocationNode, isNotesNode, zNodeStatus } from 'features/nodes/types/invocation'; -import { forEach } from 'lodash-es'; +import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation'; +import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation'; import { atom } from 'nanostores'; import type { Connection, Edge, EdgeChange, EdgeRemoveChange, Node, NodeChange, Viewport, XYPosition } from 'reactflow'; import { addEdge, applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from 'reactflow'; import type { UndoableOptions } from 'redux-undo'; -import { - socketGeneratorProgress, - socketInvocationComplete, - socketInvocationError, - socketInvocationStarted, - socketQueueItemStatusChanged, -} from 'services/events/actions'; import type { z } from 'zod'; import type { NodesState, PendingConnection, Templates } from './types'; import { findUnoccupiedPosition } from './util/findUnoccupiedPosition'; -const initialNodeExecutionState: Omit = { - status: zNodeStatus.enum.PENDING, - error: null, - progress: null, - progressImage: null, - outputs: [], -}; - const initialNodesState: NodesState = { _version: 1, nodes: [], edges: [], - nodeExecutionStates: {}, }; type FieldValueAction = PayloadAction<{ @@ -137,15 +119,6 @@ export const nodesSlice = createSlice({ ); state.nodes.push(node); - - if (!isInvocationNode(node)) { - return; - } - - state.nodeExecutionStates[node.id] = { - nodeId: node.id, - ...initialNodeExecutionState, - }; }, edgesChanged: (state, action: PayloadAction) => { state.edges = applyEdgeChanges(action.payload, state.edges); @@ -316,7 +289,6 @@ export const nodesSlice = createSlice({ if (!isInvocationNode(node)) { return; } - delete state.nodeExecutionStates[node.id]; }); }, nodeLabelChanged: (state, action: PayloadAction<{ nodeId: string; label: string }>) => { @@ -459,14 +431,6 @@ export const nodesSlice = createSlice({ state.nodes = applyNodeChanges(nodeChanges, state.nodes); state.edges = applyEdgeChanges(edgeChanges, state.edges); - - // Add node execution states for new nodes - nodes.forEach((node) => { - state.nodeExecutionStates[node.id] = { - nodeId: node.id, - ...deepClone(initialNodeExecutionState), - }; - }); }, undo: (state) => state, redo: (state) => state, @@ -485,63 +449,6 @@ export const nodesSlice = createSlice({ edges.map((edge) => ({ item: edge, type: 'add' })), [] ); - - state.nodeExecutionStates = nodes.reduce>((acc, node) => { - acc[node.id] = { - nodeId: node.id, - ...initialNodeExecutionState, - }; - return acc; - }, {}); - }); - - builder.addCase(socketInvocationStarted, (state, action) => { - const { source_node_id } = action.payload.data; - const node = state.nodeExecutionStates[source_node_id]; - if (node) { - node.status = zNodeStatus.enum.IN_PROGRESS; - } - }); - builder.addCase(socketInvocationComplete, (state, action) => { - const { source_node_id, result } = action.payload.data; - const nes = state.nodeExecutionStates[source_node_id]; - if (nes) { - nes.status = zNodeStatus.enum.COMPLETED; - if (nes.progress !== null) { - nes.progress = 1; - } - nes.outputs.push(result); - } - }); - builder.addCase(socketInvocationError, (state, action) => { - const { source_node_id } = action.payload.data; - const node = state.nodeExecutionStates[source_node_id]; - if (node) { - node.status = zNodeStatus.enum.FAILED; - node.error = action.payload.data.error; - node.progress = null; - node.progressImage = null; - } - }); - builder.addCase(socketGeneratorProgress, (state, action) => { - const { source_node_id, step, total_steps, progress_image } = action.payload.data; - const node = state.nodeExecutionStates[source_node_id]; - if (node) { - node.status = zNodeStatus.enum.IN_PROGRESS; - node.progress = (step + 1) / total_steps; - node.progressImage = progress_image ?? null; - } - }); - builder.addCase(socketQueueItemStatusChanged, (state, action) => { - if (['in_progress'].includes(action.payload.data.queue_item.status)) { - forEach(state.nodeExecutionStates, (nes) => { - nes.status = zNodeStatus.enum.PENDING; - nes.error = null; - nes.progress = null; - nes.progressImage = null; - nes.outputs = []; - }); - } }); }, }); diff --git a/invokeai/frontend/web/src/features/nodes/store/types.ts b/invokeai/frontend/web/src/features/nodes/store/types.ts index 090a967626..2f514bdb5b 100644 --- a/invokeai/frontend/web/src/features/nodes/store/types.ts +++ b/invokeai/frontend/web/src/features/nodes/store/types.ts @@ -14,6 +14,7 @@ import type { import type { WorkflowV3 } from 'features/nodes/types/workflow'; export type Templates = Record; +export type NodeExecutionStates = Record; export type PendingConnection = { node: InvocationNode; @@ -25,7 +26,6 @@ export type NodesState = { _version: 1; nodes: AnyNode[]; edges: InvocationNodeEdge[]; - nodeExecutionStates: Record; }; export type WorkflowMode = 'edit' | 'view';