feat(ui): move invocation templates out of redux

Templates are stored in nanostores. All hooks, selectors, etc are reworked to reference the nanostore.
This commit is contained in:
psychedelicious 2024-05-16 17:00:08 +10:00
parent f6a44681a8
commit 1d884fb794
23 changed files with 146 additions and 265 deletions

View File

@ -25,7 +25,7 @@ export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartLi
unableToUpdateCount++; unableToUpdateCount++;
return; return;
} }
if (!getNeedsUpdate(node, template)) { if (!getNeedsUpdate(node.data, template)) {
// No need to increment the count here, since we're not actually updating // No need to increment the count here, since we're not actually updating
return; return;
} }

View File

@ -1,6 +1,8 @@
import { Badge, Flex } from '@invoke-ai/ui-library'; import { Badge, Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens'; import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens';
import { $templates } from 'features/nodes/store/nodesSlice';
import { memo, useMemo } from 'react'; import { memo, useMemo } from 'react';
import type { EdgeProps } from 'reactflow'; import type { EdgeProps } from 'reactflow';
import { BaseEdge, EdgeLabelRenderer, getBezierPath } from 'reactflow'; import { BaseEdge, EdgeLabelRenderer, getBezierPath } from 'reactflow';
@ -22,9 +24,10 @@ const InvocationCollapsedEdge = ({
sourceHandleId, sourceHandleId,
targetHandleId, targetHandleId,
}: EdgeProps<{ count: number }>) => { }: EdgeProps<{ count: number }>) => {
const templates = useStore($templates);
const selector = useMemo( const selector = useMemo(
() => makeEdgeSelector(source, sourceHandleId, target, targetHandleId, selected), () => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId, selected),
[selected, source, sourceHandleId, target, targetHandleId] [templates, selected, source, sourceHandleId, target, targetHandleId]
); );
const { isSelected, shouldAnimate } = useAppSelector(selector); const { isSelected, shouldAnimate } = useAppSelector(selector);

View File

@ -1,5 +1,7 @@
import { Flex, Text } from '@invoke-ai/ui-library'; import { Flex, Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { $templates } from 'features/nodes/store/nodesSlice';
import type { CSSProperties } from 'react'; import type { CSSProperties } from 'react';
import { memo, useMemo } from 'react'; import { memo, useMemo } from 'react';
import type { EdgeProps } from 'reactflow'; import type { EdgeProps } from 'reactflow';
@ -21,9 +23,10 @@ const InvocationDefaultEdge = ({
sourceHandleId, sourceHandleId,
targetHandleId, targetHandleId,
}: EdgeProps) => { }: EdgeProps) => {
const templates = useStore($templates);
const selector = useMemo( const selector = useMemo(
() => makeEdgeSelector(source, sourceHandleId, target, targetHandleId, selected), () => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId, selected),
[source, sourceHandleId, target, targetHandleId, selected] [templates, source, sourceHandleId, target, targetHandleId, selected]
); );
const { isSelected, shouldAnimate, stroke, label } = useAppSelector(selector); const { isSelected, shouldAnimate, stroke, label } = useAppSelector(selector);

View File

@ -1,8 +1,8 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectFieldOutputTemplate, selectNodeTemplate } from 'features/nodes/store/selectors';
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice'; import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import type { InvocationTemplate } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation'; import { isInvocationNode } from 'features/nodes/types/invocation';
import { getFieldColor } from './getEdgeColor'; import { getFieldColor } from './getEdgeColor';
@ -15,6 +15,7 @@ const defaultReturnValue = {
}; };
export const makeEdgeSelector = ( export const makeEdgeSelector = (
templates: Record<string, InvocationTemplate>,
source: string, source: string,
sourceHandleId: string | null | undefined, sourceHandleId: string | null | undefined,
target: string, target: string,
@ -35,13 +36,14 @@ export const makeEdgeSelector = (
return defaultReturnValue; return defaultReturnValue;
} }
const outputFieldTemplate = selectFieldOutputTemplate(nodes, sourceNode.id, sourceHandleId); const sourceNodeTemplate = templates[sourceNode.data.type];
const targetNodeTemplate = templates[targetNode.data.type];
const outputFieldTemplate = sourceNodeTemplate?.outputs[sourceHandleId];
const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined; const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined;
const stroke = sourceType && workflowSettings.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500'); const stroke =
sourceType && workflowSettings.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
const sourceNodeTemplate = selectNodeTemplate(nodes, sourceNode.id);
const targetNodeTemplate = selectNodeTemplate(nodes, targetNode.id);
const label = `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`; const label = `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`;

View File

@ -1,4 +1,5 @@
import { useStore } from '@nanostores/react'; import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import InvocationNode from 'features/nodes/components/flow/nodes/Invocation/InvocationNode'; import InvocationNode from 'features/nodes/components/flow/nodes/Invocation/InvocationNode';
import { $templates } from 'features/nodes/store/nodesSlice'; import { $templates } from 'features/nodes/store/nodesSlice';
import type { InvocationNodeData } from 'features/nodes/types/invocation'; import type { InvocationNodeData } from 'features/nodes/types/invocation';
@ -12,6 +13,11 @@ const InvocationNodeWrapper = (props: NodeProps<InvocationNodeData>) => {
const { id: nodeId, type, isOpen, label } = data; const { id: nodeId, type, isOpen, label } = data;
const templates = useStore($templates); const templates = useStore($templates);
const hasTemplate = useMemo(() => Boolean(templates[type]), [templates, type]); const hasTemplate = useMemo(() => Boolean(templates[type]), [templates, type]);
const nodeExists = useAppSelector((s) => Boolean(s.nodes.present.nodes.find((n) => n.id === nodeId)));
if (!nodeExists) {
return null;
}
if (!hasTemplate) { if (!hasTemplate) {
return ( return (

View File

@ -1,31 +1,19 @@
import { EMPTY_ARRAY } from 'app/store/constants'; import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectNodeTemplate } from 'features/nodes/store/selectors';
import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames'; import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames';
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate'; import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
import { keys, map } from 'lodash-es'; import { keys, map } from 'lodash-es';
import { useMemo } from 'react'; import { useMemo } from 'react';
export const useAnyOrDirectInputFieldNames = (nodeId: string): string[] => { export const useAnyOrDirectInputFieldNames = (nodeId: string): string[] => {
const selector = useMemo( const template = useNodeTemplate(nodeId);
() => const fieldNames = useMemo(() => {
createMemoizedSelector(selectNodesSlice, (nodes) => { const fields = map(template.inputs).filter(
const template = selectNodeTemplate(nodes, nodeId); (field) =>
if (!template) { (['any', 'direct'].includes(field.input) || field.type.isCollectionOrScalar) &&
return EMPTY_ARRAY; keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
} );
const fields = map(template.inputs).filter( return getSortedFilteredFieldNames(fields);
(field) => }, [template]);
(['any', 'direct'].includes(field.input) || field.type.isCollectionOrScalar) &&
keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
);
return getSortedFilteredFieldNames(fields);
}),
[nodeId]
);
const fieldNames = useAppSelector(selector);
return fieldNames; return fieldNames;
}; };

View File

@ -1,34 +1,20 @@
import { EMPTY_ARRAY } from 'app/store/constants'; import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectNodeTemplate } from 'features/nodes/store/selectors';
import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames'; import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames';
import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate'; import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate';
import { keys, map } from 'lodash-es'; import { keys, map } from 'lodash-es';
import { useMemo } from 'react'; import { useMemo } from 'react';
export const useConnectionInputFieldNames = (nodeId: string): string[] => { export const useConnectionInputFieldNames = (nodeId: string): string[] => {
const selector = useMemo( const template = useNodeTemplate(nodeId);
() => const fieldNames = useMemo(() => {
createMemoizedSelector(selectNodesSlice, (nodes) => { // get the visible fields
const template = selectNodeTemplate(nodes, nodeId); const fields = map(template.inputs).filter(
if (!template) { (field) =>
return EMPTY_ARRAY; (field.input === 'connection' && !field.type.isCollectionOrScalar) ||
} !keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
);
// get the visible fields return getSortedFilteredFieldNames(fields);
const fields = map(template.inputs).filter( }, [template]);
(field) =>
(field.input === 'connection' && !field.type.isCollectionOrScalar) ||
!keys(TEMPLATE_BUILDER_MAP).includes(field.type.name)
);
return getSortedFilteredFieldNames(fields);
}),
[nodeId]
);
const fieldNames = useAppSelector(selector);
return fieldNames; return fieldNames;
}; };

View File

@ -1,20 +1,9 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectFieldInputTemplate } from 'features/nodes/store/selectors';
import type { FieldInputTemplate } from 'features/nodes/types/field'; import type { FieldInputTemplate } from 'features/nodes/types/field';
import { useMemo } from 'react'; import { useMemo } from 'react';
export const useFieldInputTemplate = (nodeId: string, fieldName: string): FieldInputTemplate | null => { export const useFieldInputTemplate = (nodeId: string, fieldName: string): FieldInputTemplate | null => {
const selector = useMemo( const template = useNodeTemplate(nodeId);
() => const fieldTemplate = useMemo(() => template.inputs[fieldName] ?? null, [fieldName, template.inputs]);
createMemoizedSelector(selectNodesSlice, (nodes) => {
return selectFieldInputTemplate(nodes, nodeId, fieldName);
}),
[fieldName, nodeId]
);
const fieldTemplate = useAppSelector(selector);
return fieldTemplate; return fieldTemplate;
}; };

View File

@ -1,20 +1,9 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectFieldOutputTemplate } from 'features/nodes/store/selectors';
import type { FieldOutputTemplate } from 'features/nodes/types/field'; import type { FieldOutputTemplate } from 'features/nodes/types/field';
import { useMemo } from 'react'; import { useMemo } from 'react';
export const useFieldOutputTemplate = (nodeId: string, fieldName: string): FieldOutputTemplate | null => { export const useFieldOutputTemplate = (nodeId: string, fieldName: string): FieldOutputTemplate | null => {
const selector = useMemo( const template = useNodeTemplate(nodeId);
() => const fieldTemplate = useMemo(() => template.outputs[fieldName] ?? null, [fieldName, template.outputs]);
createMemoizedSelector(selectNodesSlice, (nodes) => {
return selectFieldOutputTemplate(nodes, nodeId, fieldName);
}),
[fieldName, nodeId]
);
const fieldTemplate = useAppSelector(selector);
return fieldTemplate; return fieldTemplate;
}; };

View File

@ -1,27 +1,36 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors'; import { selectInvocationNodeType } from 'features/nodes/store/selectors';
import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field'; import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
import { useMemo } from 'react'; import { useMemo } from 'react';
import { assert } from 'tsafe';
export const useFieldTemplate = ( export const useFieldTemplate = (
nodeId: string, nodeId: string,
fieldName: string, fieldName: string,
kind: 'inputs' | 'outputs' kind: 'inputs' | 'outputs'
): FieldInputTemplate | FieldOutputTemplate | null => { ): FieldInputTemplate | FieldOutputTemplate => {
const selector = useMemo( const templates = useStore($templates);
() => const selectNodeType = useMemo(
createMemoizedSelector(selectNodesSlice, (nodes) => { () => createSelector(selectNodesSlice, (nodes) => selectInvocationNodeType(nodes, nodeId)),
if (kind === 'inputs') { [nodeId]
return selectFieldInputTemplate(nodes, nodeId, fieldName);
}
return selectFieldOutputTemplate(nodes, nodeId, fieldName);
}),
[fieldName, kind, nodeId]
); );
const nodeType = useAppSelector(selectNodeType);
const fieldTemplate = useAppSelector(selector); const fieldTemplate = useMemo(() => {
const template = templates[nodeType];
assert(template, `Template for node type ${nodeType} not found`);
if (kind === 'inputs') {
const fieldTemplate = template.inputs[fieldName];
assert(fieldTemplate, `Field template for field ${fieldName} not found`);
return fieldTemplate;
} else {
const fieldTemplate = template.outputs[fieldName];
assert(fieldTemplate, `Field template for field ${fieldName} not found`);
return fieldTemplate;
}
}, [fieldName, kind, nodeType, templates]);
return fieldTemplate; return fieldTemplate;
}; };

View File

@ -1,22 +1,8 @@
import { createSelector } from '@reduxjs/toolkit'; import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors';
import { useMemo } from 'react'; import { useMemo } from 'react';
export const useFieldTemplateTitle = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): string | null => { export const useFieldTemplateTitle = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): string | null => {
const selector = useMemo( const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind);
() => const fieldTemplateTitle = useMemo(() => fieldTemplate.title, [fieldTemplate]);
createSelector(selectNodesSlice, (nodes) => {
if (kind === 'inputs') {
return selectFieldInputTemplate(nodes, nodeId, fieldName)?.title ?? null;
}
return selectFieldOutputTemplate(nodes, nodeId, fieldName)?.title ?? null;
}),
[fieldName, kind, nodeId]
);
const fieldTemplateTitle = useAppSelector(selector);
return fieldTemplateTitle; return fieldTemplateTitle;
}; };

View File

@ -1,23 +1,9 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors';
import type { FieldType } from 'features/nodes/types/field'; import type { FieldType } from 'features/nodes/types/field';
import { useMemo } from 'react'; import { useMemo } from 'react';
export const useFieldType = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): FieldType | null => { export const useFieldType = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): FieldType => {
const selector = useMemo( const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind);
() => const fieldType = useMemo(() => fieldTemplate.type, [fieldTemplate]);
createMemoizedSelector(selectNodesSlice, (nodes) => {
if (kind === 'inputs') {
return selectFieldInputTemplate(nodes, nodeId, fieldName)?.type ?? null;
}
return selectFieldOutputTemplate(nodes, nodeId, fieldName)?.type ?? null;
}),
[fieldName, kind, nodeId]
);
const fieldType = useAppSelector(selector);
return fieldType; return fieldType;
}; };

View File

@ -16,7 +16,7 @@ export const useGetNodesNeedUpdate = () => {
if (!template) { if (!template) {
return false; return false;
} }
return getNeedsUpdate(node, template); return getNeedsUpdate(node.data, template);
}) })
), ),
[templates] [templates]

View File

@ -1,26 +1,20 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectNodeTemplate } from 'features/nodes/store/selectors';
import { some } from 'lodash-es'; import { some } from 'lodash-es';
import { useMemo } from 'react'; import { useMemo } from 'react';
export const useHasImageOutput = (nodeId: string): boolean => { export const useHasImageOutput = (nodeId: string): boolean => {
const selector = useMemo( const template = useNodeTemplate(nodeId);
const hasImageOutput = useMemo(
() => () =>
createMemoizedSelector(selectNodesSlice, (nodes) => { some(
const template = selectNodeTemplate(nodes, nodeId); template?.outputs,
return some( (output) =>
template?.outputs, output.type.name === 'ImageField' &&
(output) => // the image primitive node (node type "image") does not actually save the image, do not show the image-saving checkboxes
output.type.name === 'ImageField' && template?.type !== 'image'
// the image primitive node (node type "image") does not actually save the image, do not show the image-saving checkboxes ),
template?.type !== 'image' [template]
);
}),
[nodeId]
); );
const hasImageOutput = useAppSelector(selector);
return hasImageOutput; return hasImageOutput;
}; };

View File

@ -1,19 +1,9 @@
import { createSelector } from '@reduxjs/toolkit'; import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectNodeTemplate } from 'features/nodes/store/selectors';
import type { Classification } from 'features/nodes/types/common'; import type { Classification } from 'features/nodes/types/common';
import { useMemo } from 'react'; import { useMemo } from 'react';
export const useNodeClassification = (nodeId: string): Classification | null => { export const useNodeClassification = (nodeId: string): Classification => {
const selector = useMemo( const template = useNodeTemplate(nodeId);
() => const classification = useMemo(() => template.classification, [template]);
createSelector(selectNodesSlice, (nodes) => { return classification;
return selectNodeTemplate(nodes, nodeId)?.classification ?? null;
}),
[nodeId]
);
const title = useAppSelector(selector);
return title;
}; };

View File

@ -5,7 +5,7 @@ import { selectNodeData } from 'features/nodes/store/selectors';
import type { InvocationNodeData } from 'features/nodes/types/invocation'; import type { InvocationNodeData } from 'features/nodes/types/invocation';
import { useMemo } from 'react'; import { useMemo } from 'react';
export const useNodeData = (nodeId: string): InvocationNodeData | null => { export const useNodeData = (nodeId: string): InvocationNodeData => {
const selector = useMemo( const selector = useMemo(
() => () =>
createMemoizedSelector(selectNodesSlice, (nodes) => { createMemoizedSelector(selectNodesSlice, (nodes) => {

View File

@ -1,25 +1,11 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useNodeData } from 'features/nodes/hooks/useNodeData';
import { useAppSelector } from 'app/store/storeHooks'; import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectInvocationNode, selectNodeTemplate } from 'features/nodes/store/selectors';
import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate'; import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate';
import { useMemo } from 'react'; import { useMemo } from 'react';
export const useNodeNeedsUpdate = (nodeId: string) => { export const useNodeNeedsUpdate = (nodeId: string) => {
const selector = useMemo( const data = useNodeData(nodeId);
() => const template = useNodeTemplate(nodeId);
createMemoizedSelector(selectNodesSlice, (nodes) => { const needsUpdate = useMemo(() => getNeedsUpdate(data, template), [data, template]);
const node = selectInvocationNode(nodes, nodeId);
const template = selectNodeTemplate(nodes, nodeId);
if (!node || !template) {
return false;
}
return getNeedsUpdate(node, template);
}),
[nodeId]
);
const needsUpdate = useAppSelector(selector);
return needsUpdate; return needsUpdate;
}; };

View File

@ -1,20 +1,23 @@
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectNodeTemplate } from 'features/nodes/store/selectors'; import { selectInvocationNodeType } from 'features/nodes/store/selectors';
import type { InvocationTemplate } from 'features/nodes/types/invocation'; import type { InvocationTemplate } from 'features/nodes/types/invocation';
import { useMemo } from 'react'; import { useMemo } from 'react';
import { assert } from 'tsafe';
export const useNodeTemplate = (nodeId: string): InvocationTemplate | null => { export const useNodeTemplate = (nodeId: string): InvocationTemplate => {
const selector = useMemo( const templates = useStore($templates);
() => const selectNodeType = useMemo(
createSelector(selectNodesSlice, (nodes) => { () => createSelector(selectNodesSlice, (nodes) => selectInvocationNodeType(nodes, nodeId)),
return selectNodeTemplate(nodes, nodeId);
}),
[nodeId] [nodeId]
); );
const nodeType = useAppSelector(selectNodeType);
const nodeTemplate = useAppSelector(selector); const template = useMemo(() => {
const t = templates[nodeType];
return nodeTemplate; assert(t, `Template for node type ${nodeType} not found`);
return t;
}, [nodeType, templates]);
return template;
}; };

View File

@ -1,18 +1,8 @@
import { createSelector } from '@reduxjs/toolkit'; import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectNodeTemplate } from 'features/nodes/store/selectors';
import { useMemo } from 'react'; import { useMemo } from 'react';
export const useNodeTemplateTitle = (nodeId: string): string | null => { export const useNodeTemplateTitle = (nodeId: string): string | null => {
const selector = useMemo( const template = useNodeTemplate(nodeId);
() => const title = useMemo(() => template.title, [template.title]);
createSelector(selectNodesSlice, (nodes) => {
return selectNodeTemplate(nodes, nodeId)?.title ?? null;
}),
[nodeId]
);
const title = useAppSelector(selector);
return title; return title;
}; };

View File

@ -1,26 +1,10 @@
import { EMPTY_ARRAY } from 'app/store/constants'; import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectNodeTemplate } from 'features/nodes/store/selectors';
import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames'; import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames';
import { map } from 'lodash-es'; import { map } from 'lodash-es';
import { useMemo } from 'react'; import { useMemo } from 'react';
export const useOutputFieldNames = (nodeId: string) => { export const useOutputFieldNames = (nodeId: string): string[] => {
const selector = useMemo( const template = useNodeTemplate(nodeId);
() => const fieldNames = useMemo(() => getSortedFilteredFieldNames(map(template.outputs)), [template.outputs]);
createMemoizedSelector(selectNodesSlice, (nodes) => {
const template = selectNodeTemplate(nodes, nodeId);
if (!template) {
return EMPTY_ARRAY;
}
return getSortedFilteredFieldNames(map(template.outputs));
}),
[nodeId]
);
const fieldNames = useAppSelector(selector);
return fieldNames; return fieldNames;
}; };

View File

@ -1,18 +1,23 @@
import type { NodesState } from 'features/nodes/store/types'; import type { NodesState } from 'features/nodes/store/types';
import type { FieldInputInstance, FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field'; import type { FieldInputInstance } from 'features/nodes/types/field';
import type { InvocationNode, InvocationNodeData } from 'features/nodes/types/invocation'; import type { InvocationNode, InvocationNodeData } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation'; import { isInvocationNode } from 'features/nodes/types/invocation';
import { assert } from 'tsafe';
export const selectInvocationNode = (nodesSlice: NodesState, nodeId: string): InvocationNode | null => { export const selectInvocationNode = (nodesSlice: NodesState, nodeId: string): InvocationNode => {
const node = nodesSlice.nodes.find((node) => node.id === nodeId); const node = nodesSlice.nodes.find((node) => node.id === nodeId);
if (!isInvocationNode(node)) { assert(isInvocationNode(node), `Node ${nodeId} is not an invocation node`);
return null;
}
return node; return node;
}; };
export const selectNodeData = (nodesSlice: NodesState, nodeId: string): InvocationNodeData | null => { export const selectInvocationNodeType = (nodesSlice: NodesState, nodeId: string): string => {
return selectInvocationNode(nodesSlice, nodeId)?.data ?? null; const node = selectInvocationNode(nodesSlice, nodeId);
return node.data.type;
};
export const selectNodeData = (nodesSlice: NodesState, nodeId: string): InvocationNodeData => {
const node = selectInvocationNode(nodesSlice, nodeId);
return node.data;
}; };
export const selectFieldInputInstance = ( export const selectFieldInputInstance = (
@ -23,21 +28,3 @@ export const selectFieldInputInstance = (
const data = selectNodeData(nodesSlice, nodeId); const data = selectNodeData(nodesSlice, nodeId);
return data?.inputs[fieldName] ?? null; return data?.inputs[fieldName] ?? null;
}; };
export const selectFieldInputTemplate = (
nodesSlice: NodesState,
nodeId: string,
fieldName: string
): FieldInputTemplate | null => {
const template = selectNodeTemplate(nodesSlice, nodeId);
return template?.inputs[fieldName] ?? null;
};
export const selectFieldOutputTemplate = (
nodesSlice: NodesState,
nodeId: string,
fieldName: string
): FieldOutputTemplate | null => {
const template = selectNodeTemplate(nodesSlice, nodeId);
return template?.outputs[fieldName] ?? null;
};

View File

@ -1,17 +1,17 @@
import { deepClone } from 'common/util/deepClone'; import { deepClone } from 'common/util/deepClone';
import { satisfies } from 'compare-versions'; import { satisfies } from 'compare-versions';
import { NodeUpdateError } from 'features/nodes/types/error'; import { NodeUpdateError } from 'features/nodes/types/error';
import type { InvocationNode, InvocationTemplate } from 'features/nodes/types/invocation'; import type { InvocationNode, InvocationNodeData, InvocationTemplate } from 'features/nodes/types/invocation';
import { zParsedSemver } from 'features/nodes/types/semver'; import { zParsedSemver } from 'features/nodes/types/semver';
import { defaultsDeep, keys, pick } from 'lodash-es'; import { defaultsDeep, keys, pick } from 'lodash-es';
import { buildInvocationNode } from './buildInvocationNode'; import { buildInvocationNode } from './buildInvocationNode';
export const getNeedsUpdate = (node: InvocationNode, template: InvocationTemplate): boolean => { export const getNeedsUpdate = (data: InvocationNodeData, template: InvocationTemplate): boolean => {
if (node.data.type !== template.type) { if (data.type !== template.type) {
return true; return true;
} }
return node.data.version !== template.version; return data.version !== template.version;
}; };
/** /**
@ -20,7 +20,7 @@ export const getNeedsUpdate = (node: InvocationNode, template: InvocationTemplat
* @param template The invocation template to check against. * @param template The invocation template to check against.
*/ */
const getMayUpdateNode = (node: InvocationNode, template: InvocationTemplate): boolean => { const getMayUpdateNode = (node: InvocationNode, template: InvocationTemplate): boolean => {
const needsUpdate = getNeedsUpdate(node, template); const needsUpdate = getNeedsUpdate(node.data, template);
if (!needsUpdate || node.data.type !== template.type) { if (!needsUpdate || node.data.type !== template.type) {
return false; return false;
} }

View File

@ -68,7 +68,7 @@ export const validateWorkflow = (
return; return;
} }
if (getNeedsUpdate(node, template)) { if (getNeedsUpdate(node.data, template)) {
// This node needs to be updated, based on comparison of its version to the template version // This node needs to be updated, based on comparison of its version to the template version
const message = t('nodes.mismatchedVersion', { const message = t('nodes.mismatchedVersion', {
node: node.id, node: node.id,