feat(ui): store node templates in separate slice

Flattens the `nodes` slice. May offer minor perf improvements in addition to just being cleaner.
This commit is contained in:
psychedelicious 2024-01-01 12:27:05 +11:00 committed by Kent Keirsey
parent 7c548c5bf3
commit 5d4610d981
33 changed files with 200 additions and 167 deletions

View File

@ -1,6 +1,6 @@
import type { UnknownAction } from '@reduxjs/toolkit'; import type { UnknownAction } from '@reduxjs/toolkit';
import { isAnyGraphBuilt } from 'features/nodes/store/actions'; import { isAnyGraphBuilt } from 'features/nodes/store/actions';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice'; import { nodeTemplatesBuilt } from 'features/nodes/store/nodeTemplatesSlice';
import { receivedOpenAPISchema } from 'services/api/thunks/schema'; import { receivedOpenAPISchema } from 'services/api/thunks/schema';
import type { Graph } from 'services/api/types'; import type { Graph } from 'services/api/types';

View File

@ -1,6 +1,6 @@
import { logger } from 'app/logging/logger'; import { logger } from 'app/logging/logger';
import { parseify } from 'common/util/serialize'; import { parseify } from 'common/util/serialize';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice'; import { nodeTemplatesBuilt } from 'features/nodes/store/nodeTemplatesSlice';
import { parseSchema } from 'features/nodes/util/schema/parseSchema'; import { parseSchema } from 'features/nodes/util/schema/parseSchema';
import { size } from 'lodash-es'; import { size } from 'lodash-es';
import { receivedOpenAPISchema } from 'services/api/thunks/schema'; import { receivedOpenAPISchema } from 'services/api/thunks/schema';

View File

@ -15,11 +15,11 @@ export const addSocketConnectedEventListener = () => {
log.debug('Connected'); log.debug('Connected');
const { nodes, config, system } = getState(); const { nodeTemplates, config, system } = getState();
const { disabledTabs } = config; const { disabledTabs } = config;
if (!size(nodes.nodeTemplates) && !disabledTabs.includes('nodes')) { if (!size(nodeTemplates.templates) && !disabledTabs.includes('nodes')) {
dispatch(receivedOpenAPISchema()); dispatch(receivedOpenAPISchema());
} }

View File

@ -19,7 +19,7 @@ export const addUpdateAllNodesRequestedListener = () => {
effect: (action, { dispatch, getState }) => { effect: (action, { dispatch, getState }) => {
const log = logger('nodes'); const log = logger('nodes');
const nodes = getState().nodes.nodes; const nodes = getState().nodes.nodes;
const templates = getState().nodes.nodeTemplates; const templates = getState().nodeTemplates.templates;
let unableToUpdateCount = 0; let unableToUpdateCount = 0;

View File

@ -25,7 +25,7 @@ export const addWorkflowLoadRequestedListener = () => {
effect: (action, { dispatch, getState }) => { effect: (action, { dispatch, getState }) => {
const log = logger('nodes'); const log = logger('nodes');
const { workflow, asCopy } = action.payload; const { workflow, asCopy } = action.payload;
const nodeTemplates = getState().nodes.nodeTemplates; const nodeTemplates = getState().nodeTemplates.templates;
try { try {
const { workflow: validatedWorkflow, warnings } = validateWorkflow( const { workflow: validatedWorkflow, warnings } = validateWorkflow(

View File

@ -14,6 +14,7 @@ import hrfReducer from 'features/hrf/store/hrfSlice';
import loraReducer from 'features/lora/store/loraSlice'; import loraReducer from 'features/lora/store/loraSlice';
import modelmanagerReducer from 'features/modelManager/store/modelManagerSlice'; import modelmanagerReducer from 'features/modelManager/store/modelManagerSlice';
import nodesReducer from 'features/nodes/store/nodesSlice'; import nodesReducer from 'features/nodes/store/nodesSlice';
import nodeTemplatesReducer from 'features/nodes/store/nodeTemplatesSlice';
import workflowReducer from 'features/nodes/store/workflowSlice'; import workflowReducer from 'features/nodes/store/workflowSlice';
import generationReducer from 'features/parameters/store/generationSlice'; import generationReducer from 'features/parameters/store/generationSlice';
import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice';
@ -42,6 +43,7 @@ const allReducers = {
gallery: galleryReducer, gallery: galleryReducer,
generation: generationReducer, generation: generationReducer,
nodes: nodesReducer, nodes: nodesReducer,
nodeTemplates: nodeTemplatesReducer,
postprocessing: postprocessingReducer, postprocessing: postprocessingReducer,
system: systemReducer, system: systemReducer,
config: configReducer, config: configReducer,

View File

@ -12,7 +12,14 @@ import { getConnectedEdges } from 'reactflow';
const selector = createMemoizedSelector( const selector = createMemoizedSelector(
[stateSelector, activeTabNameSelector], [stateSelector, activeTabNameSelector],
( (
{ controlAdapters, generation, system, nodes, dynamicPrompts }, {
controlAdapters,
generation,
system,
nodes,
nodeTemplates,
dynamicPrompts,
},
activeTabName activeTabName
) => { ) => {
const { initialImage, model } = generation; const { initialImage, model } = generation;
@ -41,7 +48,7 @@ const selector = createMemoizedSelector(
return; return;
} }
const nodeTemplate = nodes.nodeTemplates[node.data.type]; const nodeTemplate = nodeTemplates.templates[node.data.type];
if (!nodeTemplate) { if (!nodeTemplate) {
// Node type not found // Node type not found

View File

@ -74,57 +74,60 @@ const AddNodePopover = () => {
(state) => state.nodes.connectionStartParams?.handleType (state) => state.nodes.connectionStartParams?.handleType
); );
const selector = createMemoizedSelector([stateSelector], ({ nodes }) => { const selector = createMemoizedSelector(
// If we have a connection in progress, we need to filter the node choices [stateSelector],
const filteredNodeTemplates = fieldFilter ({ nodeTemplates }) => {
? filter(nodes.nodeTemplates, (template) => { // If we have a connection in progress, we need to filter the node choices
const handles = const filteredNodeTemplates = fieldFilter
handleFilter == 'source' ? template.inputs : template.outputs; ? filter(nodeTemplates.templates, (template) => {
const handles =
handleFilter == 'source' ? template.inputs : template.outputs;
return some(handles, (handle) => { return some(handles, (handle) => {
const sourceType = const sourceType =
handleFilter == 'source' ? fieldFilter : handle.type; handleFilter == 'source' ? fieldFilter : handle.type;
const targetType = const targetType =
handleFilter == 'target' ? fieldFilter : handle.type; handleFilter == 'target' ? fieldFilter : handle.type;
return validateSourceAndTargetTypes(sourceType, targetType); return validateSourceAndTargetTypes(sourceType, targetType);
}); });
}) })
: map(nodes.nodeTemplates); : map(nodeTemplates.templates);
const options: InvSelectOption[] = map( const options: InvSelectOption[] = map(
filteredNodeTemplates, filteredNodeTemplates,
(template) => { (template) => {
return { return {
label: template.title, label: template.title,
value: template.type, value: template.type,
description: template.description, description: template.description,
tags: template.tags, tags: template.tags,
}; };
}
);
//We only want these nodes if we're not filtered
if (fieldFilter === null) {
options.push({
label: t('nodes.currentImage'),
value: 'current_image',
description: t('nodes.currentImageDescription'),
tags: ['progress'],
});
options.push({
label: t('nodes.notes'),
value: 'notes',
description: t('nodes.notesDescription'),
tags: ['notes'],
});
} }
);
//We only want these nodes if we're not filtered options.sort((a, b) => a.label.localeCompare(b.label));
if (fieldFilter === null) {
options.push({
label: t('nodes.currentImage'),
value: 'current_image',
description: t('nodes.currentImageDescription'),
tags: ['progress'],
});
options.push({ return { options };
label: t('nodes.notes'),
value: 'notes',
description: t('nodes.notesDescription'),
tags: ['notes'],
});
} }
);
options.sort((a, b) => a.label.localeCompare(b.label));
return { options };
});
const { options } = useAppSelector(selector); const { options } = useAppSelector(selector);
const isOpen = useAppSelector((state) => state.nodes.isAddNodePopoverOpen); const isOpen = useAppSelector((state) => state.nodes.isAddNodePopoverOpen);

View File

@ -14,8 +14,8 @@ const InvocationNodeWrapper = (props: NodeProps<InvocationNodeData>) => {
const hasTemplateSelector = useMemo( const hasTemplateSelector = useMemo(
() => () =>
createMemoizedSelector(stateSelector, ({ nodes }) => createMemoizedSelector(stateSelector, ({ nodeTemplates }) =>
Boolean(nodes.nodeTemplates[type]) Boolean(nodeTemplates.templates[type])
), ),
[type] [type]
); );

View File

@ -14,28 +14,31 @@ import { useTranslation } from 'react-i18next';
import EditableNodeTitle from './details/EditableNodeTitle'; import EditableNodeTitle from './details/EditableNodeTitle';
const selector = createMemoizedSelector(stateSelector, ({ nodes }) => { const selector = createMemoizedSelector(
const lastSelectedNodeId = stateSelector,
nodes.selectedNodes[nodes.selectedNodes.length - 1]; ({ nodes, nodeTemplates }) => {
const lastSelectedNodeId =
nodes.selectedNodes[nodes.selectedNodes.length - 1];
const lastSelectedNode = nodes.nodes.find( const lastSelectedNode = nodes.nodes.find(
(node) => node.id === lastSelectedNodeId (node) => node.id === lastSelectedNodeId
); );
const lastSelectedNodeTemplate = lastSelectedNode const lastSelectedNodeTemplate = lastSelectedNode
? nodes.nodeTemplates[lastSelectedNode.data.type] ? nodeTemplates.templates[lastSelectedNode.data.type]
: undefined; : undefined;
if (!isInvocationNode(lastSelectedNode) || !lastSelectedNodeTemplate) { if (!isInvocationNode(lastSelectedNode) || !lastSelectedNodeTemplate) {
return; return;
}
return {
nodeId: lastSelectedNode.data.id,
nodeVersion: lastSelectedNode.data.version,
templateTitle: lastSelectedNodeTemplate.title,
};
} }
);
return {
nodeId: lastSelectedNode.data.id,
nodeVersion: lastSelectedNode.data.version,
templateTitle: lastSelectedNodeTemplate.title,
};
});
const InspectorDetailsTab = () => { const InspectorDetailsTab = () => {
const data = useAppSelector(selector); const data = useAppSelector(selector);

View File

@ -13,34 +13,37 @@ import type { AnyResult } from 'services/events/types';
import ImageOutputPreview from './outputs/ImageOutputPreview'; import ImageOutputPreview from './outputs/ImageOutputPreview';
const selector = createMemoizedSelector(stateSelector, ({ nodes }) => { const selector = createMemoizedSelector(
const lastSelectedNodeId = stateSelector,
nodes.selectedNodes[nodes.selectedNodes.length - 1]; ({ nodes, nodeTemplates }) => {
const lastSelectedNodeId =
nodes.selectedNodes[nodes.selectedNodes.length - 1];
const lastSelectedNode = nodes.nodes.find( const lastSelectedNode = nodes.nodes.find(
(node) => node.id === lastSelectedNodeId (node) => node.id === lastSelectedNodeId
); );
const lastSelectedNodeTemplate = lastSelectedNode const lastSelectedNodeTemplate = lastSelectedNode
? nodes.nodeTemplates[lastSelectedNode.data.type] ? nodeTemplates.templates[lastSelectedNode.data.type]
: undefined; : undefined;
const nes = const nes =
nodes.nodeExecutionStates[lastSelectedNodeId ?? '__UNKNOWN_NODE__']; nodes.nodeExecutionStates[lastSelectedNodeId ?? '__UNKNOWN_NODE__'];
if ( if (
!isInvocationNode(lastSelectedNode) || !isInvocationNode(lastSelectedNode) ||
!nes || !nes ||
!lastSelectedNodeTemplate !lastSelectedNodeTemplate
) { ) {
return; return;
}
return {
outputs: nes.outputs,
outputType: lastSelectedNodeTemplate.outputType,
};
} }
);
return {
outputs: nes.outputs,
outputType: lastSelectedNodeTemplate.outputType,
};
});
const InspectorOutputsTab = () => { const InspectorOutputsTab = () => {
const data = useAppSelector(selector); const data = useAppSelector(selector);

View File

@ -6,22 +6,25 @@ import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataView
import { memo } from 'react'; import { memo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
const selector = createMemoizedSelector(stateSelector, ({ nodes }) => { const selector = createMemoizedSelector(
const lastSelectedNodeId = stateSelector,
nodes.selectedNodes[nodes.selectedNodes.length - 1]; ({ nodes, nodeTemplates }) => {
const lastSelectedNodeId =
nodes.selectedNodes[nodes.selectedNodes.length - 1];
const lastSelectedNode = nodes.nodes.find( const lastSelectedNode = nodes.nodes.find(
(node) => node.id === lastSelectedNodeId (node) => node.id === lastSelectedNodeId
); );
const lastSelectedNodeTemplate = lastSelectedNode const lastSelectedNodeTemplate = lastSelectedNode
? nodes.nodeTemplates[lastSelectedNode.data.type] ? nodeTemplates.templates[lastSelectedNode.data.type]
: undefined; : undefined;
return { return {
template: lastSelectedNodeTemplate, template: lastSelectedNodeTemplate,
}; };
}); }
);
const NodeTemplateInspector = () => { const NodeTemplateInspector = () => {
const { template } = useAppSelector(selector); const { template } = useAppSelector(selector);

View File

@ -10,12 +10,12 @@ import { useMemo } from 'react';
export const useAnyOrDirectInputFieldNames = (nodeId: string) => { export const useAnyOrDirectInputFieldNames = (nodeId: string) => {
const selector = useMemo( const selector = useMemo(
() => () =>
createMemoizedSelector(stateSelector, ({ nodes }) => { createMemoizedSelector(stateSelector, ({ nodes, nodeTemplates }) => {
const node = nodes.nodes.find((node) => node.id === nodeId); const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) { if (!isInvocationNode(node)) {
return []; return [];
} }
const nodeTemplate = nodes.nodeTemplates[node.data.type]; const nodeTemplate = nodeTemplates.templates[node.data.type];
if (!nodeTemplate) { if (!nodeTemplate) {
return []; return [];
} }

View File

@ -1,5 +1,3 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import type { RootState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { import {
DRAG_HANDLE_CLASSNAME, DRAG_HANDLE_CLASSNAME,
@ -16,17 +14,14 @@ import { useCallback } from 'react';
import type { Node } from 'reactflow'; import type { Node } from 'reactflow';
import { useReactFlow } from 'reactflow'; import { useReactFlow } from 'reactflow';
const templatesSelector = createMemoizedSelector(
[(state: RootState) => state.nodes],
(nodes) => nodes.nodeTemplates
);
export const SHARED_NODE_PROPERTIES: Partial<Node> = { export const SHARED_NODE_PROPERTIES: Partial<Node> = {
dragHandle: `.${DRAG_HANDLE_CLASSNAME}`, dragHandle: `.${DRAG_HANDLE_CLASSNAME}`,
}; };
export const useBuildNode = () => { export const useBuildNode = () => {
const nodeTemplates = useAppSelector(templatesSelector); const nodeTemplates = useAppSelector(
(state) => state.nodeTemplates.templates
);
const flow = useReactFlow(); const flow = useReactFlow();

View File

@ -10,12 +10,12 @@ import { useMemo } from 'react';
export const useConnectionInputFieldNames = (nodeId: string) => { export const useConnectionInputFieldNames = (nodeId: string) => {
const selector = useMemo( const selector = useMemo(
() => () =>
createMemoizedSelector(stateSelector, ({ nodes }) => { createMemoizedSelector(stateSelector, ({ nodes, nodeTemplates }) => {
const node = nodes.nodes.find((node) => node.id === nodeId); const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) { if (!isInvocationNode(node)) {
return []; return [];
} }
const nodeTemplate = nodes.nodeTemplates[node.data.type]; const nodeTemplate = nodeTemplates.templates[node.data.type];
if (!nodeTemplate) { if (!nodeTemplate) {
return []; return [];
} }

View File

@ -8,12 +8,12 @@ import { useMemo } from 'react';
export const useDoNodeVersionsMatch = (nodeId: string) => { export const useDoNodeVersionsMatch = (nodeId: string) => {
const selector = useMemo( const selector = useMemo(
() => () =>
createMemoizedSelector(stateSelector, ({ nodes }) => { createMemoizedSelector(stateSelector, ({ nodes, nodeTemplates }) => {
const node = nodes.nodes.find((node) => node.id === nodeId); const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) { if (!isInvocationNode(node)) {
return false; return false;
} }
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? '']; const nodeTemplate = nodeTemplates.templates[node?.data.type ?? ''];
if (!nodeTemplate?.version || !node.data?.version) { if (!nodeTemplate?.version || !node.data?.version) {
return false; return false;
} }

View File

@ -7,12 +7,12 @@ import { useMemo } from 'react';
export const useFieldInputKind = (nodeId: string, fieldName: string) => { export const useFieldInputKind = (nodeId: string, fieldName: string) => {
const selector = useMemo( const selector = useMemo(
() => () =>
createMemoizedSelector(stateSelector, ({ nodes }) => { createMemoizedSelector(stateSelector, ({ nodes, nodeTemplates }) => {
const node = nodes.nodes.find((node) => node.id === nodeId); const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) { if (!isInvocationNode(node)) {
return; return;
} }
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? '']; const nodeTemplate = nodeTemplates.templates[node?.data.type ?? ''];
const fieldTemplate = nodeTemplate?.inputs[fieldName]; const fieldTemplate = nodeTemplate?.inputs[fieldName];
return fieldTemplate?.input; return fieldTemplate?.input;
}), }),

View File

@ -7,12 +7,12 @@ import { useMemo } from 'react';
export const useFieldInputTemplate = (nodeId: string, fieldName: string) => { export const useFieldInputTemplate = (nodeId: string, fieldName: string) => {
const selector = useMemo( const selector = useMemo(
() => () =>
createMemoizedSelector(stateSelector, ({ nodes }) => { createMemoizedSelector(stateSelector, ({ nodes, nodeTemplates }) => {
const node = nodes.nodes.find((node) => node.id === nodeId); const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) { if (!isInvocationNode(node)) {
return; return;
} }
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? '']; const nodeTemplate = nodeTemplates.templates[node?.data.type ?? ''];
return nodeTemplate?.inputs[fieldName]; return nodeTemplate?.inputs[fieldName];
}), }),
[fieldName, nodeId] [fieldName, nodeId]

View File

@ -7,12 +7,12 @@ import { useMemo } from 'react';
export const useFieldOutputTemplate = (nodeId: string, fieldName: string) => { export const useFieldOutputTemplate = (nodeId: string, fieldName: string) => {
const selector = useMemo( const selector = useMemo(
() => () =>
createMemoizedSelector(stateSelector, ({ nodes }) => { createMemoizedSelector(stateSelector, ({ nodes, nodeTemplates }) => {
const node = nodes.nodes.find((node) => node.id === nodeId); const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) { if (!isInvocationNode(node)) {
return; return;
} }
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? '']; const nodeTemplate = nodeTemplates.templates[node?.data.type ?? ''];
return nodeTemplate?.outputs[fieldName]; return nodeTemplate?.outputs[fieldName];
}), }),
[fieldName, nodeId] [fieldName, nodeId]

View File

@ -12,12 +12,12 @@ export const useFieldTemplate = (
) => { ) => {
const selector = useMemo( const selector = useMemo(
() => () =>
createMemoizedSelector(stateSelector, ({ nodes }) => { createMemoizedSelector(stateSelector, ({ nodes, nodeTemplates }) => {
const node = nodes.nodes.find((node) => node.id === nodeId); const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) { if (!isInvocationNode(node)) {
return; return;
} }
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? '']; const nodeTemplate = nodeTemplates.templates[node?.data.type ?? ''];
return nodeTemplate?.[KIND_MAP[kind]][fieldName]; return nodeTemplate?.[KIND_MAP[kind]][fieldName];
}), }),
[fieldName, kind, nodeId] [fieldName, kind, nodeId]

View File

@ -12,12 +12,12 @@ export const useFieldTemplateTitle = (
) => { ) => {
const selector = useMemo( const selector = useMemo(
() => () =>
createMemoizedSelector(stateSelector, ({ nodes }) => { createMemoizedSelector(stateSelector, ({ nodes, nodeTemplates }) => {
const node = nodes.nodes.find((node) => node.id === nodeId); const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) { if (!isInvocationNode(node)) {
return; return;
} }
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? '']; const nodeTemplate = nodeTemplates.templates[node?.data.type ?? ''];
return nodeTemplate?.[KIND_MAP[kind]][fieldName]?.title; return nodeTemplate?.[KIND_MAP[kind]][fieldName]?.title;
}), }),
[fieldName, kind, nodeId] [fieldName, kind, nodeId]

View File

@ -4,19 +4,15 @@ import { useAppSelector } from 'app/store/storeHooks';
import { isInvocationNode } from 'features/nodes/types/invocation'; import { isInvocationNode } from 'features/nodes/types/invocation';
import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate'; import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate';
const selector = createMemoizedSelector(stateSelector, (state) => { const selector = createMemoizedSelector(stateSelector, (state) =>
const nodes = state.nodes.nodes; state.nodes.nodes.filter(isInvocationNode).some((node) => {
const templates = state.nodes.nodeTemplates; const template = state.nodeTemplates.templates[node.data.type];
const needsUpdate = nodes.filter(isInvocationNode).some((node) => {
const template = templates[node.data.type];
if (!template) { if (!template) {
return false; return false;
} }
return getNeedsUpdate(node, template); return getNeedsUpdate(node, template);
}); })
return needsUpdate; );
});
export const useGetNodesNeedUpdate = () => { export const useGetNodesNeedUpdate = () => {
const getNeedsUpdate = useAppSelector(selector); const getNeedsUpdate = useAppSelector(selector);

View File

@ -7,12 +7,12 @@ import { useMemo } from 'react';
export const useNodeClassification = (nodeId: string) => { export const useNodeClassification = (nodeId: string) => {
const selector = useMemo( const selector = useMemo(
() => () =>
createMemoizedSelector(stateSelector, ({ nodes }) => { createMemoizedSelector(stateSelector, ({ nodes, nodeTemplates }) => {
const node = nodes.nodes.find((node) => node.id === nodeId); const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) { if (!isInvocationNode(node)) {
return false; return false;
} }
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? '']; const nodeTemplate = nodeTemplates.templates[node?.data.type ?? ''];
return nodeTemplate?.classification; return nodeTemplate?.classification;
}), }),
[nodeId] [nodeId]

View File

@ -8,9 +8,9 @@ import { useMemo } from 'react';
export const useNodeNeedsUpdate = (nodeId: string) => { export const useNodeNeedsUpdate = (nodeId: string) => {
const selector = useMemo( const selector = useMemo(
() => () =>
createMemoizedSelector(stateSelector, ({ nodes }) => { createMemoizedSelector(stateSelector, ({ nodes, nodeTemplates }) => {
const node = nodes.nodes.find((node) => node.id === nodeId); const node = nodes.nodes.find((node) => node.id === nodeId);
const template = nodes.nodeTemplates[node?.data.type ?? '']; const template = nodeTemplates.templates[node?.data.type ?? ''];
if (isInvocationNode(node) && template) { if (isInvocationNode(node) && template) {
return getNeedsUpdate(node, template); return getNeedsUpdate(node, template);
} }

View File

@ -6,9 +6,9 @@ import { useMemo } from 'react';
export const useNodeTemplate = (nodeId: string) => { export const useNodeTemplate = (nodeId: string) => {
const selector = useMemo( const selector = useMemo(
() => () =>
createMemoizedSelector(stateSelector, ({ nodes }) => { createMemoizedSelector(stateSelector, ({ nodes, nodeTemplates }) => {
const node = nodes.nodes.find((node) => node.id === nodeId); const node = nodes.nodes.find((node) => node.id === nodeId);
const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? '']; const nodeTemplate = nodeTemplates.templates[node?.data.type ?? ''];
return nodeTemplate; return nodeTemplate;
}), }),
[nodeId] [nodeId]

View File

@ -9,9 +9,8 @@ export const useNodeTemplateByType = (type: string) => {
() => () =>
createMemoizedSelector( createMemoizedSelector(
stateSelector, stateSelector,
({ nodes }): InvocationTemplate | undefined => { ({ nodeTemplates }): InvocationTemplate | undefined => {
const nodeTemplate = nodes.nodeTemplates[type]; return nodeTemplates.templates[type];
return nodeTemplate;
} }
), ),
[type] [type]

View File

@ -7,13 +7,13 @@ import { useMemo } from 'react';
export const useNodeTemplateTitle = (nodeId: string) => { export const useNodeTemplateTitle = (nodeId: string) => {
const selector = useMemo( const selector = useMemo(
() => () =>
createMemoizedSelector(stateSelector, ({ nodes }) => { createMemoizedSelector(stateSelector, ({ nodes, nodeTemplates }) => {
const node = nodes.nodes.find((node) => node.id === nodeId); const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) { if (!isInvocationNode(node)) {
return false; return false;
} }
const nodeTemplate = node const nodeTemplate = node
? nodes.nodeTemplates[node.data.type] ? nodeTemplates.templates[node.data.type]
: undefined; : undefined;
return nodeTemplate?.title; return nodeTemplate?.title;

View File

@ -9,12 +9,12 @@ import { useMemo } from 'react';
export const useOutputFieldNames = (nodeId: string) => { export const useOutputFieldNames = (nodeId: string) => {
const selector = useMemo( const selector = useMemo(
() => () =>
createMemoizedSelector(stateSelector, ({ nodes }) => { createMemoizedSelector(stateSelector, ({ nodes, nodeTemplates }) => {
const node = nodes.nodes.find((node) => node.id === nodeId); const node = nodes.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) { if (!isInvocationNode(node)) {
return []; return [];
} }
const nodeTemplate = nodes.nodeTemplates[node.data.type]; const nodeTemplate = nodeTemplates.templates[node.data.type];
if (!nodeTemplate) { if (!nodeTemplate) {
return []; return [];
} }

View File

@ -0,0 +1,26 @@
import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice } from '@reduxjs/toolkit';
import type { InvocationTemplate } from 'features/nodes/types/invocation';
import type { NodeTemplatesState } from './types';
export const initialNodeTemplatesState: NodeTemplatesState = {
templates: {},
};
const nodesTemplatesSlice = createSlice({
name: 'nodeTemplates',
initialState: initialNodeTemplatesState,
reducers: {
nodeTemplatesBuilt: (
state,
action: PayloadAction<Record<string, InvocationTemplate>>
) => {
state.templates = action.payload;
},
},
});
export const { nodeTemplatesBuilt } = nodesTemplatesSlice.actions;
export default nodesTemplatesSlice.reducer;

View File

@ -4,7 +4,6 @@ import type { NodesState } from './types';
* Nodes slice persist denylist * Nodes slice persist denylist
*/ */
export const nodesPersistDenylist: (keyof NodesState)[] = [ export const nodesPersistDenylist: (keyof NodesState)[] = [
'nodeTemplates',
'connectionStartParams', 'connectionStartParams',
'connectionStartFieldType', 'connectionStartFieldType',
'selectedNodes', 'selectedNodes',

View File

@ -1,6 +1,7 @@
import type { PayloadAction } from '@reduxjs/toolkit'; import type { PayloadAction } from '@reduxjs/toolkit';
import { createSlice, isAnyOf } from '@reduxjs/toolkit'; import { createSlice, isAnyOf } from '@reduxjs/toolkit';
import { workflowLoaded } from 'features/nodes/store/actions'; import { workflowLoaded } from 'features/nodes/store/actions';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodeTemplatesSlice';
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants'; import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
import type { import type {
BoardFieldValue, BoardFieldValue,
@ -41,7 +42,6 @@ import {
} from 'features/nodes/types/field'; } from 'features/nodes/types/field';
import type { import type {
AnyNode, AnyNode,
InvocationTemplate,
NodeExecutionState, NodeExecutionState,
} from 'features/nodes/types/invocation'; } from 'features/nodes/types/invocation';
import { import {
@ -97,7 +97,6 @@ const initialNodeExecutionState: Omit<NodeExecutionState, 'nodeId'> = {
export const initialNodesState: NodesState = { export const initialNodesState: NodesState = {
nodes: [], nodes: [],
edges: [], edges: [],
nodeTemplates: {},
isReady: false, isReady: false,
connectionStartParams: null, connectionStartParams: null,
connectionStartFieldType: null, connectionStartFieldType: null,
@ -656,13 +655,6 @@ const nodesSlice = createSlice({
shouldShowMinimapPanelChanged: (state, action: PayloadAction<boolean>) => { shouldShowMinimapPanelChanged: (state, action: PayloadAction<boolean>) => {
state.shouldShowMinimapPanel = action.payload; state.shouldShowMinimapPanel = action.payload;
}, },
nodeTemplatesBuilt: (
state,
action: PayloadAction<Record<string, InvocationTemplate>>
) => {
state.nodeTemplates = action.payload;
state.isReady = true;
},
nodeEditorReset: (state) => { nodeEditorReset: (state) => {
state.nodes = []; state.nodes = [];
state.edges = []; state.edges = [];
@ -893,6 +885,9 @@ const nodesSlice = createSlice({
}); });
} }
}); });
builder.addCase(nodeTemplatesBuilt, (state) => {
state.isReady = true;
});
}, },
}); });
@ -935,7 +930,6 @@ export const {
nodeOpacityChanged, nodeOpacityChanged,
nodesChanged, nodesChanged,
nodesDeleted, nodesDeleted,
nodeTemplatesBuilt,
nodeUseCacheChanged, nodeUseCacheChanged,
notesNodeValueChanged, notesNodeValueChanged,
selectedAll, selectedAll,

View File

@ -16,7 +16,6 @@ import type {
export type NodesState = { export type NodesState = {
nodes: AnyNode[]; nodes: AnyNode[];
edges: InvocationNodeEdge[]; edges: InvocationNodeEdge[];
nodeTemplates: Record<string, InvocationTemplate>;
connectionStartParams: OnConnectStartParams | null; connectionStartParams: OnConnectStartParams | null;
connectionStartFieldType: FieldType | null; connectionStartFieldType: FieldType | null;
connectionMade: boolean; connectionMade: boolean;
@ -42,3 +41,7 @@ export type NodesState = {
export type WorkflowsState = Omit<WorkflowV2, 'nodes' | 'edges'> & { export type WorkflowsState = Omit<WorkflowV2, 'nodes' | 'edges'> & {
isTouched: boolean; isTouched: boolean;
}; };
export type NodeTemplatesState = {
templates: Record<string, InvocationTemplate>;
};

View File

@ -33,7 +33,7 @@ const zWorkflowMetaVersion = z.object({
* - Workflow schema version bumped to 2.0.0 * - Workflow schema version bumped to 2.0.0
*/ */
const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => { const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => {
const invocationTemplates = $store.get()?.getState().nodes.nodeTemplates; const invocationTemplates = $store.get()?.getState().nodeTemplates.templates;
if (!invocationTemplates) { if (!invocationTemplates) {
throw new Error(t('app.storeNotInitialized')); throw new Error(t('app.storeNotInitialized'));