feat(ui): move invocation templates out of redux (wip)

This commit is contained in:
psychedelicious 2024-05-16 15:17:23 +10:00
parent d4df312300
commit f6a44681a8
18 changed files with 303 additions and 318 deletions

View File

@ -1,7 +1,6 @@
import type { UnknownAction } from '@reduxjs/toolkit';
import { deepClone } from 'common/util/deepClone';
import { isAnyGraphBuilt } from 'features/nodes/store/actions';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
import { appInfoApi } from 'services/api/endpoints/appInfo';
import type { Graph } from 'services/api/types';
import { socketGeneratorProgress } from 'services/events/actions';
@ -25,13 +24,6 @@ export const actionSanitizer = <A extends UnknownAction>(action: A): A => {
};
}
if (nodeTemplatesBuilt.match(action)) {
return {
...action,
payload: '<Node templates omitted>',
};
}
if (socketGeneratorProgress.match(action)) {
const sanitized = deepClone(action);
if (sanitized.payload.data.progress_image) {

View File

@ -1,7 +1,7 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice';
import { $templates } from 'features/nodes/store/nodesSlice';
import { parseSchema } from 'features/nodes/util/schema/parseSchema';
import { size } from 'lodash-es';
import { appInfoApi } from 'services/api/endpoints/appInfo';
@ -9,7 +9,7 @@ import { appInfoApi } from 'services/api/endpoints/appInfo';
export const addGetOpenAPISchemaListener = (startAppListening: AppStartListening) => {
startAppListening({
matcher: appInfoApi.endpoints.getOpenAPISchema.matchFulfilled,
effect: (action, { dispatch, getState }) => {
effect: (action, { getState }) => {
const log = logger('system');
const schemaJSON = action.payload;
@ -20,7 +20,7 @@ export const addGetOpenAPISchemaListener = (startAppListening: AppStartListening
log.debug({ nodeTemplates: parseify(nodeTemplates) }, `Built ${size(nodeTemplates)} node templates`);
dispatch(nodeTemplatesBuilt(nodeTemplates));
$templates.set(nodeTemplates);
},
});

View File

@ -1,7 +1,7 @@
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { updateAllNodesRequested } from 'features/nodes/store/actions';
import { nodeReplaced } from 'features/nodes/store/nodesSlice';
import { $templates, nodeReplaced } from 'features/nodes/store/nodesSlice';
import { NodeUpdateError } from 'features/nodes/types/error';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { getNeedsUpdate, updateNode } from 'features/nodes/util/node/nodeUpdate';
@ -14,7 +14,8 @@ export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartLi
actionCreator: updateAllNodesRequested,
effect: (action, { dispatch, getState }) => {
const log = logger('nodes');
const { nodes, templates } = getState().nodes.present;
const { nodes } = getState().nodes.present;
const templates = $templates.get();
let unableToUpdateCount = 0;

View File

@ -2,6 +2,7 @@ import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import { workflowLoaded, workflowLoadRequested } from 'features/nodes/store/actions';
import { $templates } from 'features/nodes/store/nodesSlice';
import { $flow } from 'features/nodes/store/reactFlowInstance';
import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error';
import { validateWorkflow } from 'features/nodes/util/workflow/validateWorkflow';
@ -14,10 +15,10 @@ import { fromZodError } from 'zod-validation-error';
export const addWorkflowLoadRequestedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: workflowLoadRequested,
effect: (action, { dispatch, getState }) => {
effect: (action, { dispatch }) => {
const log = logger('nodes');
const { workflow, asCopy } = action.payload;
const nodeTemplates = getState().nodes.present.templates;
const nodeTemplates = $templates.get();
try {
const { workflow: validatedWorkflow, warnings } = validateWorkflow(workflow, nodeTemplates);

View File

@ -1,3 +1,4 @@
import { useStore } from '@nanostores/react';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import {
@ -9,14 +10,16 @@ import { selectControlLayersSlice } from 'features/controlLayers/store/controlLa
import type { Layer } from 'features/controlLayers/store/types';
import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import type { InvocationTemplate } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import { selectSystemSlice } from 'features/system/store/systemSlice';
import { activeTabNameSelector } from 'features/ui/store/uiSelectors';
import i18n from 'i18next';
import { forEach, upperFirst } from 'lodash-es';
import { useMemo } from 'react';
import { getConnectedEdges } from 'reactflow';
const LAYER_TYPE_TO_TKEY: Record<Layer['type'], string> = {
@ -26,200 +29,205 @@ const LAYER_TYPE_TO_TKEY: Record<Layer['type'], string> = {
regional_guidance_layer: 'controlLayers.regionalGuidance',
};
const selector = createMemoizedSelector(
[
selectControlAdaptersSlice,
selectGenerationSlice,
selectSystemSlice,
selectNodesSlice,
selectWorkflowSettingsSlice,
selectDynamicPromptsSlice,
selectControlLayersSlice,
activeTabNameSelector,
],
(controlAdapters, generation, system, nodes, workflowSettings, dynamicPrompts, controlLayers, activeTabName) => {
const { model } = generation;
const { size } = controlLayers.present;
const { positivePrompt } = controlLayers.present;
const createSelector = (templates: Record<string, InvocationTemplate>) =>
createMemoizedSelector(
[
selectControlAdaptersSlice,
selectGenerationSlice,
selectSystemSlice,
selectNodesSlice,
selectWorkflowSettingsSlice,
selectDynamicPromptsSlice,
selectControlLayersSlice,
activeTabNameSelector,
],
(controlAdapters, generation, system, nodes, workflowSettings, dynamicPrompts, controlLayers, activeTabName) => {
const { model } = generation;
const { size } = controlLayers.present;
const { positivePrompt } = controlLayers.present;
const { isConnected } = system;
const { isConnected } = system;
const reasons: { prefix?: string; content: string }[] = [];
const reasons: { prefix?: string; content: string }[] = [];
// Cannot generate if not connected
if (!isConnected) {
reasons.push({ content: i18n.t('parameters.invoke.systemDisconnected') });
}
// Cannot generate if not connected
if (!isConnected) {
reasons.push({ content: i18n.t('parameters.invoke.systemDisconnected') });
}
if (activeTabName === 'workflows') {
if (workflowSettings.shouldValidateGraph) {
if (!nodes.nodes.length) {
reasons.push({ content: i18n.t('parameters.invoke.noNodesInGraph') });
if (activeTabName === 'workflows') {
if (workflowSettings.shouldValidateGraph) {
if (!nodes.nodes.length) {
reasons.push({ content: i18n.t('parameters.invoke.noNodesInGraph') });
}
nodes.nodes.forEach((node) => {
if (!isInvocationNode(node)) {
return;
}
const nodeTemplate = templates[node.data.type];
if (!nodeTemplate) {
// Node type not found
reasons.push({ content: i18n.t('parameters.invoke.missingNodeTemplate') });
return;
}
const connectedEdges = getConnectedEdges([node], nodes.edges);
forEach(node.data.inputs, (field) => {
const fieldTemplate = nodeTemplate.inputs[field.name];
const hasConnection = connectedEdges.some(
(edge) => edge.target === node.id && edge.targetHandle === field.name
);
if (!fieldTemplate) {
reasons.push({ content: i18n.t('parameters.invoke.missingFieldTemplate') });
return;
}
if (fieldTemplate.required && field.value === undefined && !hasConnection) {
reasons.push({
content: i18n.t('parameters.invoke.missingInputForField', {
nodeLabel: node.data.label || nodeTemplate.title,
fieldLabel: field.label || fieldTemplate.title,
}),
});
return;
}
});
});
}
} else {
if (dynamicPrompts.prompts.length === 0 && getShouldProcessPrompt(positivePrompt)) {
reasons.push({ content: i18n.t('parameters.invoke.noPrompts') });
}
nodes.nodes.forEach((node) => {
if (!isInvocationNode(node)) {
return;
}
if (!model) {
reasons.push({ content: i18n.t('parameters.invoke.noModelSelected') });
}
const nodeTemplate = nodes.templates[node.data.type];
if (!nodeTemplate) {
// Node type not found
reasons.push({ content: i18n.t('parameters.invoke.missingNodeTemplate') });
return;
}
const connectedEdges = getConnectedEdges([node], nodes.edges);
forEach(node.data.inputs, (field) => {
const fieldTemplate = nodeTemplate.inputs[field.name];
const hasConnection = connectedEdges.some(
(edge) => edge.target === node.id && edge.targetHandle === field.name
);
if (!fieldTemplate) {
reasons.push({ content: i18n.t('parameters.invoke.missingFieldTemplate') });
return;
}
if (fieldTemplate.required && field.value === undefined && !hasConnection) {
reasons.push({
content: i18n.t('parameters.invoke.missingInputForField', {
nodeLabel: node.data.label || nodeTemplate.title,
fieldLabel: field.label || fieldTemplate.title,
}),
});
return;
}
});
});
}
} else {
if (dynamicPrompts.prompts.length === 0 && getShouldProcessPrompt(positivePrompt)) {
reasons.push({ content: i18n.t('parameters.invoke.noPrompts') });
}
if (!model) {
reasons.push({ content: i18n.t('parameters.invoke.noModelSelected') });
}
if (activeTabName === 'generation') {
// Handling for generation tab
controlLayers.present.layers
.filter((l) => l.isEnabled)
.forEach((l, i) => {
const layerLiteral = i18n.t('controlLayers.layers_one');
const layerNumber = i + 1;
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[l.type]);
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
const problems: string[] = [];
if (l.type === 'control_adapter_layer') {
// Must have model
if (!l.controlAdapter.model) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoModelSelected'));
}
// Model base must match
if (l.controlAdapter.model?.base !== model?.base) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterIncompatibleBaseModel'));
}
// Must have a control image OR, if it has a processor, it must have a processed image
if (!l.controlAdapter.image) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoImageSelected'));
} else if (l.controlAdapter.processorConfig && !l.controlAdapter.processedImage) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterImageNotProcessed'));
}
// T2I Adapters require images have dimensions that are multiples of 64
if (l.controlAdapter.type === 't2i_adapter' && (size.width % 64 !== 0 || size.height % 64 !== 0)) {
problems.push(i18n.t('parameters.invoke.layer.t2iAdapterIncompatibleDimensions'));
}
}
if (l.type === 'ip_adapter_layer') {
// Must have model
if (!l.ipAdapter.model) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected'));
}
// Model base must match
if (l.ipAdapter.model?.base !== model?.base) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'));
}
// Must have an image
if (!l.ipAdapter.image) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected'));
}
}
if (l.type === 'initial_image_layer') {
// Must have an image
if (!l.image) {
problems.push(i18n.t('parameters.invoke.layer.initialImageNoImageSelected'));
}
}
if (l.type === 'regional_guidance_layer') {
// Must have a region
if (l.maskObjects.length === 0) {
problems.push(i18n.t('parameters.invoke.layer.rgNoRegion'));
}
// Must have at least 1 prompt or IP Adapter
if (l.positivePrompt === null && l.negativePrompt === null && l.ipAdapters.length === 0) {
problems.push(i18n.t('parameters.invoke.layer.rgNoPromptsOrIPAdapters'));
}
l.ipAdapters.forEach((ipAdapter) => {
if (activeTabName === 'generation') {
// Handling for generation tab
controlLayers.present.layers
.filter((l) => l.isEnabled)
.forEach((l, i) => {
const layerLiteral = i18n.t('controlLayers.layers_one');
const layerNumber = i + 1;
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[l.type]);
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
const problems: string[] = [];
if (l.type === 'control_adapter_layer') {
// Must have model
if (!ipAdapter.model) {
if (!l.controlAdapter.model) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoModelSelected'));
}
// Model base must match
if (l.controlAdapter.model?.base !== model?.base) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterIncompatibleBaseModel'));
}
// Must have a control image OR, if it has a processor, it must have a processed image
if (!l.controlAdapter.image) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoImageSelected'));
} else if (l.controlAdapter.processorConfig && !l.controlAdapter.processedImage) {
problems.push(i18n.t('parameters.invoke.layer.controlAdapterImageNotProcessed'));
}
// T2I Adapters require images have dimensions that are multiples of 64
if (l.controlAdapter.type === 't2i_adapter' && (size.width % 64 !== 0 || size.height % 64 !== 0)) {
problems.push(i18n.t('parameters.invoke.layer.t2iAdapterIncompatibleDimensions'));
}
}
if (l.type === 'ip_adapter_layer') {
// Must have model
if (!l.ipAdapter.model) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected'));
}
// Model base must match
if (ipAdapter.model?.base !== model?.base) {
if (l.ipAdapter.model?.base !== model?.base) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'));
}
// Must have an image
if (!ipAdapter.image) {
if (!l.ipAdapter.image) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected'));
}
});
}
}
if (problems.length) {
const content = upperFirst(problems.join(', '));
reasons.push({ prefix, content });
}
});
} else {
// Handling for all other tabs
selectControlAdapterAll(controlAdapters)
.filter((ca) => ca.isEnabled)
.forEach((ca, i) => {
if (!ca.isEnabled) {
return;
}
if (l.type === 'initial_image_layer') {
// Must have an image
if (!l.image) {
problems.push(i18n.t('parameters.invoke.layer.initialImageNoImageSelected'));
}
}
if (!ca.model) {
reasons.push({ content: i18n.t('parameters.invoke.noModelForControlAdapter', { number: i + 1 }) });
} else if (ca.model.base !== model?.base) {
// This should never happen, just a sanity check
reasons.push({
content: i18n.t('parameters.invoke.incompatibleBaseModelForControlAdapter', { number: i + 1 }),
});
}
if (l.type === 'regional_guidance_layer') {
// Must have a region
if (l.maskObjects.length === 0) {
problems.push(i18n.t('parameters.invoke.layer.rgNoRegion'));
}
// Must have at least 1 prompt or IP Adapter
if (l.positivePrompt === null && l.negativePrompt === null && l.ipAdapters.length === 0) {
problems.push(i18n.t('parameters.invoke.layer.rgNoPromptsOrIPAdapters'));
}
l.ipAdapters.forEach((ipAdapter) => {
// Must have model
if (!ipAdapter.model) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected'));
}
// Model base must match
if (ipAdapter.model?.base !== model?.base) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'));
}
// Must have an image
if (!ipAdapter.image) {
problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected'));
}
});
}
if (
!ca.controlImage ||
(isControlNetOrT2IAdapter(ca) && !ca.processedControlImage && ca.processorType !== 'none')
) {
reasons.push({ content: i18n.t('parameters.invoke.noControlImageForControlAdapter', { number: i + 1 }) });
}
});
if (problems.length) {
const content = upperFirst(problems.join(', '));
reasons.push({ prefix, content });
}
});
} else {
// Handling for all other tabs
selectControlAdapterAll(controlAdapters)
.filter((ca) => ca.isEnabled)
.forEach((ca, i) => {
if (!ca.isEnabled) {
return;
}
if (!ca.model) {
reasons.push({ content: i18n.t('parameters.invoke.noModelForControlAdapter', { number: i + 1 }) });
} else if (ca.model.base !== model?.base) {
// This should never happen, just a sanity check
reasons.push({
content: i18n.t('parameters.invoke.incompatibleBaseModelForControlAdapter', { number: i + 1 }),
});
}
if (
!ca.controlImage ||
(isControlNetOrT2IAdapter(ca) && !ca.processedControlImage && ca.processorType !== 'none')
) {
reasons.push({
content: i18n.t('parameters.invoke.noControlImageForControlAdapter', { number: i + 1 }),
});
}
});
}
}
}
return { isReady: !reasons.length, reasons };
}
);
return { isReady: !reasons.length, reasons };
}
);
export const useIsReadyToEnqueue = () => {
const templates = useStore($templates);
const selector = useMemo(() => createSelector(templates), [templates]);
const value = useAppSelector(selector);
return value;
};

View File

@ -2,21 +2,16 @@ import 'reactflow/dist/style.css';
import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, Flex, Popover, PopoverAnchor, PopoverBody, PopoverContent } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppToaster } from 'app/components/Toaster';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import type { SelectInstance } from 'chakra-react-select';
import { useBuildNode } from 'features/nodes/hooks/useBuildNode';
import {
addNodePopoverClosed,
addNodePopoverOpened,
nodeAdded,
selectNodesSlice,
} from 'features/nodes/store/nodesSlice';
import { $templates, addNodePopoverClosed, addNodePopoverOpened, nodeAdded } from 'features/nodes/store/nodesSlice';
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
import { filter, map, memoize, some } from 'lodash-es';
import type { KeyboardEventHandler } from 'react';
import { memo, useCallback, useRef } from 'react';
import { memo, useCallback, useMemo, useRef } from 'react';
import { flushSync } from 'react-dom';
import { useHotkeys } from 'react-hotkeys-hook';
import type { HotkeyCallback } from 'react-hotkeys-hook/dist/types';
@ -54,14 +49,15 @@ const AddNodePopover = () => {
const { t } = useTranslation();
const selectRef = useRef<SelectInstance<ComboboxOption> | null>(null);
const inputRef = useRef<HTMLInputElement>(null);
const templates = useStore($templates);
const fieldFilter = useAppSelector((s) => s.nodes.present.connectionStartFieldType);
const handleFilter = useAppSelector((s) => s.nodes.present.connectionStartParams?.handleType);
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
const options = useMemo(() => {
// If we have a connection in progress, we need to filter the node choices
const filteredNodeTemplates = fieldFilter
? filter(nodes.templates, (template) => {
? filter(templates, (template) => {
const handles = handleFilter === 'source' ? template.inputs : template.outputs;
return some(handles, (handle) => {
@ -71,7 +67,7 @@ const AddNodePopover = () => {
return validateSourceAndTargetTypes(sourceType, targetType);
});
})
: map(nodes.templates);
: map(templates);
const options: ComboboxOption[] = map(filteredNodeTemplates, (template) => {
return {
@ -101,10 +97,9 @@ const AddNodePopover = () => {
options.sort((a, b) => a.label.localeCompare(b.label));
return { options };
});
return options;
}, [fieldFilter, handleFilter, t, templates]);
const { options } = useAppSelector(selector);
const isOpen = useAppSelector((s) => s.nodes.present.isAddNodePopoverOpen);
const addNode = useCallback(

View File

@ -1,7 +1,6 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { useStore } from '@nanostores/react';
import InvocationNode from 'features/nodes/components/flow/nodes/Invocation/InvocationNode';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { $templates } from 'features/nodes/store/nodesSlice';
import type { InvocationNodeData } from 'features/nodes/types/invocation';
import { memo, useMemo } from 'react';
import type { NodeProps } from 'reactflow';
@ -11,13 +10,8 @@ import InvocationNodeUnknownFallback from './InvocationNodeUnknownFallback';
const InvocationNodeWrapper = (props: NodeProps<InvocationNodeData>) => {
const { data, selected } = props;
const { id: nodeId, type, isOpen, label } = data;
const hasTemplateSelector = useMemo(
() => createSelector(selectNodesSlice, (nodes) => Boolean(nodes.templates[type])),
[type]
);
const hasTemplate = useAppSelector(hasTemplateSelector);
const templates = useStore($templates);
const hasTemplate = useMemo(() => Boolean(templates[type]), [templates, type]);
if (!hasTemplate) {
return (

View File

@ -1,36 +1,39 @@
import { Box, Flex, FormControl, FormLabel, HStack, Text } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import NotesTextarea from 'features/nodes/components/flow/nodes/Invocation/NotesTextarea';
import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { memo } from 'react';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import EditableNodeTitle from './details/EditableNodeTitle';
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1];
const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId);
const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined;
if (!isInvocationNode(lastSelectedNode) || !lastSelectedNodeTemplate) {
return;
}
return {
nodeId: lastSelectedNode.data.id,
nodeVersion: lastSelectedNode.data.version,
templateTitle: lastSelectedNodeTemplate.title,
};
});
const InspectorDetailsTab = () => {
const templates = useStore($templates);
const selector = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodes) => {
const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1];
const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId);
const lastSelectedNodeTemplate = lastSelectedNode ? templates[lastSelectedNode.data.type] : undefined;
if (!isInvocationNode(lastSelectedNode) || !lastSelectedNodeTemplate) {
return;
}
return {
nodeId: lastSelectedNode.data.id,
nodeVersion: lastSelectedNode.data.version,
templateTitle: lastSelectedNodeTemplate.title,
};
}),
[templates]
);
const data = useAppSelector(selector);
const { t } = useTranslation();

View File

@ -1,38 +1,41 @@
import { Box, Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { memo } from 'react';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import type { ImageOutput } from 'services/api/types';
import type { AnyResult } from 'services/events/types';
import ImageOutputPreview from './outputs/ImageOutputPreview';
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1];
const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId);
const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined;
const nes = nodes.nodeExecutionStates[lastSelectedNodeId ?? '__UNKNOWN_NODE__'];
if (!isInvocationNode(lastSelectedNode) || !nes || !lastSelectedNodeTemplate) {
return;
}
return {
outputs: nes.outputs,
outputType: lastSelectedNodeTemplate.outputType,
};
});
const InspectorOutputsTab = () => {
const templates = useStore($templates);
const selector = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodes) => {
const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1];
const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId);
const lastSelectedNodeTemplate = lastSelectedNode ? templates[lastSelectedNode.data.type] : undefined;
const nes = nodes.nodeExecutionStates[lastSelectedNodeId ?? '__UNKNOWN_NODE__'];
if (!isInvocationNode(lastSelectedNode) || !nes || !lastSelectedNodeTemplate) {
return;
}
return {
outputs: nes.outputs,
outputType: lastSelectedNodeTemplate.outputType,
};
}),
[templates]
);
const data = useAppSelector(selector);
const { t } = useTranslation();

View File

@ -1,25 +1,26 @@
import { useStore } from '@nanostores/react';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { IAINoContentFallback } from 'common/components/IAIImageFallback';
import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { memo } from 'react';
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
const selector = createMemoizedSelector(selectNodesSlice, (nodes) => {
const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1];
const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId);
const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined;
return {
template: lastSelectedNodeTemplate,
};
});
const NodeTemplateInspector = () => {
const { template } = useAppSelector(selector);
const templates = useStore($templates);
const selector = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodes) => {
const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1];
const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId);
const lastSelectedNodeTemplate = lastSelectedNode ? templates[lastSelectedNode.data.type] : undefined;
return lastSelectedNodeTemplate;
}),
[templates]
);
const template = useAppSelector(selector);
const { t } = useTranslation();
if (!template) {

View File

@ -1,4 +1,5 @@
import { useAppSelector } from 'app/store/storeHooks';
import { useStore } from '@nanostores/react';
import { $templates } from 'features/nodes/store/nodesSlice';
import { NODE_WIDTH } from 'features/nodes/types/constants';
import type { AnyNode, InvocationTemplate } from 'features/nodes/types/invocation';
import { buildCurrentImageNode } from 'features/nodes/util/node/buildCurrentImageNode';
@ -8,8 +9,7 @@ import { useCallback } from 'react';
import { useReactFlow } from 'reactflow';
export const useBuildNode = () => {
const nodeTemplates = useAppSelector((s) => s.nodes.present.templates);
const templates = useStore($templates);
const flow = useReactFlow();
return useCallback(
@ -41,10 +41,10 @@ export const useBuildNode = () => {
// TODO: Keep track of invocation types so we do not need to cast this
// We know it is safe because the caller of this function gets the `type` arg from the list of invocation templates.
const template = nodeTemplates[type] as InvocationTemplate;
const template = templates[type] as InvocationTemplate;
return buildInvocationNode(position, template);
},
[nodeTemplates, flow]
[templates, flow]
);
};

View File

@ -1,20 +1,26 @@
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate';
const selector = createSelector(selectNodesSlice, (nodes) =>
nodes.nodes.filter(isInvocationNode).some((node) => {
const template = nodes.templates[node.data.type];
if (!template) {
return false;
}
return getNeedsUpdate(node, template);
})
);
import { useMemo } from 'react';
export const useGetNodesNeedUpdate = () => {
const getNeedsUpdate = useAppSelector(selector);
return getNeedsUpdate;
const templates = useStore($templates);
const selector = useMemo(
() =>
createSelector(selectNodesSlice, (nodes) =>
nodes.nodes.filter(isInvocationNode).some((node) => {
const template = templates[node.data.type];
if (!template) {
return false;
}
return getNeedsUpdate(node, template);
})
),
[templates]
);
const needsUpdate = useAppSelector(selector);
return needsUpdate;
};

View File

@ -1,5 +1,7 @@
// TODO: enable this at some point
import { useStore } from '@nanostores/react';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { $templates } from 'features/nodes/store/nodesSlice';
import { getIsGraphAcyclic } from 'features/nodes/store/util/getIsGraphAcyclic';
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
import type { InvocationNodeData } from 'features/nodes/types/invocation';
@ -13,6 +15,7 @@ import type { Connection, Node } from 'reactflow';
export const useIsValidConnection = () => {
const store = useAppStore();
const templates = useStore($templates);
const shouldValidateGraph = useAppSelector((s) => s.workflowSettings.shouldValidateGraph);
const isValidConnection = useCallback(
({ source, sourceHandle, target, targetHandle }: Connection): boolean => {
@ -27,7 +30,7 @@ export const useIsValidConnection = () => {
}
const state = store.getState();
const { nodes, edges, templates } = state.nodes.present;
const { nodes, edges } = state.nodes.present;
// Find the source and target nodes
const sourceNode = nodes.find((node) => node.id === source) as Node<InvocationNodeData>;
@ -76,7 +79,7 @@ export const useIsValidConnection = () => {
// Graphs much be acyclic (no loops!)
return getIsGraphAcyclic(source, target, nodes, edges);
},
[shouldValidateGraph, store]
[shouldValidateGraph, templates, store]
);
return isValidConnection;

View File

@ -602,9 +602,6 @@ export const nodesSlice = createSlice({
state.connectionStartParams = null;
state.connectionStartFieldType = null;
},
nodeTemplatesBuilt: (state, action: PayloadAction<Record<string, InvocationTemplate>>) => {
state.templates = action.payload;
},
undo: (state) => state,
redo: (state) => state,
},
@ -728,7 +725,6 @@ export const {
selectionPasted,
viewportChanged,
edgeAdded,
nodeTemplatesBuilt,
undo,
redo,
} = nodesSlice.actions;
@ -770,6 +766,7 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
);
export const $cursorPos = atom<XYPosition | null>(null);
export const $templates = atom<Record<string, InvocationTemplate>>({});
export const $copiedNodes = atom<AnyNode[]>([]);
export const $copiedEdges = atom<InvocationNodeEdge[]>([]);

View File

@ -1,6 +1,6 @@
import type { NodesState } from 'features/nodes/store/types';
import type { FieldInputInstance, FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
import type { InvocationNode, InvocationNodeData, InvocationTemplate } from 'features/nodes/types/invocation';
import type { InvocationNode, InvocationNodeData } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation';
export const selectInvocationNode = (nodesSlice: NodesState, nodeId: string): InvocationNode | null => {
@ -15,14 +15,6 @@ export const selectNodeData = (nodesSlice: NodesState, nodeId: string): Invocati
return selectInvocationNode(nodesSlice, nodeId)?.data ?? null;
};
export const selectNodeTemplate = (nodesSlice: NodesState, nodeId: string): InvocationTemplate | null => {
const node = selectInvocationNode(nodesSlice, nodeId);
if (!node) {
return null;
}
return nodesSlice.templates[node.data.type] ?? null;
};
export const selectFieldInputInstance = (
nodesSlice: NodesState,
nodeId: string,

View File

@ -2,7 +2,6 @@ import type { FieldIdentifier, FieldType, StatefulFieldValue } from 'features/no
import type {
AnyNode,
InvocationNodeEdge,
InvocationTemplate,
NodeExecutionState,
} from 'features/nodes/types/invocation';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
@ -12,7 +11,6 @@ export type NodesState = {
_version: 1;
nodes: AnyNode[];
edges: InvocationNodeEdge[];
templates: Record<string, InvocationTemplate>;
connectionStartParams: OnConnectStartParams | null;
connectionStartFieldType: FieldType | null;
connectionMade: boolean;

View File

@ -1,11 +1,10 @@
import * as dagre from '@dagrejs/dagre';
import { logger } from 'app/logging/logger';
import { getStore } from 'app/store/nanostores/store';
import { $templates } from 'features/nodes/store/nodesSlice';
import { NODE_WIDTH } from 'features/nodes/types/constants';
import type { FieldInputInstance } from 'features/nodes/types/field';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import { buildFieldInputInstance } from 'features/nodes/util/schema/buildFieldInputInstance';
import { t } from 'i18next';
import { forEach } from 'lodash-es';
import type { NonNullableGraph } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
@ -18,11 +17,7 @@ import { v4 as uuidv4 } from 'uuid';
* @returns The workflow.
*/
export const graphToWorkflow = (graph: NonNullableGraph, autoLayout = true): WorkflowV3 => {
const invocationTemplates = getStore().getState().nodes.present.templates;
if (!invocationTemplates) {
throw new Error(t('app.storeNotInitialized'));
}
const templates = $templates.get();
// Initialize the workflow
const workflow: WorkflowV3 = {
@ -44,11 +39,11 @@ export const graphToWorkflow = (graph: NonNullableGraph, autoLayout = true): Wor
// Convert nodes
forEach(graph.nodes, (node) => {
const template = invocationTemplates[node.type];
const template = templates[node.type];
// Skip missing node templates - this is a best-effort
if (!template) {
logger('nodes').warn(`Node type ${node.type} not found in invocationTemplates`);
logger('nodes').warn(`Node type ${node.type} not found in templates`);
return;
}

View File

@ -1,5 +1,5 @@
import { $store } from 'app/store/nanostores/store';
import { deepClone } from 'common/util/deepClone';
import { $templates } from 'features/nodes/store/nodesSlice';
import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error';
import type { FieldType } from 'features/nodes/types/field';
import type { InvocationNodeData } from 'features/nodes/types/invocation';
@ -33,11 +33,7 @@ const zWorkflowMetaVersion = z.object({
* - Workflow schema version bumped to 2.0.0
*/
const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => {
const invocationTemplates = $store.get()?.getState().nodes.present.templates;
if (!invocationTemplates) {
throw new Error(t('app.storeNotInitialized'));
}
const templates = $templates.get();
workflowToMigrate.nodes.forEach((node) => {
if (node.type === 'invocation') {
@ -57,7 +53,7 @@ const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => {
(output.type as unknown as FieldType) = newFieldType;
});
// Add node pack
const invocationTemplate = invocationTemplates[node.data.type];
const invocationTemplate = templates[node.data.type];
const nodePack = invocationTemplate ? invocationTemplate.nodePack : t('common.unknown');
(node.data as unknown as InvocationNodeData).nodePack = nodePack;