feat(ui): extract node execution state from nodesSlice

This state is ephemeral and not undoable.
This commit is contained in:
psychedelicious 2024-05-17 13:21:01 +10:00
parent d2f5103f9f
commit ad8778df6c
12 changed files with 141 additions and 128 deletions

View File

@ -1,5 +1,8 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketGeneratorProgress } from 'services/events/actions';
const log = logger('socketio');
@ -9,6 +12,13 @@ export const addGeneratorProgressEventListener = (startAppListening: AppStartLis
actionCreator: socketGeneratorProgress,
effect: (action) => {
log.trace(action.payload, `Generator progress`);
const { source_node_id, step, total_steps, progress_image } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
if (nes) {
nes.status = zNodeStatus.enum.IN_PROGRESS;
nes.progress = (step + 1) / total_steps;
nes.progressImage = progress_image ?? null;
}
},
});
};

View File

@ -1,5 +1,6 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { parseify } from 'common/util/serialize';
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
import {
@ -9,7 +10,9 @@ import {
isImageViewerOpenChanged,
} from 'features/gallery/store/gallerySlice';
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { isImageOutput } from 'features/nodes/types/common';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
import { boardsApi } from 'services/api/endpoints/boards';
import { imagesApi } from 'services/api/endpoints/images';
@ -28,7 +31,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
const { data } = action.payload;
log.debug({ data: parseify(data) }, `Invocation complete (${action.payload.data.node.type})`);
const { result, node, queue_batch_id } = data;
const { result, node, queue_batch_id, source_node_id } = data;
// This complete event has an associated image output
if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type)) {
const { image_name } = result.image;
@ -110,6 +113,16 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
}
}
}
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
if (nes) {
nes.status = zNodeStatus.enum.COMPLETED;
if (nes.progress !== null) {
nes.progress = 1;
}
nes.outputs.push(result);
upsertExecutionState(nes.nodeId, nes);
}
},
});
};

View File

@ -1,5 +1,8 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketInvocationError } from 'services/events/actions';
const log = logger('socketio');
@ -9,6 +12,15 @@ export const addInvocationErrorEventListener = (startAppListening: AppStartListe
actionCreator: socketInvocationError,
effect: (action) => {
log.error(action.payload, `Invocation error (${action.payload.data.node.type})`);
const { source_node_id } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
if (nes) {
nes.status = zNodeStatus.enum.FAILED;
nes.error = action.payload.data.error;
nes.progress = null;
nes.progressImage = null;
upsertExecutionState(nes.nodeId, nes);
}
},
});
};

View File

@ -1,5 +1,8 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { socketInvocationStarted } from 'services/events/actions';
const log = logger('socketio');
@ -9,6 +12,12 @@ export const addInvocationStartedEventListener = (startAppListening: AppStartLis
actionCreator: socketInvocationStarted,
effect: (action) => {
log.debug(action.payload, `Invocation started (${action.payload.data.node.type})`);
const { source_node_id } = action.payload.data;
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
if (nes) {
nes.status = zNodeStatus.enum.IN_PROGRESS;
upsertExecutionState(nes.nodeId, nes);
}
},
});
};

View File

@ -1,5 +1,9 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { deepClone } from 'common/util/deepClone';
import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { forEach } from 'lodash-es';
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
import { socketQueueItemStatusChanged } from 'services/events/actions';
@ -54,6 +58,21 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
dispatch(
queueApi.util.invalidateTags(['CurrentSessionQueueItem', 'NextSessionQueueItem', 'InvocationCacheStatus'])
);
if (['in_progress'].includes(action.payload.data.queue_item.status)) {
forEach($nodeExecutionStates.get(), (nes) => {
if (!nes) {
return;
}
const clone = deepClone(nes);
clone.status = zNodeStatus.enum.PENDING;
clone.error = null;
clone.progress = null;
clone.progressImage = null;
clone.outputs = [];
$nodeExecutionStates.setKey(clone.nodeId, clone);
});
}
},
});
};

View File

@ -3,6 +3,7 @@ import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useConnection } from 'features/nodes/hooks/useConnection';
import { useCopyPaste } from 'features/nodes/hooks/useCopyPaste';
import { useSyncExecutionState } from 'features/nodes/hooks/useExecutionState';
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher';
import {
@ -81,6 +82,7 @@ export const Flow = memo(() => {
const isValidConnection = useIsValidConnection();
const cancelConnection = useReactFlowStore(selectCancelConnection);
useWorkflowWatcher();
useSyncExecutionState();
const [borderRadius] = useToken('radii', ['base']);
const flowStyles = useMemo<CSSProperties>(

View File

@ -1,12 +1,10 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Badge, CircularProgress, Flex, Icon, Image, Text, Tooltip } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { useExecutionState } from 'features/nodes/hooks/useExecutionState';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import type { NodeExecutionState } from 'features/nodes/types/invocation';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { memo, useMemo } from 'react';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCheckBold, PiDotsThreeOutlineFill, PiWarningBold } from 'react-icons/pi';
@ -24,12 +22,7 @@ const circleStyles: SystemStyleObject = {
};
const InvocationNodeStatusIndicator = ({ nodeId }: Props) => {
const selectNodeExecutionState = useMemo(
() => createMemoizedSelector(selectNodesSlice, (nodes) => nodes.nodeExecutionStates[nodeId]),
[nodeId]
);
const nodeExecutionState = useAppSelector(selectNodeExecutionState);
const nodeExecutionState = useExecutionState(nodeId);
if (!nodeExecutionState) {
return null;

View File

@ -1,14 +1,14 @@
import type { ChakraProps } from '@invoke-ai/ui-library';
import { Box, useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay';
import { useExecutionState } from 'features/nodes/hooks/useExecutionState';
import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
import { nodeExclusivelySelected, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { nodeExclusivelySelected } from 'features/nodes/store/nodesSlice';
import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from 'features/nodes/types/constants';
import { zNodeStatus } from 'features/nodes/types/invocation';
import type { MouseEvent, PropsWithChildren } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { memo, useCallback } from 'react';
type NodeWrapperProps = PropsWithChildren & {
nodeId: string;
@ -20,16 +20,8 @@ const NodeWrapper = (props: NodeWrapperProps) => {
const { nodeId, width, children, selected } = props;
const { isMouseOverNode, handleMouseOut, handleMouseOver } = useMouseOverNode(nodeId);
const selectIsInProgress = useMemo(
() =>
createSelector(
selectNodesSlice,
(nodes) => nodes.nodeExecutionStates[nodeId]?.status === zNodeStatus.enum.IN_PROGRESS
),
[nodeId]
);
const isInProgress = useAppSelector(selectIsInProgress);
const executionState = useExecutionState(nodeId);
const isInProgress = executionState?.status === zNodeStatus.enum.IN_PROGRESS;
const [nodeInProgress, shadowsXl, shadowsBase] = useToken('shadows', [
'nodeInProgress',

View File

@ -5,6 +5,7 @@ import { useAppSelector } from 'app/store/storeHooks';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
import { useExecutionState } from 'features/nodes/hooks/useExecutionState';
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectLastSelectedNode } from 'features/nodes/store/selectors';
import { isInvocationNode } from 'features/nodes/types/invocation';
@ -23,27 +24,26 @@ const InspectorOutputsTab = () => {
const lastSelectedNode = selectLastSelectedNode(nodes);
const lastSelectedNodeTemplate = lastSelectedNode ? templates[lastSelectedNode.data.type] : undefined;
const nes = nodes.nodeExecutionStates[lastSelectedNode?.id ?? '__UNKNOWN_NODE__'];
if (!isInvocationNode(lastSelectedNode) || !nes || !lastSelectedNodeTemplate) {
if (!isInvocationNode(lastSelectedNode) || !lastSelectedNodeTemplate) {
return;
}
return {
outputs: nes.outputs,
nodeId: lastSelectedNode.id,
outputType: lastSelectedNodeTemplate.outputType,
};
}),
[templates]
);
const data = useAppSelector(selector);
const nes = useExecutionState(data?.nodeId);
const { t } = useTranslation();
if (!data) {
if (!data || !nes) {
return <IAINoContentFallback label={t('nodes.noNodeSelected')} icon={null} />;
}
if (data.outputs.length === 0) {
if (nes.outputs.length === 0) {
return <IAINoContentFallback label={t('nodes.noOutputRecorded')} icon={null} />;
}
@ -52,11 +52,11 @@ const InspectorOutputsTab = () => {
<ScrollableContent>
<Flex position="relative" flexDir="column" alignItems="flex-start" p={1} gap={2} h="full" w="full">
{data.outputType === 'image_output' ? (
data.outputs.map((result, i) => (
nes.outputs.map((result, i) => (
<ImageOutputPreview key={getKey(result, i)} output={result as ImageOutput} />
))
) : (
<DataViewer data={data.outputs} label={t('nodes.nodeOutputs')} />
<DataViewer data={nes.outputs} label={t('nodes.nodeOutputs')} />
)}
</Flex>
</ScrollableContent>

View File

@ -0,0 +1,56 @@
import { useStore } from '@nanostores/react';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { deepClone } from 'common/util/deepClone';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import type { NodeExecutionStates } from 'features/nodes/store/types';
import type { NodeExecutionState } from 'features/nodes/types/invocation';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { map } from 'nanostores';
import { useEffect, useMemo } from 'react';
export const $nodeExecutionStates = map<NodeExecutionStates>({});
const initialNodeExecutionState: Omit<NodeExecutionState, 'nodeId'> = {
status: zNodeStatus.enum.PENDING,
error: null,
progress: null,
progressImage: null,
outputs: [],
};
export const useExecutionState = (nodeId?: string) => {
const executionStates = useStore($nodeExecutionStates, nodeId ? { keys: [nodeId] } : undefined);
const executionState = useMemo(() => (nodeId ? executionStates[nodeId] : undefined), [executionStates, nodeId]);
return executionState;
};
const removeNodeExecutionState = (nodeId: string) => {
$nodeExecutionStates.setKey(nodeId, undefined);
};
export const upsertExecutionState = (nodeId: string, updates?: Partial<NodeExecutionState>) => {
const state = $nodeExecutionStates.get()[nodeId];
if (!state) {
$nodeExecutionStates.setKey(nodeId, { ...deepClone(initialNodeExecutionState), nodeId, ...updates });
} else {
$nodeExecutionStates.setKey(nodeId, { ...state, ...updates });
}
};
const selectNodeIds = createMemoizedSelector(selectNodesSlice, (nodesSlice) => nodesSlice.nodes.map((node) => node.id));
export const useSyncExecutionState = () => {
const nodeIds = useAppSelector(selectNodeIds);
useEffect(() => {
const nodeExecutionStates = $nodeExecutionStates.get();
const nodeIdsToAdd = nodeIds.filter((id) => !nodeExecutionStates[id]);
const nodeIdsToRemove = Object.keys(nodeExecutionStates).filter((id) => !nodeIds.includes(id));
for (const id of nodeIdsToAdd) {
upsertExecutionState(id);
}
for (const id of nodeIdsToRemove) {
removeNodeExecutionState(id);
}
}, [nodeIds]);
};

View File

@ -1,7 +1,6 @@
import type { PayloadAction, UnknownAction } from '@reduxjs/toolkit';
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
import type { PersistConfig, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { workflowLoaded } from 'features/nodes/store/actions';
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
import type {
@ -43,38 +42,21 @@ import {
zT2IAdapterModelFieldValue,
zVAEModelFieldValue,
} from 'features/nodes/types/field';
import type { AnyNode, InvocationNodeEdge, NodeExecutionState } from 'features/nodes/types/invocation';
import { isInvocationNode, isNotesNode, zNodeStatus } from 'features/nodes/types/invocation';
import { forEach } from 'lodash-es';
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
import { atom } from 'nanostores';
import type { Connection, Edge, EdgeChange, EdgeRemoveChange, Node, NodeChange, Viewport, XYPosition } from 'reactflow';
import { addEdge, applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from 'reactflow';
import type { UndoableOptions } from 'redux-undo';
import {
socketGeneratorProgress,
socketInvocationComplete,
socketInvocationError,
socketInvocationStarted,
socketQueueItemStatusChanged,
} from 'services/events/actions';
import type { z } from 'zod';
import type { NodesState, PendingConnection, Templates } from './types';
import { findUnoccupiedPosition } from './util/findUnoccupiedPosition';
const initialNodeExecutionState: Omit<NodeExecutionState, 'nodeId'> = {
status: zNodeStatus.enum.PENDING,
error: null,
progress: null,
progressImage: null,
outputs: [],
};
const initialNodesState: NodesState = {
_version: 1,
nodes: [],
edges: [],
nodeExecutionStates: {},
};
type FieldValueAction<T extends FieldValue> = PayloadAction<{
@ -137,15 +119,6 @@ export const nodesSlice = createSlice({
);
state.nodes.push(node);
if (!isInvocationNode(node)) {
return;
}
state.nodeExecutionStates[node.id] = {
nodeId: node.id,
...initialNodeExecutionState,
};
},
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
state.edges = applyEdgeChanges(action.payload, state.edges);
@ -316,7 +289,6 @@ export const nodesSlice = createSlice({
if (!isInvocationNode(node)) {
return;
}
delete state.nodeExecutionStates[node.id];
});
},
nodeLabelChanged: (state, action: PayloadAction<{ nodeId: string; label: string }>) => {
@ -459,14 +431,6 @@ export const nodesSlice = createSlice({
state.nodes = applyNodeChanges(nodeChanges, state.nodes);
state.edges = applyEdgeChanges(edgeChanges, state.edges);
// Add node execution states for new nodes
nodes.forEach((node) => {
state.nodeExecutionStates[node.id] = {
nodeId: node.id,
...deepClone(initialNodeExecutionState),
};
});
},
undo: (state) => state,
redo: (state) => state,
@ -485,63 +449,6 @@ export const nodesSlice = createSlice({
edges.map((edge) => ({ item: edge, type: 'add' })),
[]
);
state.nodeExecutionStates = nodes.reduce<Record<string, NodeExecutionState>>((acc, node) => {
acc[node.id] = {
nodeId: node.id,
...initialNodeExecutionState,
};
return acc;
}, {});
});
builder.addCase(socketInvocationStarted, (state, action) => {
const { source_node_id } = action.payload.data;
const node = state.nodeExecutionStates[source_node_id];
if (node) {
node.status = zNodeStatus.enum.IN_PROGRESS;
}
});
builder.addCase(socketInvocationComplete, (state, action) => {
const { source_node_id, result } = action.payload.data;
const nes = state.nodeExecutionStates[source_node_id];
if (nes) {
nes.status = zNodeStatus.enum.COMPLETED;
if (nes.progress !== null) {
nes.progress = 1;
}
nes.outputs.push(result);
}
});
builder.addCase(socketInvocationError, (state, action) => {
const { source_node_id } = action.payload.data;
const node = state.nodeExecutionStates[source_node_id];
if (node) {
node.status = zNodeStatus.enum.FAILED;
node.error = action.payload.data.error;
node.progress = null;
node.progressImage = null;
}
});
builder.addCase(socketGeneratorProgress, (state, action) => {
const { source_node_id, step, total_steps, progress_image } = action.payload.data;
const node = state.nodeExecutionStates[source_node_id];
if (node) {
node.status = zNodeStatus.enum.IN_PROGRESS;
node.progress = (step + 1) / total_steps;
node.progressImage = progress_image ?? null;
}
});
builder.addCase(socketQueueItemStatusChanged, (state, action) => {
if (['in_progress'].includes(action.payload.data.queue_item.status)) {
forEach(state.nodeExecutionStates, (nes) => {
nes.status = zNodeStatus.enum.PENDING;
nes.error = null;
nes.progress = null;
nes.progressImage = null;
nes.outputs = [];
});
}
});
},
});

View File

@ -14,6 +14,7 @@ import type {
import type { WorkflowV3 } from 'features/nodes/types/workflow';
export type Templates = Record<string, InvocationTemplate>;
export type NodeExecutionStates = Record<string, NodeExecutionState | undefined>;
export type PendingConnection = {
node: InvocationNode;
@ -25,7 +26,6 @@ export type NodesState = {
_version: 1;
nodes: AnyNode[];
edges: InvocationNodeEdge[];
nodeExecutionStates: Record<string, NodeExecutionState>;
};
export type WorkflowMode = 'edit' | 'view';