mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): extract node execution state from nodesSlice
This state is ephemeral and not undoable.
This commit is contained in:
parent
d2f5103f9f
commit
ad8778df6c
@ -1,5 +1,8 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { socketGeneratorProgress } from 'services/events/actions';
|
||||
|
||||
const log = logger('socketio');
|
||||
@ -9,6 +12,13 @@ export const addGeneratorProgressEventListener = (startAppListening: AppStartLis
|
||||
actionCreator: socketGeneratorProgress,
|
||||
effect: (action) => {
|
||||
log.trace(action.payload, `Generator progress`);
|
||||
const { source_node_id, step, total_steps, progress_image } = action.payload.data;
|
||||
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
|
||||
if (nes) {
|
||||
nes.status = zNodeStatus.enum.IN_PROGRESS;
|
||||
nes.progress = (step + 1) / total_steps;
|
||||
nes.progressImage = progress_image ?? null;
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -1,5 +1,6 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
|
||||
import {
|
||||
@ -9,7 +10,9 @@ import {
|
||||
isImageViewerOpenChanged,
|
||||
} from 'features/gallery/store/gallerySlice';
|
||||
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { isImageOutput } from 'features/nodes/types/common';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
@ -28,7 +31,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
||||
const { data } = action.payload;
|
||||
log.debug({ data: parseify(data) }, `Invocation complete (${action.payload.data.node.type})`);
|
||||
|
||||
const { result, node, queue_batch_id } = data;
|
||||
const { result, node, queue_batch_id, source_node_id } = data;
|
||||
// This complete event has an associated image output
|
||||
if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type)) {
|
||||
const { image_name } = result.image;
|
||||
@ -110,6 +113,16 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
|
||||
if (nes) {
|
||||
nes.status = zNodeStatus.enum.COMPLETED;
|
||||
if (nes.progress !== null) {
|
||||
nes.progress = 1;
|
||||
}
|
||||
nes.outputs.push(result);
|
||||
upsertExecutionState(nes.nodeId, nes);
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -1,5 +1,8 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { socketInvocationError } from 'services/events/actions';
|
||||
|
||||
const log = logger('socketio');
|
||||
@ -9,6 +12,15 @@ export const addInvocationErrorEventListener = (startAppListening: AppStartListe
|
||||
actionCreator: socketInvocationError,
|
||||
effect: (action) => {
|
||||
log.error(action.payload, `Invocation error (${action.payload.data.node.type})`);
|
||||
const { source_node_id } = action.payload.data;
|
||||
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
|
||||
if (nes) {
|
||||
nes.status = zNodeStatus.enum.FAILED;
|
||||
nes.error = action.payload.data.error;
|
||||
nes.progress = null;
|
||||
nes.progressImage = null;
|
||||
upsertExecutionState(nes.nodeId, nes);
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -1,5 +1,8 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { socketInvocationStarted } from 'services/events/actions';
|
||||
|
||||
const log = logger('socketio');
|
||||
@ -9,6 +12,12 @@ export const addInvocationStartedEventListener = (startAppListening: AppStartLis
|
||||
actionCreator: socketInvocationStarted,
|
||||
effect: (action) => {
|
||||
log.debug(action.payload, `Invocation started (${action.payload.data.node.type})`);
|
||||
const { source_node_id } = action.payload.data;
|
||||
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
|
||||
if (nes) {
|
||||
nes.status = zNodeStatus.enum.IN_PROGRESS;
|
||||
upsertExecutionState(nes.nodeId, nes);
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -1,5 +1,9 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
|
||||
import { socketQueueItemStatusChanged } from 'services/events/actions';
|
||||
|
||||
@ -54,6 +58,21 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
|
||||
dispatch(
|
||||
queueApi.util.invalidateTags(['CurrentSessionQueueItem', 'NextSessionQueueItem', 'InvocationCacheStatus'])
|
||||
);
|
||||
|
||||
if (['in_progress'].includes(action.payload.data.queue_item.status)) {
|
||||
forEach($nodeExecutionStates.get(), (nes) => {
|
||||
if (!nes) {
|
||||
return;
|
||||
}
|
||||
const clone = deepClone(nes);
|
||||
clone.status = zNodeStatus.enum.PENDING;
|
||||
clone.error = null;
|
||||
clone.progress = null;
|
||||
clone.progressImage = null;
|
||||
clone.outputs = [];
|
||||
$nodeExecutionStates.setKey(clone.nodeId, clone);
|
||||
});
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
|
@ -3,6 +3,7 @@ import { useStore } from '@nanostores/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useConnection } from 'features/nodes/hooks/useConnection';
|
||||
import { useCopyPaste } from 'features/nodes/hooks/useCopyPaste';
|
||||
import { useSyncExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
|
||||
import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher';
|
||||
import {
|
||||
@ -81,6 +82,7 @@ export const Flow = memo(() => {
|
||||
const isValidConnection = useIsValidConnection();
|
||||
const cancelConnection = useReactFlowStore(selectCancelConnection);
|
||||
useWorkflowWatcher();
|
||||
useSyncExecutionState();
|
||||
const [borderRadius] = useToken('radii', ['base']);
|
||||
|
||||
const flowStyles = useMemo<CSSProperties>(
|
||||
|
@ -1,12 +1,10 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Badge, CircularProgress, Flex, Icon, Image, Text, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { useExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
||||
import type { NodeExecutionState } from 'features/nodes/types/invocation';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCheckBold, PiDotsThreeOutlineFill, PiWarningBold } from 'react-icons/pi';
|
||||
|
||||
@ -24,12 +22,7 @@ const circleStyles: SystemStyleObject = {
|
||||
};
|
||||
|
||||
const InvocationNodeStatusIndicator = ({ nodeId }: Props) => {
|
||||
const selectNodeExecutionState = useMemo(
|
||||
() => createMemoizedSelector(selectNodesSlice, (nodes) => nodes.nodeExecutionStates[nodeId]),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const nodeExecutionState = useAppSelector(selectNodeExecutionState);
|
||||
const nodeExecutionState = useExecutionState(nodeId);
|
||||
|
||||
if (!nodeExecutionState) {
|
||||
return null;
|
||||
|
@ -1,14 +1,14 @@
|
||||
import type { ChakraProps } from '@invoke-ai/ui-library';
|
||||
import { Box, useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay';
|
||||
import { useExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
||||
import { nodeExclusivelySelected, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { nodeExclusivelySelected } from 'features/nodes/store/nodesSlice';
|
||||
import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from 'features/nodes/types/constants';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import type { MouseEvent, PropsWithChildren } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
|
||||
type NodeWrapperProps = PropsWithChildren & {
|
||||
nodeId: string;
|
||||
@ -20,16 +20,8 @@ const NodeWrapper = (props: NodeWrapperProps) => {
|
||||
const { nodeId, width, children, selected } = props;
|
||||
const { isMouseOverNode, handleMouseOut, handleMouseOver } = useMouseOverNode(nodeId);
|
||||
|
||||
const selectIsInProgress = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
selectNodesSlice,
|
||||
(nodes) => nodes.nodeExecutionStates[nodeId]?.status === zNodeStatus.enum.IN_PROGRESS
|
||||
),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const isInProgress = useAppSelector(selectIsInProgress);
|
||||
const executionState = useExecutionState(nodeId);
|
||||
const isInProgress = executionState?.status === zNodeStatus.enum.IN_PROGRESS;
|
||||
|
||||
const [nodeInProgress, shadowsXl, shadowsBase] = useToken('shadows', [
|
||||
'nodeInProgress',
|
||||
|
@ -5,6 +5,7 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||
import { useExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import { selectLastSelectedNode } from 'features/nodes/store/selectors';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
@ -23,27 +24,26 @@ const InspectorOutputsTab = () => {
|
||||
const lastSelectedNode = selectLastSelectedNode(nodes);
|
||||
const lastSelectedNodeTemplate = lastSelectedNode ? templates[lastSelectedNode.data.type] : undefined;
|
||||
|
||||
const nes = nodes.nodeExecutionStates[lastSelectedNode?.id ?? '__UNKNOWN_NODE__'];
|
||||
|
||||
if (!isInvocationNode(lastSelectedNode) || !nes || !lastSelectedNodeTemplate) {
|
||||
if (!isInvocationNode(lastSelectedNode) || !lastSelectedNodeTemplate) {
|
||||
return;
|
||||
}
|
||||
|
||||
return {
|
||||
outputs: nes.outputs,
|
||||
nodeId: lastSelectedNode.id,
|
||||
outputType: lastSelectedNodeTemplate.outputType,
|
||||
};
|
||||
}),
|
||||
[templates]
|
||||
);
|
||||
const data = useAppSelector(selector);
|
||||
const nes = useExecutionState(data?.nodeId);
|
||||
const { t } = useTranslation();
|
||||
|
||||
if (!data) {
|
||||
if (!data || !nes) {
|
||||
return <IAINoContentFallback label={t('nodes.noNodeSelected')} icon={null} />;
|
||||
}
|
||||
|
||||
if (data.outputs.length === 0) {
|
||||
if (nes.outputs.length === 0) {
|
||||
return <IAINoContentFallback label={t('nodes.noOutputRecorded')} icon={null} />;
|
||||
}
|
||||
|
||||
@ -52,11 +52,11 @@ const InspectorOutputsTab = () => {
|
||||
<ScrollableContent>
|
||||
<Flex position="relative" flexDir="column" alignItems="flex-start" p={1} gap={2} h="full" w="full">
|
||||
{data.outputType === 'image_output' ? (
|
||||
data.outputs.map((result, i) => (
|
||||
nes.outputs.map((result, i) => (
|
||||
<ImageOutputPreview key={getKey(result, i)} output={result as ImageOutput} />
|
||||
))
|
||||
) : (
|
||||
<DataViewer data={data.outputs} label={t('nodes.nodeOutputs')} />
|
||||
<DataViewer data={nes.outputs} label={t('nodes.nodeOutputs')} />
|
||||
)}
|
||||
</Flex>
|
||||
</ScrollableContent>
|
||||
|
@ -0,0 +1,56 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||
import type { NodeExecutionStates } from 'features/nodes/store/types';
|
||||
import type { NodeExecutionState } from 'features/nodes/types/invocation';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { map } from 'nanostores';
|
||||
import { useEffect, useMemo } from 'react';
|
||||
|
||||
export const $nodeExecutionStates = map<NodeExecutionStates>({});
|
||||
|
||||
const initialNodeExecutionState: Omit<NodeExecutionState, 'nodeId'> = {
|
||||
status: zNodeStatus.enum.PENDING,
|
||||
error: null,
|
||||
progress: null,
|
||||
progressImage: null,
|
||||
outputs: [],
|
||||
};
|
||||
|
||||
export const useExecutionState = (nodeId?: string) => {
|
||||
const executionStates = useStore($nodeExecutionStates, nodeId ? { keys: [nodeId] } : undefined);
|
||||
const executionState = useMemo(() => (nodeId ? executionStates[nodeId] : undefined), [executionStates, nodeId]);
|
||||
return executionState;
|
||||
};
|
||||
|
||||
const removeNodeExecutionState = (nodeId: string) => {
|
||||
$nodeExecutionStates.setKey(nodeId, undefined);
|
||||
};
|
||||
|
||||
export const upsertExecutionState = (nodeId: string, updates?: Partial<NodeExecutionState>) => {
|
||||
const state = $nodeExecutionStates.get()[nodeId];
|
||||
if (!state) {
|
||||
$nodeExecutionStates.setKey(nodeId, { ...deepClone(initialNodeExecutionState), nodeId, ...updates });
|
||||
} else {
|
||||
$nodeExecutionStates.setKey(nodeId, { ...state, ...updates });
|
||||
}
|
||||
};
|
||||
|
||||
const selectNodeIds = createMemoizedSelector(selectNodesSlice, (nodesSlice) => nodesSlice.nodes.map((node) => node.id));
|
||||
|
||||
export const useSyncExecutionState = () => {
|
||||
const nodeIds = useAppSelector(selectNodeIds);
|
||||
useEffect(() => {
|
||||
const nodeExecutionStates = $nodeExecutionStates.get();
|
||||
const nodeIdsToAdd = nodeIds.filter((id) => !nodeExecutionStates[id]);
|
||||
const nodeIdsToRemove = Object.keys(nodeExecutionStates).filter((id) => !nodeIds.includes(id));
|
||||
for (const id of nodeIdsToAdd) {
|
||||
upsertExecutionState(id);
|
||||
}
|
||||
for (const id of nodeIdsToRemove) {
|
||||
removeNodeExecutionState(id);
|
||||
}
|
||||
}, [nodeIds]);
|
||||
};
|
@ -1,7 +1,6 @@
|
||||
import type { PayloadAction, UnknownAction } from '@reduxjs/toolkit';
|
||||
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { workflowLoaded } from 'features/nodes/store/actions';
|
||||
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
|
||||
import type {
|
||||
@ -43,38 +42,21 @@ import {
|
||||
zT2IAdapterModelFieldValue,
|
||||
zVAEModelFieldValue,
|
||||
} from 'features/nodes/types/field';
|
||||
import type { AnyNode, InvocationNodeEdge, NodeExecutionState } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode, isNotesNode, zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { forEach } from 'lodash-es';
|
||||
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
|
||||
import { atom } from 'nanostores';
|
||||
import type { Connection, Edge, EdgeChange, EdgeRemoveChange, Node, NodeChange, Viewport, XYPosition } from 'reactflow';
|
||||
import { addEdge, applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from 'reactflow';
|
||||
import type { UndoableOptions } from 'redux-undo';
|
||||
import {
|
||||
socketGeneratorProgress,
|
||||
socketInvocationComplete,
|
||||
socketInvocationError,
|
||||
socketInvocationStarted,
|
||||
socketQueueItemStatusChanged,
|
||||
} from 'services/events/actions';
|
||||
import type { z } from 'zod';
|
||||
|
||||
import type { NodesState, PendingConnection, Templates } from './types';
|
||||
import { findUnoccupiedPosition } from './util/findUnoccupiedPosition';
|
||||
|
||||
const initialNodeExecutionState: Omit<NodeExecutionState, 'nodeId'> = {
|
||||
status: zNodeStatus.enum.PENDING,
|
||||
error: null,
|
||||
progress: null,
|
||||
progressImage: null,
|
||||
outputs: [],
|
||||
};
|
||||
|
||||
const initialNodesState: NodesState = {
|
||||
_version: 1,
|
||||
nodes: [],
|
||||
edges: [],
|
||||
nodeExecutionStates: {},
|
||||
};
|
||||
|
||||
type FieldValueAction<T extends FieldValue> = PayloadAction<{
|
||||
@ -137,15 +119,6 @@ export const nodesSlice = createSlice({
|
||||
);
|
||||
|
||||
state.nodes.push(node);
|
||||
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
|
||||
state.nodeExecutionStates[node.id] = {
|
||||
nodeId: node.id,
|
||||
...initialNodeExecutionState,
|
||||
};
|
||||
},
|
||||
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
|
||||
state.edges = applyEdgeChanges(action.payload, state.edges);
|
||||
@ -316,7 +289,6 @@ export const nodesSlice = createSlice({
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
delete state.nodeExecutionStates[node.id];
|
||||
});
|
||||
},
|
||||
nodeLabelChanged: (state, action: PayloadAction<{ nodeId: string; label: string }>) => {
|
||||
@ -459,14 +431,6 @@ export const nodesSlice = createSlice({
|
||||
|
||||
state.nodes = applyNodeChanges(nodeChanges, state.nodes);
|
||||
state.edges = applyEdgeChanges(edgeChanges, state.edges);
|
||||
|
||||
// Add node execution states for new nodes
|
||||
nodes.forEach((node) => {
|
||||
state.nodeExecutionStates[node.id] = {
|
||||
nodeId: node.id,
|
||||
...deepClone(initialNodeExecutionState),
|
||||
};
|
||||
});
|
||||
},
|
||||
undo: (state) => state,
|
||||
redo: (state) => state,
|
||||
@ -485,63 +449,6 @@ export const nodesSlice = createSlice({
|
||||
edges.map((edge) => ({ item: edge, type: 'add' })),
|
||||
[]
|
||||
);
|
||||
|
||||
state.nodeExecutionStates = nodes.reduce<Record<string, NodeExecutionState>>((acc, node) => {
|
||||
acc[node.id] = {
|
||||
nodeId: node.id,
|
||||
...initialNodeExecutionState,
|
||||
};
|
||||
return acc;
|
||||
}, {});
|
||||
});
|
||||
|
||||
builder.addCase(socketInvocationStarted, (state, action) => {
|
||||
const { source_node_id } = action.payload.data;
|
||||
const node = state.nodeExecutionStates[source_node_id];
|
||||
if (node) {
|
||||
node.status = zNodeStatus.enum.IN_PROGRESS;
|
||||
}
|
||||
});
|
||||
builder.addCase(socketInvocationComplete, (state, action) => {
|
||||
const { source_node_id, result } = action.payload.data;
|
||||
const nes = state.nodeExecutionStates[source_node_id];
|
||||
if (nes) {
|
||||
nes.status = zNodeStatus.enum.COMPLETED;
|
||||
if (nes.progress !== null) {
|
||||
nes.progress = 1;
|
||||
}
|
||||
nes.outputs.push(result);
|
||||
}
|
||||
});
|
||||
builder.addCase(socketInvocationError, (state, action) => {
|
||||
const { source_node_id } = action.payload.data;
|
||||
const node = state.nodeExecutionStates[source_node_id];
|
||||
if (node) {
|
||||
node.status = zNodeStatus.enum.FAILED;
|
||||
node.error = action.payload.data.error;
|
||||
node.progress = null;
|
||||
node.progressImage = null;
|
||||
}
|
||||
});
|
||||
builder.addCase(socketGeneratorProgress, (state, action) => {
|
||||
const { source_node_id, step, total_steps, progress_image } = action.payload.data;
|
||||
const node = state.nodeExecutionStates[source_node_id];
|
||||
if (node) {
|
||||
node.status = zNodeStatus.enum.IN_PROGRESS;
|
||||
node.progress = (step + 1) / total_steps;
|
||||
node.progressImage = progress_image ?? null;
|
||||
}
|
||||
});
|
||||
builder.addCase(socketQueueItemStatusChanged, (state, action) => {
|
||||
if (['in_progress'].includes(action.payload.data.queue_item.status)) {
|
||||
forEach(state.nodeExecutionStates, (nes) => {
|
||||
nes.status = zNodeStatus.enum.PENDING;
|
||||
nes.error = null;
|
||||
nes.progress = null;
|
||||
nes.progressImage = null;
|
||||
nes.outputs = [];
|
||||
});
|
||||
}
|
||||
});
|
||||
},
|
||||
});
|
||||
|
@ -14,6 +14,7 @@ import type {
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
|
||||
export type Templates = Record<string, InvocationTemplate>;
|
||||
export type NodeExecutionStates = Record<string, NodeExecutionState | undefined>;
|
||||
|
||||
export type PendingConnection = {
|
||||
node: InvocationNode;
|
||||
@ -25,7 +26,6 @@ export type NodesState = {
|
||||
_version: 1;
|
||||
nodes: AnyNode[];
|
||||
edges: InvocationNodeEdge[];
|
||||
nodeExecutionStates: Record<string, NodeExecutionState>;
|
||||
};
|
||||
|
||||
export type WorkflowMode = 'edit' | 'view';
|
||||
|
Loading…
Reference in New Issue
Block a user