mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): extract node execution state from nodesSlice
This state is ephemeral and not undoable.
This commit is contained in:
parent
d2f5103f9f
commit
ad8778df6c
@ -1,5 +1,8 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
|
import { deepClone } from 'common/util/deepClone';
|
||||||
|
import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState';
|
||||||
|
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||||
import { socketGeneratorProgress } from 'services/events/actions';
|
import { socketGeneratorProgress } from 'services/events/actions';
|
||||||
|
|
||||||
const log = logger('socketio');
|
const log = logger('socketio');
|
||||||
@ -9,6 +12,13 @@ export const addGeneratorProgressEventListener = (startAppListening: AppStartLis
|
|||||||
actionCreator: socketGeneratorProgress,
|
actionCreator: socketGeneratorProgress,
|
||||||
effect: (action) => {
|
effect: (action) => {
|
||||||
log.trace(action.payload, `Generator progress`);
|
log.trace(action.payload, `Generator progress`);
|
||||||
|
const { source_node_id, step, total_steps, progress_image } = action.payload.data;
|
||||||
|
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
|
||||||
|
if (nes) {
|
||||||
|
nes.status = zNodeStatus.enum.IN_PROGRESS;
|
||||||
|
nes.progress = (step + 1) / total_steps;
|
||||||
|
nes.progressImage = progress_image ?? null;
|
||||||
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
|
import { deepClone } from 'common/util/deepClone';
|
||||||
import { parseify } from 'common/util/serialize';
|
import { parseify } from 'common/util/serialize';
|
||||||
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
|
import { addImageToStagingArea } from 'features/canvas/store/canvasSlice';
|
||||||
import {
|
import {
|
||||||
@ -9,7 +10,9 @@ import {
|
|||||||
isImageViewerOpenChanged,
|
isImageViewerOpenChanged,
|
||||||
} from 'features/gallery/store/gallerySlice';
|
} from 'features/gallery/store/gallerySlice';
|
||||||
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
import { IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||||
|
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||||
import { isImageOutput } from 'features/nodes/types/common';
|
import { isImageOutput } from 'features/nodes/types/common';
|
||||||
|
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||||
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
|
import { CANVAS_OUTPUT } from 'features/nodes/util/graph/constants';
|
||||||
import { boardsApi } from 'services/api/endpoints/boards';
|
import { boardsApi } from 'services/api/endpoints/boards';
|
||||||
import { imagesApi } from 'services/api/endpoints/images';
|
import { imagesApi } from 'services/api/endpoints/images';
|
||||||
@ -28,7 +31,7 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
|||||||
const { data } = action.payload;
|
const { data } = action.payload;
|
||||||
log.debug({ data: parseify(data) }, `Invocation complete (${action.payload.data.node.type})`);
|
log.debug({ data: parseify(data) }, `Invocation complete (${action.payload.data.node.type})`);
|
||||||
|
|
||||||
const { result, node, queue_batch_id } = data;
|
const { result, node, queue_batch_id, source_node_id } = data;
|
||||||
// This complete event has an associated image output
|
// This complete event has an associated image output
|
||||||
if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type)) {
|
if (isImageOutput(result) && !nodeTypeDenylist.includes(node.type)) {
|
||||||
const { image_name } = result.image;
|
const { image_name } = result.image;
|
||||||
@ -110,6 +113,16 @@ export const addInvocationCompleteEventListener = (startAppListening: AppStartLi
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
|
||||||
|
if (nes) {
|
||||||
|
nes.status = zNodeStatus.enum.COMPLETED;
|
||||||
|
if (nes.progress !== null) {
|
||||||
|
nes.progress = 1;
|
||||||
|
}
|
||||||
|
nes.outputs.push(result);
|
||||||
|
upsertExecutionState(nes.nodeId, nes);
|
||||||
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
|
import { deepClone } from 'common/util/deepClone';
|
||||||
|
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||||
|
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||||
import { socketInvocationError } from 'services/events/actions';
|
import { socketInvocationError } from 'services/events/actions';
|
||||||
|
|
||||||
const log = logger('socketio');
|
const log = logger('socketio');
|
||||||
@ -9,6 +12,15 @@ export const addInvocationErrorEventListener = (startAppListening: AppStartListe
|
|||||||
actionCreator: socketInvocationError,
|
actionCreator: socketInvocationError,
|
||||||
effect: (action) => {
|
effect: (action) => {
|
||||||
log.error(action.payload, `Invocation error (${action.payload.data.node.type})`);
|
log.error(action.payload, `Invocation error (${action.payload.data.node.type})`);
|
||||||
|
const { source_node_id } = action.payload.data;
|
||||||
|
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
|
||||||
|
if (nes) {
|
||||||
|
nes.status = zNodeStatus.enum.FAILED;
|
||||||
|
nes.error = action.payload.data.error;
|
||||||
|
nes.progress = null;
|
||||||
|
nes.progressImage = null;
|
||||||
|
upsertExecutionState(nes.nodeId, nes);
|
||||||
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
|
import { deepClone } from 'common/util/deepClone';
|
||||||
|
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||||
|
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||||
import { socketInvocationStarted } from 'services/events/actions';
|
import { socketInvocationStarted } from 'services/events/actions';
|
||||||
|
|
||||||
const log = logger('socketio');
|
const log = logger('socketio');
|
||||||
@ -9,6 +12,12 @@ export const addInvocationStartedEventListener = (startAppListening: AppStartLis
|
|||||||
actionCreator: socketInvocationStarted,
|
actionCreator: socketInvocationStarted,
|
||||||
effect: (action) => {
|
effect: (action) => {
|
||||||
log.debug(action.payload, `Invocation started (${action.payload.data.node.type})`);
|
log.debug(action.payload, `Invocation started (${action.payload.data.node.type})`);
|
||||||
|
const { source_node_id } = action.payload.data;
|
||||||
|
const nes = deepClone($nodeExecutionStates.get()[source_node_id]);
|
||||||
|
if (nes) {
|
||||||
|
nes.status = zNodeStatus.enum.IN_PROGRESS;
|
||||||
|
upsertExecutionState(nes.nodeId, nes);
|
||||||
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -1,5 +1,9 @@
|
|||||||
import { logger } from 'app/logging/logger';
|
import { logger } from 'app/logging/logger';
|
||||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||||
|
import { deepClone } from 'common/util/deepClone';
|
||||||
|
import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState';
|
||||||
|
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||||
|
import { forEach } from 'lodash-es';
|
||||||
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
|
import { queueApi, queueItemsAdapter } from 'services/api/endpoints/queue';
|
||||||
import { socketQueueItemStatusChanged } from 'services/events/actions';
|
import { socketQueueItemStatusChanged } from 'services/events/actions';
|
||||||
|
|
||||||
@ -54,6 +58,21 @@ export const addSocketQueueItemStatusChangedEventListener = (startAppListening:
|
|||||||
dispatch(
|
dispatch(
|
||||||
queueApi.util.invalidateTags(['CurrentSessionQueueItem', 'NextSessionQueueItem', 'InvocationCacheStatus'])
|
queueApi.util.invalidateTags(['CurrentSessionQueueItem', 'NextSessionQueueItem', 'InvocationCacheStatus'])
|
||||||
);
|
);
|
||||||
|
|
||||||
|
if (['in_progress'].includes(action.payload.data.queue_item.status)) {
|
||||||
|
forEach($nodeExecutionStates.get(), (nes) => {
|
||||||
|
if (!nes) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const clone = deepClone(nes);
|
||||||
|
clone.status = zNodeStatus.enum.PENDING;
|
||||||
|
clone.error = null;
|
||||||
|
clone.progress = null;
|
||||||
|
clone.progressImage = null;
|
||||||
|
clone.outputs = [];
|
||||||
|
$nodeExecutionStates.setKey(clone.nodeId, clone);
|
||||||
|
});
|
||||||
|
}
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -3,6 +3,7 @@ import { useStore } from '@nanostores/react';
|
|||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { useConnection } from 'features/nodes/hooks/useConnection';
|
import { useConnection } from 'features/nodes/hooks/useConnection';
|
||||||
import { useCopyPaste } from 'features/nodes/hooks/useCopyPaste';
|
import { useCopyPaste } from 'features/nodes/hooks/useCopyPaste';
|
||||||
|
import { useSyncExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||||
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
|
import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection';
|
||||||
import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher';
|
import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher';
|
||||||
import {
|
import {
|
||||||
@ -81,6 +82,7 @@ export const Flow = memo(() => {
|
|||||||
const isValidConnection = useIsValidConnection();
|
const isValidConnection = useIsValidConnection();
|
||||||
const cancelConnection = useReactFlowStore(selectCancelConnection);
|
const cancelConnection = useReactFlowStore(selectCancelConnection);
|
||||||
useWorkflowWatcher();
|
useWorkflowWatcher();
|
||||||
|
useSyncExecutionState();
|
||||||
const [borderRadius] = useToken('radii', ['base']);
|
const [borderRadius] = useToken('radii', ['base']);
|
||||||
|
|
||||||
const flowStyles = useMemo<CSSProperties>(
|
const flowStyles = useMemo<CSSProperties>(
|
||||||
|
@ -1,12 +1,10 @@
|
|||||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||||
import { Badge, CircularProgress, Flex, Icon, Image, Text, Tooltip } from '@invoke-ai/ui-library';
|
import { Badge, CircularProgress, Flex, Icon, Image, Text, Tooltip } from '@invoke-ai/ui-library';
|
||||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
import { useExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
|
||||||
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
|
||||||
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
||||||
import type { NodeExecutionState } from 'features/nodes/types/invocation';
|
import type { NodeExecutionState } from 'features/nodes/types/invocation';
|
||||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||||
import { memo, useMemo } from 'react';
|
import { memo } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { PiCheckBold, PiDotsThreeOutlineFill, PiWarningBold } from 'react-icons/pi';
|
import { PiCheckBold, PiDotsThreeOutlineFill, PiWarningBold } from 'react-icons/pi';
|
||||||
|
|
||||||
@ -24,12 +22,7 @@ const circleStyles: SystemStyleObject = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const InvocationNodeStatusIndicator = ({ nodeId }: Props) => {
|
const InvocationNodeStatusIndicator = ({ nodeId }: Props) => {
|
||||||
const selectNodeExecutionState = useMemo(
|
const nodeExecutionState = useExecutionState(nodeId);
|
||||||
() => createMemoizedSelector(selectNodesSlice, (nodes) => nodes.nodeExecutionStates[nodeId]),
|
|
||||||
[nodeId]
|
|
||||||
);
|
|
||||||
|
|
||||||
const nodeExecutionState = useAppSelector(selectNodeExecutionState);
|
|
||||||
|
|
||||||
if (!nodeExecutionState) {
|
if (!nodeExecutionState) {
|
||||||
return null;
|
return null;
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
import type { ChakraProps } from '@invoke-ai/ui-library';
|
import type { ChakraProps } from '@invoke-ai/ui-library';
|
||||||
import { Box, useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
|
import { Box, useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay';
|
import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay';
|
||||||
|
import { useExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||||
import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
||||||
import { nodeExclusivelySelected, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
import { nodeExclusivelySelected } from 'features/nodes/store/nodesSlice';
|
||||||
import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from 'features/nodes/types/constants';
|
import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from 'features/nodes/types/constants';
|
||||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||||
import type { MouseEvent, PropsWithChildren } from 'react';
|
import type { MouseEvent, PropsWithChildren } from 'react';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback } from 'react';
|
||||||
|
|
||||||
type NodeWrapperProps = PropsWithChildren & {
|
type NodeWrapperProps = PropsWithChildren & {
|
||||||
nodeId: string;
|
nodeId: string;
|
||||||
@ -20,16 +20,8 @@ const NodeWrapper = (props: NodeWrapperProps) => {
|
|||||||
const { nodeId, width, children, selected } = props;
|
const { nodeId, width, children, selected } = props;
|
||||||
const { isMouseOverNode, handleMouseOut, handleMouseOver } = useMouseOverNode(nodeId);
|
const { isMouseOverNode, handleMouseOut, handleMouseOver } = useMouseOverNode(nodeId);
|
||||||
|
|
||||||
const selectIsInProgress = useMemo(
|
const executionState = useExecutionState(nodeId);
|
||||||
() =>
|
const isInProgress = executionState?.status === zNodeStatus.enum.IN_PROGRESS;
|
||||||
createSelector(
|
|
||||||
selectNodesSlice,
|
|
||||||
(nodes) => nodes.nodeExecutionStates[nodeId]?.status === zNodeStatus.enum.IN_PROGRESS
|
|
||||||
),
|
|
||||||
[nodeId]
|
|
||||||
);
|
|
||||||
|
|
||||||
const isInProgress = useAppSelector(selectIsInProgress);
|
|
||||||
|
|
||||||
const [nodeInProgress, shadowsXl, shadowsBase] = useToken('shadows', [
|
const [nodeInProgress, shadowsXl, shadowsBase] = useToken('shadows', [
|
||||||
'nodeInProgress',
|
'nodeInProgress',
|
||||||
|
@ -5,6 +5,7 @@ import { useAppSelector } from 'app/store/storeHooks';
|
|||||||
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
|
||||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||||
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
|
||||||
|
import { useExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||||
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||||
import { selectLastSelectedNode } from 'features/nodes/store/selectors';
|
import { selectLastSelectedNode } from 'features/nodes/store/selectors';
|
||||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||||
@ -23,27 +24,26 @@ const InspectorOutputsTab = () => {
|
|||||||
const lastSelectedNode = selectLastSelectedNode(nodes);
|
const lastSelectedNode = selectLastSelectedNode(nodes);
|
||||||
const lastSelectedNodeTemplate = lastSelectedNode ? templates[lastSelectedNode.data.type] : undefined;
|
const lastSelectedNodeTemplate = lastSelectedNode ? templates[lastSelectedNode.data.type] : undefined;
|
||||||
|
|
||||||
const nes = nodes.nodeExecutionStates[lastSelectedNode?.id ?? '__UNKNOWN_NODE__'];
|
if (!isInvocationNode(lastSelectedNode) || !lastSelectedNodeTemplate) {
|
||||||
|
|
||||||
if (!isInvocationNode(lastSelectedNode) || !nes || !lastSelectedNodeTemplate) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
outputs: nes.outputs,
|
nodeId: lastSelectedNode.id,
|
||||||
outputType: lastSelectedNodeTemplate.outputType,
|
outputType: lastSelectedNodeTemplate.outputType,
|
||||||
};
|
};
|
||||||
}),
|
}),
|
||||||
[templates]
|
[templates]
|
||||||
);
|
);
|
||||||
const data = useAppSelector(selector);
|
const data = useAppSelector(selector);
|
||||||
|
const nes = useExecutionState(data?.nodeId);
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
if (!data) {
|
if (!data || !nes) {
|
||||||
return <IAINoContentFallback label={t('nodes.noNodeSelected')} icon={null} />;
|
return <IAINoContentFallback label={t('nodes.noNodeSelected')} icon={null} />;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (data.outputs.length === 0) {
|
if (nes.outputs.length === 0) {
|
||||||
return <IAINoContentFallback label={t('nodes.noOutputRecorded')} icon={null} />;
|
return <IAINoContentFallback label={t('nodes.noOutputRecorded')} icon={null} />;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -52,11 +52,11 @@ const InspectorOutputsTab = () => {
|
|||||||
<ScrollableContent>
|
<ScrollableContent>
|
||||||
<Flex position="relative" flexDir="column" alignItems="flex-start" p={1} gap={2} h="full" w="full">
|
<Flex position="relative" flexDir="column" alignItems="flex-start" p={1} gap={2} h="full" w="full">
|
||||||
{data.outputType === 'image_output' ? (
|
{data.outputType === 'image_output' ? (
|
||||||
data.outputs.map((result, i) => (
|
nes.outputs.map((result, i) => (
|
||||||
<ImageOutputPreview key={getKey(result, i)} output={result as ImageOutput} />
|
<ImageOutputPreview key={getKey(result, i)} output={result as ImageOutput} />
|
||||||
))
|
))
|
||||||
) : (
|
) : (
|
||||||
<DataViewer data={data.outputs} label={t('nodes.nodeOutputs')} />
|
<DataViewer data={nes.outputs} label={t('nodes.nodeOutputs')} />
|
||||||
)}
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
</ScrollableContent>
|
</ScrollableContent>
|
||||||
|
@ -0,0 +1,56 @@
|
|||||||
|
import { useStore } from '@nanostores/react';
|
||||||
|
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||||
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
|
import { deepClone } from 'common/util/deepClone';
|
||||||
|
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
|
||||||
|
import type { NodeExecutionStates } from 'features/nodes/store/types';
|
||||||
|
import type { NodeExecutionState } from 'features/nodes/types/invocation';
|
||||||
|
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||||
|
import { map } from 'nanostores';
|
||||||
|
import { useEffect, useMemo } from 'react';
|
||||||
|
|
||||||
|
export const $nodeExecutionStates = map<NodeExecutionStates>({});
|
||||||
|
|
||||||
|
const initialNodeExecutionState: Omit<NodeExecutionState, 'nodeId'> = {
|
||||||
|
status: zNodeStatus.enum.PENDING,
|
||||||
|
error: null,
|
||||||
|
progress: null,
|
||||||
|
progressImage: null,
|
||||||
|
outputs: [],
|
||||||
|
};
|
||||||
|
|
||||||
|
export const useExecutionState = (nodeId?: string) => {
|
||||||
|
const executionStates = useStore($nodeExecutionStates, nodeId ? { keys: [nodeId] } : undefined);
|
||||||
|
const executionState = useMemo(() => (nodeId ? executionStates[nodeId] : undefined), [executionStates, nodeId]);
|
||||||
|
return executionState;
|
||||||
|
};
|
||||||
|
|
||||||
|
const removeNodeExecutionState = (nodeId: string) => {
|
||||||
|
$nodeExecutionStates.setKey(nodeId, undefined);
|
||||||
|
};
|
||||||
|
|
||||||
|
export const upsertExecutionState = (nodeId: string, updates?: Partial<NodeExecutionState>) => {
|
||||||
|
const state = $nodeExecutionStates.get()[nodeId];
|
||||||
|
if (!state) {
|
||||||
|
$nodeExecutionStates.setKey(nodeId, { ...deepClone(initialNodeExecutionState), nodeId, ...updates });
|
||||||
|
} else {
|
||||||
|
$nodeExecutionStates.setKey(nodeId, { ...state, ...updates });
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const selectNodeIds = createMemoizedSelector(selectNodesSlice, (nodesSlice) => nodesSlice.nodes.map((node) => node.id));
|
||||||
|
|
||||||
|
export const useSyncExecutionState = () => {
|
||||||
|
const nodeIds = useAppSelector(selectNodeIds);
|
||||||
|
useEffect(() => {
|
||||||
|
const nodeExecutionStates = $nodeExecutionStates.get();
|
||||||
|
const nodeIdsToAdd = nodeIds.filter((id) => !nodeExecutionStates[id]);
|
||||||
|
const nodeIdsToRemove = Object.keys(nodeExecutionStates).filter((id) => !nodeIds.includes(id));
|
||||||
|
for (const id of nodeIdsToAdd) {
|
||||||
|
upsertExecutionState(id);
|
||||||
|
}
|
||||||
|
for (const id of nodeIdsToRemove) {
|
||||||
|
removeNodeExecutionState(id);
|
||||||
|
}
|
||||||
|
}, [nodeIds]);
|
||||||
|
};
|
@ -1,7 +1,6 @@
|
|||||||
import type { PayloadAction, UnknownAction } from '@reduxjs/toolkit';
|
import type { PayloadAction, UnknownAction } from '@reduxjs/toolkit';
|
||||||
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
|
import { createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||||
import type { PersistConfig, RootState } from 'app/store/store';
|
import type { PersistConfig, RootState } from 'app/store/store';
|
||||||
import { deepClone } from 'common/util/deepClone';
|
|
||||||
import { workflowLoaded } from 'features/nodes/store/actions';
|
import { workflowLoaded } from 'features/nodes/store/actions';
|
||||||
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
|
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
|
||||||
import type {
|
import type {
|
||||||
@ -43,38 +42,21 @@ import {
|
|||||||
zT2IAdapterModelFieldValue,
|
zT2IAdapterModelFieldValue,
|
||||||
zVAEModelFieldValue,
|
zVAEModelFieldValue,
|
||||||
} from 'features/nodes/types/field';
|
} from 'features/nodes/types/field';
|
||||||
import type { AnyNode, InvocationNodeEdge, NodeExecutionState } from 'features/nodes/types/invocation';
|
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
|
||||||
import { isInvocationNode, isNotesNode, zNodeStatus } from 'features/nodes/types/invocation';
|
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
|
||||||
import { forEach } from 'lodash-es';
|
|
||||||
import { atom } from 'nanostores';
|
import { atom } from 'nanostores';
|
||||||
import type { Connection, Edge, EdgeChange, EdgeRemoveChange, Node, NodeChange, Viewport, XYPosition } from 'reactflow';
|
import type { Connection, Edge, EdgeChange, EdgeRemoveChange, Node, NodeChange, Viewport, XYPosition } from 'reactflow';
|
||||||
import { addEdge, applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from 'reactflow';
|
import { addEdge, applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from 'reactflow';
|
||||||
import type { UndoableOptions } from 'redux-undo';
|
import type { UndoableOptions } from 'redux-undo';
|
||||||
import {
|
|
||||||
socketGeneratorProgress,
|
|
||||||
socketInvocationComplete,
|
|
||||||
socketInvocationError,
|
|
||||||
socketInvocationStarted,
|
|
||||||
socketQueueItemStatusChanged,
|
|
||||||
} from 'services/events/actions';
|
|
||||||
import type { z } from 'zod';
|
import type { z } from 'zod';
|
||||||
|
|
||||||
import type { NodesState, PendingConnection, Templates } from './types';
|
import type { NodesState, PendingConnection, Templates } from './types';
|
||||||
import { findUnoccupiedPosition } from './util/findUnoccupiedPosition';
|
import { findUnoccupiedPosition } from './util/findUnoccupiedPosition';
|
||||||
|
|
||||||
const initialNodeExecutionState: Omit<NodeExecutionState, 'nodeId'> = {
|
|
||||||
status: zNodeStatus.enum.PENDING,
|
|
||||||
error: null,
|
|
||||||
progress: null,
|
|
||||||
progressImage: null,
|
|
||||||
outputs: [],
|
|
||||||
};
|
|
||||||
|
|
||||||
const initialNodesState: NodesState = {
|
const initialNodesState: NodesState = {
|
||||||
_version: 1,
|
_version: 1,
|
||||||
nodes: [],
|
nodes: [],
|
||||||
edges: [],
|
edges: [],
|
||||||
nodeExecutionStates: {},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
type FieldValueAction<T extends FieldValue> = PayloadAction<{
|
type FieldValueAction<T extends FieldValue> = PayloadAction<{
|
||||||
@ -137,15 +119,6 @@ export const nodesSlice = createSlice({
|
|||||||
);
|
);
|
||||||
|
|
||||||
state.nodes.push(node);
|
state.nodes.push(node);
|
||||||
|
|
||||||
if (!isInvocationNode(node)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
state.nodeExecutionStates[node.id] = {
|
|
||||||
nodeId: node.id,
|
|
||||||
...initialNodeExecutionState,
|
|
||||||
};
|
|
||||||
},
|
},
|
||||||
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
|
edgesChanged: (state, action: PayloadAction<EdgeChange[]>) => {
|
||||||
state.edges = applyEdgeChanges(action.payload, state.edges);
|
state.edges = applyEdgeChanges(action.payload, state.edges);
|
||||||
@ -316,7 +289,6 @@ export const nodesSlice = createSlice({
|
|||||||
if (!isInvocationNode(node)) {
|
if (!isInvocationNode(node)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
delete state.nodeExecutionStates[node.id];
|
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
nodeLabelChanged: (state, action: PayloadAction<{ nodeId: string; label: string }>) => {
|
nodeLabelChanged: (state, action: PayloadAction<{ nodeId: string; label: string }>) => {
|
||||||
@ -459,14 +431,6 @@ export const nodesSlice = createSlice({
|
|||||||
|
|
||||||
state.nodes = applyNodeChanges(nodeChanges, state.nodes);
|
state.nodes = applyNodeChanges(nodeChanges, state.nodes);
|
||||||
state.edges = applyEdgeChanges(edgeChanges, state.edges);
|
state.edges = applyEdgeChanges(edgeChanges, state.edges);
|
||||||
|
|
||||||
// Add node execution states for new nodes
|
|
||||||
nodes.forEach((node) => {
|
|
||||||
state.nodeExecutionStates[node.id] = {
|
|
||||||
nodeId: node.id,
|
|
||||||
...deepClone(initialNodeExecutionState),
|
|
||||||
};
|
|
||||||
});
|
|
||||||
},
|
},
|
||||||
undo: (state) => state,
|
undo: (state) => state,
|
||||||
redo: (state) => state,
|
redo: (state) => state,
|
||||||
@ -485,63 +449,6 @@ export const nodesSlice = createSlice({
|
|||||||
edges.map((edge) => ({ item: edge, type: 'add' })),
|
edges.map((edge) => ({ item: edge, type: 'add' })),
|
||||||
[]
|
[]
|
||||||
);
|
);
|
||||||
|
|
||||||
state.nodeExecutionStates = nodes.reduce<Record<string, NodeExecutionState>>((acc, node) => {
|
|
||||||
acc[node.id] = {
|
|
||||||
nodeId: node.id,
|
|
||||||
...initialNodeExecutionState,
|
|
||||||
};
|
|
||||||
return acc;
|
|
||||||
}, {});
|
|
||||||
});
|
|
||||||
|
|
||||||
builder.addCase(socketInvocationStarted, (state, action) => {
|
|
||||||
const { source_node_id } = action.payload.data;
|
|
||||||
const node = state.nodeExecutionStates[source_node_id];
|
|
||||||
if (node) {
|
|
||||||
node.status = zNodeStatus.enum.IN_PROGRESS;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
builder.addCase(socketInvocationComplete, (state, action) => {
|
|
||||||
const { source_node_id, result } = action.payload.data;
|
|
||||||
const nes = state.nodeExecutionStates[source_node_id];
|
|
||||||
if (nes) {
|
|
||||||
nes.status = zNodeStatus.enum.COMPLETED;
|
|
||||||
if (nes.progress !== null) {
|
|
||||||
nes.progress = 1;
|
|
||||||
}
|
|
||||||
nes.outputs.push(result);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
builder.addCase(socketInvocationError, (state, action) => {
|
|
||||||
const { source_node_id } = action.payload.data;
|
|
||||||
const node = state.nodeExecutionStates[source_node_id];
|
|
||||||
if (node) {
|
|
||||||
node.status = zNodeStatus.enum.FAILED;
|
|
||||||
node.error = action.payload.data.error;
|
|
||||||
node.progress = null;
|
|
||||||
node.progressImage = null;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
builder.addCase(socketGeneratorProgress, (state, action) => {
|
|
||||||
const { source_node_id, step, total_steps, progress_image } = action.payload.data;
|
|
||||||
const node = state.nodeExecutionStates[source_node_id];
|
|
||||||
if (node) {
|
|
||||||
node.status = zNodeStatus.enum.IN_PROGRESS;
|
|
||||||
node.progress = (step + 1) / total_steps;
|
|
||||||
node.progressImage = progress_image ?? null;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
builder.addCase(socketQueueItemStatusChanged, (state, action) => {
|
|
||||||
if (['in_progress'].includes(action.payload.data.queue_item.status)) {
|
|
||||||
forEach(state.nodeExecutionStates, (nes) => {
|
|
||||||
nes.status = zNodeStatus.enum.PENDING;
|
|
||||||
nes.error = null;
|
|
||||||
nes.progress = null;
|
|
||||||
nes.progressImage = null;
|
|
||||||
nes.outputs = [];
|
|
||||||
});
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
@ -14,6 +14,7 @@ import type {
|
|||||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||||
|
|
||||||
export type Templates = Record<string, InvocationTemplate>;
|
export type Templates = Record<string, InvocationTemplate>;
|
||||||
|
export type NodeExecutionStates = Record<string, NodeExecutionState | undefined>;
|
||||||
|
|
||||||
export type PendingConnection = {
|
export type PendingConnection = {
|
||||||
node: InvocationNode;
|
node: InvocationNode;
|
||||||
@ -25,7 +26,6 @@ export type NodesState = {
|
|||||||
_version: 1;
|
_version: 1;
|
||||||
nodes: AnyNode[];
|
nodes: AnyNode[];
|
||||||
edges: InvocationNodeEdge[];
|
edges: InvocationNodeEdge[];
|
||||||
nodeExecutionStates: Record<string, NodeExecutionState>;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
export type WorkflowMode = 'edit' | 'view';
|
export type WorkflowMode = 'edit' | 'view';
|
||||||
|
Loading…
Reference in New Issue
Block a user