From f5447cdc23f86a9642f47e1e98f5ac35056c1943 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 13 Feb 2024 16:30:00 +1100 Subject: [PATCH] feat(ui): workflow schema v3 (WIP) The changes aim to deduplicate data between workflows and node templates, decoupling workflows from internal implementation details. A good amount of data that was needlessly duplicated from the node template to the workflow is removed. These changes substantially reduce the file size of workflows (and therefore the images with embedded workflows): - Default T2I SD1.5 workflow JSON is reduced from 23.7kb (798 lines) to 10.9kb (407 lines). - Default tiled upscale workflow JSON is reduced from 102.7kb (3341 lines) to 51.9kb (1774 lines). The trade-off is that we need to reference node templates to get things like the field type and other things. In practice, this is a non-issue, because we need a node template to do anything with a node anyways. - Field types are not included in the workflow. They are always pulled from the node templates. The field type is now properly an internal implementation detail and we can change it as needed. Previously this would require a migration for the workflow itself. With the v3 schema, the structure of a field type is an internal implementation detail that we are free to change as we see fit. - Workflow nodes no long have an `outputs` property and there is no longer such a thing as a `FieldOutputInstance`. These are only on the templates. These were never referenced at a time when we didn't also have the templates available, and there'd be no reason to do so. - Node width and height are no longer stored in the node. These weren't used. Also, per https://reactflow.dev/api-reference/types/node, we shouldn't be programmatically changing these properties. A future enhancement can properly add node resizing. - `nodeTemplates` slice is merged back into `nodesSlice` as `nodes.templates`. Turns out it's just a hassle having these separate in separate slices. - Workflow migration logic updated to support the new schema. V1 workflows migrate all the way to v3 now. - Changes throughout the nodes code to accommodate the above changes. --- .../middleware/devtools/actionSanitizer.ts | 2 +- .../listeners/getOpenAPISchema.ts | 2 +- .../listeners/updateAllNodesRequested.ts | 3 +- .../listeners/workflowLoadRequested.ts | 2 +- invokeai/frontend/web/src/app/store/store.ts | 2 - .../frontend/web/src/app/store/storeHooks.ts | 3 +- invokeai/frontend/web/src/app/store/util.ts | 2 + .../src/common/hooks/useIsReadyToEnqueue.ts | 6 +- .../flow/AddNodePopover/AddNodePopover.tsx | 14 +- .../flow/edges/util/makeEdgeSelector.ts | 18 +- .../InvocationNodeCollapsedHandles.tsx | 19 +- .../Invocation/InvocationNodeWrapper.tsx | 4 +- .../Invocation/fields/EditableFieldTitle.tsx | 4 +- .../nodes/Invocation/fields/FieldTitle.tsx | 2 +- .../Invocation/fields/FieldTooltipContent.tsx | 6 +- .../nodes/Invocation/fields/InputField.tsx | 6 +- .../Invocation/fields/InputFieldRenderer.tsx | 29 +- .../Invocation/fields/LinearViewField.tsx | 4 +- .../nodes/Invocation/fields/OutputField.tsx | 10 +- .../inspector/InspectorDetailsTab.tsx | 5 +- .../inspector/InspectorOutputsTab.tsx | 5 +- .../inspector/InspectorTemplateTab.tsx | 5 +- .../hooks/useAnyOrDirectInputFieldNames.ts | 20 +- .../src/features/nodes/hooks/useBuildNode.ts | 2 +- .../hooks/useConnectionInputFieldNames.ts | 20 +- .../nodes/hooks/useConnectionState.ts | 10 +- .../nodes/hooks/useDoNodeVersionsMatch.ts | 18 +- .../nodes/hooks/useDoesInputHaveValue.ts | 12 +- .../src/features/nodes/hooks/useFieldData.ts | 23 - .../nodes/hooks/useFieldInputInstance.ts | 15 +- .../features/nodes/hooks/useFieldInputKind.ts | 15 +- .../nodes/hooks/useFieldInputTemplate.ts | 15 +- .../src/features/nodes/hooks/useFieldLabel.ts | 10 +- .../nodes/hooks/useFieldOutputInstance.ts | 23 - .../nodes/hooks/useFieldOutputTemplate.ts | 15 +- .../features/nodes/hooks/useFieldTemplate.ts | 21 +- .../nodes/hooks/useFieldTemplateTitle.ts | 16 +- .../features/nodes/hooks/useFieldType.ts.ts | 14 +- .../nodes/hooks/useGetNodesNeedUpdate.ts | 5 +- .../features/nodes/hooks/useHasImageOutput.ts | 13 +- .../features/nodes/hooks/useIsIntermediate.ts | 10 +- .../nodes/hooks/useIsValidConnection.ts | 44 +- .../nodes/hooks/useNodeClassification.ts | 17 +- .../src/features/nodes/hooks/useNodeData.ts | 7 +- .../src/features/nodes/hooks/useNodeLabel.ts | 9 +- .../nodes/hooks/useNodeNeedsUpdate.ts | 15 +- .../src/features/nodes/hooks/useNodePack.ts | 10 +- .../features/nodes/hooks/useNodeTemplate.ts | 13 +- .../nodes/hooks/useNodeTemplateByType.ts | 10 +- .../nodes/hooks/useNodeTemplateTitle.ts | 15 +- .../nodes/hooks/useOutputFieldNames.ts | 20 +- .../src/features/nodes/hooks/useUseCache.ts | 8 +- .../nodes/hooks/useWorkflowWatcher.ts | 4 +- .../web/src/features/nodes/store/actions.ts | 4 +- .../nodes/store/nodeTemplatesSlice.ts | 24 - .../src/features/nodes/store/nodesSlice.ts | 15 +- .../web/src/features/nodes/store/selectors.ts | 51 + .../web/src/features/nodes/store/types.ts | 5 +- .../store/util/findConnectionToValidHandle.ts | 30 +- .../util/makeIsConnectionValidSelector.ts | 2 +- .../src/features/nodes/store/workflowSlice.ts | 6 +- .../web/src/features/nodes/types/field.ts | 130 +-- .../src/features/nodes/types/invocation.ts | 24 +- .../web/src/features/nodes/types/v2/common.ts | 188 ++++ .../src/features/nodes/types/v2/constants.ts | 80 ++ .../web/src/features/nodes/types/v2/error.ts | 58 ++ .../web/src/features/nodes/types/v2/field.ts | 875 ++++++++++++++++++ .../src/features/nodes/types/v2/invocation.ts | 93 ++ .../src/features/nodes/types/v2/metadata.ts | 77 ++ .../src/features/nodes/types/v2/openapi.ts | 86 ++ .../web/src/features/nodes/types/v2/semver.ts | 21 + .../src/features/nodes/types/v2/workflow.ts | 89 ++ .../web/src/features/nodes/types/workflow.ts | 10 +- .../nodes/util/node/buildInvocationNode.ts | 22 +- .../features/nodes/util/node/nodeUpdate.ts | 1 - .../util/schema/buildFieldInputInstance.ts | 3 - .../nodes/util/workflow/buildWorkflow.ts | 20 +- .../nodes/util/workflow/migrations.ts | 32 +- .../nodes/util/workflow/validateWorkflow.ts | 4 +- .../workflowLibrary/hooks/useSaveWorkflow.ts | 4 +- 80 files changed, 1940 insertions(+), 616 deletions(-) create mode 100644 invokeai/frontend/web/src/app/store/util.ts delete mode 100644 invokeai/frontend/web/src/features/nodes/hooks/useFieldData.ts delete mode 100644 invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputInstance.ts delete mode 100644 invokeai/frontend/web/src/features/nodes/store/nodeTemplatesSlice.ts create mode 100644 invokeai/frontend/web/src/features/nodes/store/selectors.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/v2/common.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/v2/constants.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/v2/error.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/v2/field.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/v2/invocation.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/v2/metadata.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/v2/openapi.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/v2/semver.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/v2/workflow.ts diff --git a/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts b/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts index 2e2d2014b2..ed8c82d91c 100644 --- a/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts +++ b/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts @@ -1,6 +1,6 @@ import type { UnknownAction } from '@reduxjs/toolkit'; import { isAnyGraphBuilt } from 'features/nodes/store/actions'; -import { nodeTemplatesBuilt } from 'features/nodes/store/nodeTemplatesSlice'; +import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice'; import { cloneDeep } from 'lodash-es'; import { appInfoApi } from 'services/api/endpoints/appInfo'; import type { Graph } from 'services/api/types'; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema.ts index b2d3615909..88518e2c0b 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema.ts @@ -1,6 +1,6 @@ import { logger } from 'app/logging/logger'; import { parseify } from 'common/util/serialize'; -import { nodeTemplatesBuilt } from 'features/nodes/store/nodeTemplatesSlice'; +import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice'; import { parseSchema } from 'features/nodes/util/schema/parseSchema'; import { size } from 'lodash-es'; import { appInfoApi } from 'services/api/endpoints/appInfo'; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts index 752c3b09df..ac1298da5b 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts @@ -15,8 +15,7 @@ export const addUpdateAllNodesRequestedListener = () => { actionCreator: updateAllNodesRequested, effect: (action, { dispatch, getState }) => { const log = logger('nodes'); - const nodes = getState().nodes.nodes; - const templates = getState().nodeTemplates.templates; + const { nodes, templates } = getState().nodes; let unableToUpdateCount = 0; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts index 9307031e6d..ad41dc2654 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts @@ -18,7 +18,7 @@ export const addWorkflowLoadRequestedListener = () => { effect: (action, { dispatch, getState }) => { const log = logger('nodes'); const { workflow, asCopy } = action.payload; - const nodeTemplates = getState().nodeTemplates.templates; + const nodeTemplates = getState().nodes.templates; try { const { workflow: validatedWorkflow, warnings } = validateWorkflow(workflow, nodeTemplates); diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index e25e1351eb..270662c3d2 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -16,7 +16,6 @@ import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice'; import { loraPersistConfig, loraSlice } from 'features/lora/store/loraSlice'; import { modelManagerPersistConfig, modelManagerSlice } from 'features/modelManager/store/modelManagerSlice'; import { nodesPersistConfig, nodesSlice } from 'features/nodes/store/nodesSlice'; -import { nodesTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice'; import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice'; import { postprocessingPersistConfig, postprocessingSlice } from 'features/parameters/store/postprocessingSlice'; @@ -46,7 +45,6 @@ const allReducers = { [gallerySlice.name]: gallerySlice.reducer, [generationSlice.name]: generationSlice.reducer, [nodesSlice.name]: nodesSlice.reducer, - [nodesTemplatesSlice.name]: nodesTemplatesSlice.reducer, [postprocessingSlice.name]: postprocessingSlice.reducer, [systemSlice.name]: systemSlice.reducer, [configSlice.name]: configSlice.reducer, diff --git a/invokeai/frontend/web/src/app/store/storeHooks.ts b/invokeai/frontend/web/src/app/store/storeHooks.ts index f1a9aa979c..6bc904acb3 100644 --- a/invokeai/frontend/web/src/app/store/storeHooks.ts +++ b/invokeai/frontend/web/src/app/store/storeHooks.ts @@ -1,7 +1,8 @@ import type { AppThunkDispatch, RootState } from 'app/store/store'; import type { TypedUseSelectorHook } from 'react-redux'; -import { useDispatch, useSelector } from 'react-redux'; +import { useDispatch, useSelector, useStore } from 'react-redux'; // Use throughout your app instead of plain `useDispatch` and `useSelector` export const useAppDispatch = () => useDispatch(); export const useAppSelector: TypedUseSelectorHook = useSelector; +export const useAppStore = () => useStore(); diff --git a/invokeai/frontend/web/src/app/store/util.ts b/invokeai/frontend/web/src/app/store/util.ts new file mode 100644 index 0000000000..381f7f85d2 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/util.ts @@ -0,0 +1,2 @@ +export const EMPTY_ARRAY = []; +export const EMPTY_OBJECT = {}; diff --git a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts index 4952fa1c47..baa704e75c 100644 --- a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts +++ b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts @@ -8,7 +8,6 @@ import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types'; import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice'; import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { selectGenerationSlice } from 'features/parameters/store/generationSlice'; import { selectSystemSlice } from 'features/system/store/systemSlice'; @@ -23,11 +22,10 @@ const selector = createMemoizedSelector( selectGenerationSlice, selectSystemSlice, selectNodesSlice, - selectNodeTemplatesSlice, selectDynamicPromptsSlice, activeTabNameSelector, ], - (controlAdapters, generation, system, nodes, nodeTemplates, dynamicPrompts, activeTabName) => { + (controlAdapters, generation, system, nodes, dynamicPrompts, activeTabName) => { const { initialImage, model, positivePrompt } = generation; const { isConnected } = system; @@ -54,7 +52,7 @@ const selector = createMemoizedSelector( return; } - const nodeTemplate = nodeTemplates.templates[node.data.type]; + const nodeTemplate = nodes.templates[node.data.type]; if (!nodeTemplate) { // Node type not found diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx index b24b52c6ab..061209cafc 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx @@ -7,8 +7,12 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import type { SelectInstance } from 'chakra-react-select'; import { useBuildNode } from 'features/nodes/hooks/useBuildNode'; -import { addNodePopoverClosed, addNodePopoverOpened, nodeAdded } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; +import { + addNodePopoverClosed, + addNodePopoverOpened, + nodeAdded, + selectNodesSlice, +} from 'features/nodes/store/nodesSlice'; import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes'; import { filter, map, memoize, some } from 'lodash-es'; import type { KeyboardEventHandler } from 'react'; @@ -54,10 +58,10 @@ const AddNodePopover = () => { const fieldFilter = useAppSelector((s) => s.nodes.connectionStartFieldType); const handleFilter = useAppSelector((s) => s.nodes.connectionStartParams?.handleType); - const selector = createMemoizedSelector(selectNodeTemplatesSlice, (nodeTemplates) => { + const selector = createMemoizedSelector(selectNodesSlice, (nodes) => { // If we have a connection in progress, we need to filter the node choices const filteredNodeTemplates = fieldFilter - ? filter(nodeTemplates.templates, (template) => { + ? filter(nodes.templates, (template) => { const handles = handleFilter === 'source' ? template.inputs : template.outputs; return some(handles, (handle) => { @@ -67,7 +71,7 @@ const AddNodePopover = () => { return validateSourceAndTargetTypes(sourceType, targetType); }); }) - : map(nodeTemplates.templates); + : map(nodes.templates); const options: ComboboxOption[] = map(filteredNodeTemplates, (template) => { return { diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts index 4bfc588e67..ba40b4984c 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts @@ -1,10 +1,17 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; +import { selectFieldOutputTemplate } from 'features/nodes/store/selectors'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { getFieldColor } from './getEdgeColor'; +const defaultReturnValue = { + isSelected: false, + shouldAnimate: false, + stroke: colorTokenToCssVar('base.500'), +}; + export const makeEdgeSelector = ( source: string, sourceHandleId: string | null | undefined, @@ -12,14 +19,19 @@ export const makeEdgeSelector = ( targetHandleId: string | null | undefined, selected?: boolean ) => - createMemoizedSelector(selectNodesSlice, (nodes) => { + createMemoizedSelector(selectNodesSlice, (nodes): { isSelected: boolean; shouldAnimate: boolean; stroke: string } => { const sourceNode = nodes.nodes.find((node) => node.id === source); const targetNode = nodes.nodes.find((node) => node.id === target); const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode); - const isSelected = sourceNode?.selected || targetNode?.selected || selected; - const sourceType = isInvocationToInvocationEdge ? sourceNode?.data?.outputs[sourceHandleId || '']?.type : undefined; + const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected); + if (!sourceNode || !sourceHandleId) { + return defaultReturnValue; + } + + const outputFieldTemplate = selectFieldOutputTemplate(nodes, sourceNode.id, sourceHandleId); + const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined; const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500'); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeCollapsedHandles.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeCollapsedHandles.tsx index c287842f6e..b888e8a516 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeCollapsedHandles.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeCollapsedHandles.tsx @@ -1,6 +1,5 @@ import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens'; -import { useNodeData } from 'features/nodes/hooks/useNodeData'; -import { isInvocationNodeData } from 'features/nodes/types/invocation'; +import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate'; import { map } from 'lodash-es'; import type { CSSProperties } from 'react'; import { memo, useMemo } from 'react'; @@ -13,7 +12,7 @@ interface Props { const hiddenHandleStyles: CSSProperties = { visibility: 'hidden' }; const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => { - const data = useNodeData(nodeId); + const template = useNodeTemplate(nodeId); const { base600 } = useChakraThemeTokens(); const dummyHandleStyles: CSSProperties = useMemo( @@ -37,7 +36,7 @@ const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => { [dummyHandleStyles] ); - if (!isInvocationNodeData(data)) { + if (!template) { return null; } @@ -45,14 +44,14 @@ const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => { <> - {map(data.inputs, (input) => ( + {map(template.inputs, (input) => ( { ))} - {map(data.outputs, (output) => ( + {map(template.outputs, (output) => ( ) => { const { id: nodeId, type, isOpen, label } = data; const hasTemplateSelector = useMemo( - () => createSelector(selectNodeTemplatesSlice, (nodeTemplates) => Boolean(nodeTemplates.templates[type])), + () => createSelector(selectNodesSlice, (nodes) => Boolean(nodes.templates[type])), [type] ); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/EditableFieldTitle.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/EditableFieldTitle.tsx index c2231f703a..e02b1a1474 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/EditableFieldTitle.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/EditableFieldTitle.tsx @@ -22,7 +22,7 @@ import FieldTooltipContent from './FieldTooltipContent'; interface Props { nodeId: string; fieldName: string; - kind: 'input' | 'output'; + kind: 'inputs' | 'outputs'; isMissingInput?: boolean; withTooltip?: boolean; } @@ -58,7 +58,7 @@ const EditableFieldTitle = forwardRef((props: Props, ref) => { return ( : undefined} + label={withTooltip ? : undefined} openDelay={HANDLE_TOOLTIP_OPEN_DELAY} > { - const field = useFieldInstance(nodeId, fieldName); + const field = useFieldInputInstance(nodeId, fieldName); const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind); const isInputTemplate = isFieldInputTemplate(fieldTemplate); const fieldTypeName = useFieldTypeName(fieldTemplate?.type); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx index 2b9f7960e4..66b0d3f755 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx @@ -25,7 +25,7 @@ const InputField = ({ nodeId, fieldName }: Props) => { const [isHovered, setIsHovered] = useState(false); const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } = - useConnectionState({ nodeId, fieldName, kind: 'input' }); + useConnectionState({ nodeId, fieldName, kind: 'inputs' }); const isMissingInput = useMemo(() => { if (!fieldTemplate) { @@ -76,7 +76,7 @@ const InputField = ({ nodeId, fieldName }: Props) => { @@ -101,7 +101,7 @@ const InputField = ({ nodeId, fieldName }: Props) => { diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx index c1d52c1d4f..b6e331c114 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx @@ -1,6 +1,5 @@ -import { Box, Text } from '@invoke-ai/ui-library'; -import { useFieldInstance } from 'features/nodes/hooks/useFieldData'; -import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate'; +import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance'; +import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate'; import { isBoardFieldInputInstance, isBoardFieldInputTemplate, @@ -38,7 +37,6 @@ import { isVAEModelFieldInputTemplate, } from 'features/nodes/types/field'; import { memo } from 'react'; -import { useTranslation } from 'react-i18next'; import BoardFieldInputComponent from './inputs/BoardFieldInputComponent'; import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent'; @@ -63,17 +61,8 @@ type InputFieldProps = { }; const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { - const { t } = useTranslation(); - const fieldInstance = useFieldInstance(nodeId, fieldName); - const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input'); - - if (fieldTemplate?.fieldKind === 'output') { - return ( - - {t('nodes.outputFieldInInput')}: {fieldInstance?.type.name} - - ); - } + const fieldInstance = useFieldInputInstance(nodeId, fieldName); + const fieldTemplate = useFieldInputTemplate(nodeId, fieldName); if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) { return ; @@ -141,18 +130,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { return ; } - if (fieldInstance && fieldTemplate) { + if (fieldTemplate) { // Fallback for when there is no component for the type return null; } - - return ( - - - {t('nodes.unknownFieldType', { type: fieldInstance?.type.name })} - - - ); }; export default memo(InputFieldRenderer); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx index d0a30ecc3c..0cd199f7a4 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx @@ -62,7 +62,7 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => { /> - + {isValueChanged && ( { /> )} } + label={} openDelay={HANDLE_TOOLTIP_OPEN_DELAY} placement="top" > diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx index 48c4c0d740..f2d776a2da 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx @@ -1,6 +1,5 @@ import { Flex, FormControl, FormLabel, Tooltip } from '@invoke-ai/ui-library'; import { useConnectionState } from 'features/nodes/hooks/useConnectionState'; -import { useFieldOutputInstance } from 'features/nodes/hooks/useFieldOutputInstance'; import { useFieldOutputTemplate } from 'features/nodes/hooks/useFieldOutputTemplate'; import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants'; import type { PropsWithChildren } from 'react'; @@ -18,18 +17,17 @@ interface Props { const OutputField = ({ nodeId, fieldName }: Props) => { const { t } = useTranslation(); const fieldTemplate = useFieldOutputTemplate(nodeId, fieldName); - const fieldInstance = useFieldOutputInstance(nodeId, fieldName); const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } = - useConnectionState({ nodeId, fieldName, kind: 'output' }); + useConnectionState({ nodeId, fieldName, kind: 'outputs' }); - if (!fieldTemplate || !fieldInstance) { + if (!fieldTemplate) { return ( {t('nodes.unknownOutput', { - name: fieldTemplate?.title ?? fieldName, + name: fieldName, })} @@ -40,7 +38,7 @@ const OutputField = ({ nodeId, fieldName }: Props) => { return ( } + label={} openDelay={HANDLE_TOOLTIP_OPEN_DELAY} placement="top" shouldWrapChildren diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx index b7c9033d6b..d72d2f5aa8 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx @@ -6,19 +6,18 @@ import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableCon import NotesTextarea from 'features/nodes/components/flow/nodes/Invocation/NotesTextarea'; import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; import EditableNodeTitle from './details/EditableNodeTitle'; -const selector = createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { +const selector = createMemoizedSelector(selectNodesSlice, (nodes) => { const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1]; const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId); - const lastSelectedNodeTemplate = lastSelectedNode ? nodeTemplates.templates[lastSelectedNode.data.type] : undefined; + const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined; if (!isInvocationNode(lastSelectedNode) || !lastSelectedNodeTemplate) { return; diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx index ee7dfaa693..978eeddd24 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx @@ -5,7 +5,6 @@ import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -14,12 +13,12 @@ import type { AnyResult } from 'services/events/types'; import ImageOutputPreview from './outputs/ImageOutputPreview'; -const selector = createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { +const selector = createMemoizedSelector(selectNodesSlice, (nodes) => { const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1]; const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId); - const lastSelectedNodeTemplate = lastSelectedNode ? nodeTemplates.templates[lastSelectedNode.data.type] : undefined; + const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined; const nes = nodes.nodeExecutionStates[lastSelectedNodeId ?? '__UNKNOWN_NODE__']; diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorTemplateTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorTemplateTab.tsx index 28f0e82d68..ea6e8ed704 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorTemplateTab.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorTemplateTab.tsx @@ -3,16 +3,15 @@ import { useAppSelector } from 'app/store/storeHooks'; import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; -const selector = createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { +const selector = createMemoizedSelector(selectNodesSlice, (nodes) => { const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1]; const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId); - const lastSelectedNodeTemplate = lastSelectedNode ? nodeTemplates.templates[lastSelectedNode.data.type] : undefined; + const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined; return { template: lastSelectedNodeTemplate, diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts index d0263a8bda..c882924e24 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts @@ -1,26 +1,22 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; +import { EMPTY_ARRAY } from 'app/store/util'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames'; import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate'; import { keys, map } from 'lodash-es'; import { useMemo } from 'react'; -export const useAnyOrDirectInputFieldNames = (nodeId: string) => { +export const useAnyOrDirectInputFieldNames = (nodeId: string): string[] => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return []; + createMemoizedSelector(selectNodesSlice, (nodes) => { + const template = selectNodeTemplate(nodes, nodeId); + if (!template) { + return EMPTY_ARRAY; } - const nodeTemplate = nodeTemplates.templates[node.data.type]; - if (!nodeTemplate) { - return []; - } - const fields = map(nodeTemplate.inputs).filter( + const fields = map(template.inputs).filter( (field) => (['any', 'direct'].includes(field.input) || field.type.isCollectionOrScalar) && keys(TEMPLATE_BUILDER_MAP).includes(field.type.name) diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts index aecc931893..b19edf3c85 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts @@ -13,7 +13,7 @@ export const SHARED_NODE_PROPERTIES: Partial = { }; export const useBuildNode = () => { - const nodeTemplates = useAppSelector((s) => s.nodeTemplates.templates); + const nodeTemplates = useAppSelector((s) => s.nodes.templates); const flow = useReactFlow(); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts index 23f318517b..dc8a05b88c 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts @@ -1,28 +1,24 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; +import { EMPTY_ARRAY } from 'app/store/util'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames'; import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate'; import { keys, map } from 'lodash-es'; import { useMemo } from 'react'; -export const useConnectionInputFieldNames = (nodeId: string) => { +export const useConnectionInputFieldNames = (nodeId: string): string[] => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return []; - } - const nodeTemplate = nodeTemplates.templates[node.data.type]; - if (!nodeTemplate) { - return []; + createMemoizedSelector(selectNodesSlice, (nodes) => { + const template = selectNodeTemplate(nodes, nodeId); + if (!template) { + return EMPTY_ARRAY; } // get the visible fields - const fields = map(nodeTemplate.inputs).filter( + const fields = map(template.inputs).filter( (field) => (field.input === 'connection' && !field.type.isCollectionOrScalar) || !keys(TEMPLATE_BUILDER_MAP).includes(field.type.name) diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts index a6f8b663f6..97b96f323a 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts @@ -14,7 +14,7 @@ const selectIsConnectionInProgress = createSelector( export type UseConnectionStateProps = { nodeId: string; fieldName: string; - kind: 'input' | 'output'; + kind: 'inputs' | 'outputs'; }; export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => { @@ -26,8 +26,8 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta Boolean( nodes.edges.filter((edge) => { return ( - (kind === 'input' ? edge.target : edge.source) === nodeId && - (kind === 'input' ? edge.targetHandle : edge.sourceHandle) === fieldName + (kind === 'inputs' ? edge.target : edge.source) === nodeId && + (kind === 'inputs' ? edge.targetHandle : edge.sourceHandle) === fieldName ); }).length ) @@ -36,7 +36,7 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta ); const selectConnectionError = useMemo( - () => makeConnectionErrorSelector(nodeId, fieldName, kind === 'input' ? 'target' : 'source', fieldType), + () => makeConnectionErrorSelector(nodeId, fieldName, kind === 'inputs' ? 'target' : 'source', fieldType), [nodeId, fieldName, kind, fieldType] ); @@ -46,7 +46,7 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta Boolean( nodes.connectionStartParams?.nodeId === nodeId && nodes.connectionStartParams?.handleId === fieldName && - nodes.connectionStartParams?.handleType === { input: 'target', output: 'source' }[kind] + nodes.connectionStartParams?.handleType === { inputs: 'target', outputs: 'source' }[kind] ) ), [fieldName, kind, nodeId] diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts b/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts index bfbf0a3b2d..91994cf752 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts @@ -2,23 +2,19 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { compareVersions } from 'compare-versions'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeData, selectNodeTemplate } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useDoNodeVersionsMatch = (nodeId: string) => { +export const useDoNodeVersionsMatch = (nodeId: string): boolean => { const selector = useMemo( () => - createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { + createSelector(selectNodesSlice, (nodes) => { + const data = selectNodeData(nodes, nodeId); + const template = selectNodeTemplate(nodes, nodeId); + if (!template?.version || !data?.version) { return false; } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - if (!nodeTemplate?.version || !node.data?.version) { - return false; - } - return compareVersions(nodeTemplate.version, node.data.version) === 0; + return compareVersions(template.version, data.version) === 0; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts b/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts index cfe5c90d9c..5051eaa55b 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts @@ -1,18 +1,18 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeData } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useDoesInputHaveValue = (nodeId: string, fieldName: string) => { +export const useDoesInputHaveValue = (nodeId: string, fieldName: string): boolean => { const selector = useMemo( () => createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; + const data = selectNodeData(nodes, nodeId); + if (!data) { + return false; } - return node?.data.inputs[fieldName]?.value !== undefined; + return data.inputs[fieldName]?.value !== undefined; }), [fieldName, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldData.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldData.ts deleted file mode 100644 index 8b35a2d44b..0000000000 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldData.ts +++ /dev/null @@ -1,23 +0,0 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; -import { useAppSelector } from 'app/store/storeHooks'; -import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; -import { useMemo } from 'react'; - -export const useFieldInstance = (nodeId: string, fieldName: string) => { - const selector = useMemo( - () => - createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - return node?.data.inputs[fieldName]; - }), - [fieldName, nodeId] - ); - - const fieldData = useAppSelector(selector); - - return fieldData; -}; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputInstance.ts index 0793f1f952..25065e7aba 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputInstance.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputInstance.ts @@ -1,23 +1,20 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputInstance } from 'features/nodes/store/selectors'; +import type { FieldInputInstance } from 'features/nodes/types/field'; import { useMemo } from 'react'; -export const useFieldInputInstance = (nodeId: string, fieldName: string) => { +export const useFieldInputInstance = (nodeId: string, fieldName: string): FieldInputInstance | null => { const selector = useMemo( () => createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - return node.data.inputs[fieldName]; + return selectFieldInputInstance(nodes, nodeId, fieldName); }), [fieldName, nodeId] ); - const fieldTemplate = useAppSelector(selector); + const fieldData = useAppSelector(selector); - return fieldTemplate; + return fieldData; }; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts index 11d44dbde2..08de3d9b20 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts @@ -1,21 +1,16 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputTemplate } from 'features/nodes/store/selectors'; +import type { FieldInput } from 'features/nodes/types/field'; import { useMemo } from 'react'; export const useFieldInputKind = (nodeId: string, fieldName: string) => { const selector = useMemo( () => - createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - const fieldTemplate = nodeTemplate?.inputs[fieldName]; - return fieldTemplate?.input; + createSelector(selectNodesSlice, (nodes): FieldInput | null => { + const template = selectFieldInputTemplate(nodes, nodeId, fieldName); + return template?.input ?? null; }), [fieldName, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts index 8533d2be8d..e8289d7e07 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts @@ -1,20 +1,15 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputTemplate } from 'features/nodes/store/selectors'; +import type { FieldInputTemplate } from 'features/nodes/types/field'; import { useMemo } from 'react'; -export const useFieldInputTemplate = (nodeId: string, fieldName: string) => { +export const useFieldInputTemplate = (nodeId: string, fieldName: string): FieldInputTemplate | null => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - return nodeTemplate?.inputs[fieldName]; + createMemoizedSelector(selectNodesSlice, (nodes) => { + return selectFieldInputTemplate(nodes, nodeId, fieldName); }), [fieldName, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldLabel.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldLabel.ts index ef57956047..92eab8d1b1 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldLabel.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldLabel.ts @@ -1,18 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputInstance } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useFieldLabel = (nodeId: string, fieldName: string) => { +export const useFieldLabel = (nodeId: string, fieldName: string): string | null => { const selector = useMemo( () => createSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - return node?.data.inputs[fieldName]?.label; + return selectFieldInputInstance(nodes, nodeId, fieldName)?.label ?? null; }), [fieldName, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputInstance.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputInstance.ts deleted file mode 100644 index 8b71f1ea01..0000000000 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputInstance.ts +++ /dev/null @@ -1,23 +0,0 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; -import { useAppSelector } from 'app/store/storeHooks'; -import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; -import { useMemo } from 'react'; - -export const useFieldOutputInstance = (nodeId: string, fieldName: string) => { - const selector = useMemo( - () => - createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - return node.data.outputs[fieldName]; - }), - [fieldName, nodeId] - ); - - const fieldTemplate = useAppSelector(selector); - - return fieldTemplate; -}; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputTemplate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputTemplate.ts index 11f592b399..cb154071e9 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputTemplate.ts @@ -1,20 +1,15 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldOutputTemplate } from 'features/nodes/store/selectors'; +import type { FieldOutputTemplate } from 'features/nodes/types/field'; import { useMemo } from 'react'; -export const useFieldOutputTemplate = (nodeId: string, fieldName: string) => { +export const useFieldOutputTemplate = (nodeId: string, fieldName: string): FieldOutputTemplate | null => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - return nodeTemplate?.outputs[fieldName]; + createMemoizedSelector(selectNodesSlice, (nodes) => { + return selectFieldOutputTemplate(nodes, nodeId, fieldName); }), [fieldName, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts index 663821da81..7be4ecfd4d 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts @@ -1,21 +1,22 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { KIND_MAP } from 'features/nodes/types/constants'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors'; +import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field'; import { useMemo } from 'react'; -export const useFieldTemplate = (nodeId: string, fieldName: string, kind: 'input' | 'output') => { +export const useFieldTemplate = ( + nodeId: string, + fieldName: string, + kind: 'inputs' | 'outputs' +): FieldInputTemplate | FieldOutputTemplate | null => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; + createMemoizedSelector(selectNodesSlice, (nodes) => { + if (kind === 'inputs') { + return selectFieldInputTemplate(nodes, nodeId, fieldName); } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - return nodeTemplate?.[KIND_MAP[kind]][fieldName]; + return selectFieldOutputTemplate(nodes, nodeId, fieldName); }), [fieldName, kind, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts index cfdcda6efa..e41e019572 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts @@ -1,21 +1,17 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { KIND_MAP } from 'features/nodes/types/constants'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useFieldTemplateTitle = (nodeId: string, fieldName: string, kind: 'input' | 'output') => { +export const useFieldTemplateTitle = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): string | null => { const selector = useMemo( () => - createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; + createSelector(selectNodesSlice, (nodes) => { + if (kind === 'inputs') { + return selectFieldInputTemplate(nodes, nodeId, fieldName)?.title ?? null; } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - return nodeTemplate?.[KIND_MAP[kind]][fieldName]?.title; + return selectFieldOutputTemplate(nodes, nodeId, fieldName)?.title ?? null; }), [fieldName, kind, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts index a834726a13..a71a4d044e 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts @@ -1,20 +1,18 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { KIND_MAP } from 'features/nodes/types/constants'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors'; +import type { FieldType } from 'features/nodes/types/field'; import { useMemo } from 'react'; -export const useFieldType = (nodeId: string, fieldName: string, kind: 'input' | 'output') => { +export const useFieldType = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): FieldType | null => { const selector = useMemo( () => createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; + if (kind === 'inputs') { + return selectFieldInputTemplate(nodes, nodeId, fieldName)?.type ?? null; } - const field = node.data[KIND_MAP[kind]][fieldName]; - return field?.type; + return selectFieldOutputTemplate(nodes, nodeId, fieldName)?.type ?? null; }), [fieldName, kind, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts index a8019c92d6..71344197d5 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts @@ -1,13 +1,12 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate'; -const selector = createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => +const selector = createSelector(selectNodesSlice, (nodes) => nodes.nodes.filter(isInvocationNode).some((node) => { - const template = nodeTemplates.templates[node.data.type]; + const template = nodes.templates[node.data.type]; if (!template) { return false; } diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useHasImageOutput.ts b/invokeai/frontend/web/src/features/nodes/hooks/useHasImageOutput.ts index 617e713c7c..3ac3cabb22 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useHasImageOutput.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useHasImageOutput.ts @@ -1,24 +1,21 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; import { some } from 'lodash-es'; import { useMemo } from 'react'; -export const useHasImageOutput = (nodeId: string) => { +export const useHasImageOutput = (nodeId: string): boolean => { const selector = useMemo( () => createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } + const template = selectNodeTemplate(nodes, nodeId); return some( - node.data.outputs, + template?.outputs, (output) => output.type.name === 'ImageField' && // the image primitive node (node type "image") does not actually save the image, do not show the image-saving checkboxes - node.data.type !== 'image' + template?.type !== 'image' ); }), [nodeId] diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsIntermediate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsIntermediate.ts index 729bfa0cea..3fad0a2a86 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsIntermediate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsIntermediate.ts @@ -1,18 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeData } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useIsIntermediate = (nodeId: string) => { +export const useIsIntermediate = (nodeId: string): boolean => { const selector = useMemo( () => createSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } - return node.data.isIntermediate; + return selectNodeData(nodes, nodeId)?.isIntermediate ?? false; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts index 39a8abbe7a..ded05c7b9b 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts @@ -1,11 +1,10 @@ // TODO: enable this at some point -import { useAppSelector } from 'app/store/storeHooks'; +import { useAppSelector, useAppStore } from 'app/store/storeHooks'; import { getIsGraphAcyclic } from 'features/nodes/store/util/getIsGraphAcyclic'; import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes'; import type { InvocationNodeData } from 'features/nodes/types/invocation'; import { useCallback } from 'react'; import type { Connection, Node } from 'reactflow'; -import { useReactFlow } from 'reactflow'; /** * NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts` @@ -13,39 +12,34 @@ import { useReactFlow } from 'reactflow'; */ export const useIsValidConnection = () => { - const flow = useReactFlow(); + const store = useAppStore(); const shouldValidateGraph = useAppSelector((s) => s.nodes.shouldValidateGraph); const isValidConnection = useCallback( ({ source, sourceHandle, target, targetHandle }: Connection): boolean => { - const edges = flow.getEdges(); - const nodes = flow.getNodes(); // Connection must have valid targets if (!(source && sourceHandle && target && targetHandle)) { return false; } - // Find the source and target nodes - const sourceNode = flow.getNode(source) as Node; - const targetNode = flow.getNode(target) as Node; - - // Conditional guards against undefined nodes/handles - if (!(sourceNode && targetNode && sourceNode.data && targetNode.data)) { - return false; - } - - const sourceField = sourceNode.data.outputs[sourceHandle]; - const targetField = targetNode.data.inputs[targetHandle]; - - if (!sourceField || !targetField) { - // something has gone terribly awry - return false; - } - if (source === target) { // Don't allow nodes to connect to themselves, even if validation is disabled return false; } + const state = store.getState(); + const { nodes, edges, templates } = state.nodes; + + // Find the source and target nodes + const sourceNode = nodes.find((node) => node.id === source) as Node; + const targetNode = nodes.find((node) => node.id === target) as Node; + const sourceFieldTemplate = templates[sourceNode.data.type]?.outputs[sourceHandle]; + const targetFieldTemplate = templates[targetNode.data.type]?.inputs[targetHandle]; + + // Conditional guards against undefined nodes/handles + if (!(sourceFieldTemplate && targetFieldTemplate)) { + return false; + } + if (!shouldValidateGraph) { // manual override! return true; @@ -69,20 +63,20 @@ export const useIsValidConnection = () => { return edge.target === target && edge.targetHandle === targetHandle; }) && // except CollectionItem inputs can have multiples - targetField.type.name !== 'CollectionItemField' + targetFieldTemplate.type.name !== 'CollectionItemField' ) { return false; } // Must use the originalType here if it exists - if (!validateSourceAndTargetTypes(sourceField.type, targetField.type)) { + if (!validateSourceAndTargetTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) { return false; } // Graphs much be acyclic (no loops!) return getIsGraphAcyclic(source, target, nodes, edges); }, - [flow, shouldValidateGraph] + [shouldValidateGraph, store] ); return isValidConnection; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeClassification.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeClassification.ts index c61721030e..bab8ff3f19 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeClassification.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeClassification.ts @@ -1,20 +1,15 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; +import type { Classification } from 'features/nodes/types/common'; import { useMemo } from 'react'; -export const useNodeClassification = (nodeId: string) => { +export const useNodeClassification = (nodeId: string): Classification | null => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - return nodeTemplate?.classification; + createSelector(selectNodesSlice, (nodes) => { + return selectNodeTemplate(nodes, nodeId)?.classification ?? null; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeData.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeData.ts index c507def5ee..fa21008ff8 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeData.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeData.ts @@ -1,14 +1,15 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; +import { selectNodeData } from 'features/nodes/store/selectors'; +import type { InvocationNodeData } from 'features/nodes/types/invocation'; import { useMemo } from 'react'; -export const useNodeData = (nodeId: string) => { +export const useNodeData = (nodeId: string): InvocationNodeData | null => { const selector = useMemo( () => createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - return node?.data; + return selectNodeData(nodes, nodeId); }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts index c5fc43742a..31dcb9c466 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts @@ -1,19 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeData } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; export const useNodeLabel = (nodeId: string) => { const selector = useMemo( () => createSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } - - return node.data.label; + return selectNodeData(nodes, nodeId)?.label ?? null; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts index e6efa667f1..aa0294f70f 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts @@ -1,21 +1,20 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectInvocationNode, selectNodeTemplate } from 'features/nodes/store/selectors'; import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate'; import { useMemo } from 'react'; export const useNodeNeedsUpdate = (nodeId: string) => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - const template = nodeTemplates.templates[node?.data.type ?? '']; - if (isInvocationNode(node) && template) { - return getNeedsUpdate(node, template); + createMemoizedSelector(selectNodesSlice, (nodes) => { + const node = selectInvocationNode(nodes, nodeId); + const template = selectNodeTemplate(nodes, nodeId); + if (!node || !template) { + return false; } - return false; + return getNeedsUpdate(node, template); }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodePack.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodePack.ts index ca3dd5cfdf..5c920866e9 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodePack.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodePack.ts @@ -1,18 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeData } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useNodePack = (nodeId: string) => { +export const useNodePack = (nodeId: string): string | null => { const selector = useMemo( () => createSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } - return node.data.nodePack; + return selectNodeData(nodes, nodeId)?.nodePack ?? null; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplate.ts index 7544cbff46..866c9275fb 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplate.ts @@ -1,16 +1,15 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; +import type { InvocationTemplate } from 'features/nodes/types/invocation'; import { useMemo } from 'react'; -export const useNodeTemplate = (nodeId: string) => { +export const useNodeTemplate = (nodeId: string): InvocationTemplate | null => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - return nodeTemplate; + createSelector(selectNodesSlice, (nodes) => { + return selectNodeTemplate(nodes, nodeId); }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts index 8fd1345f6f..a0c870f694 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts @@ -1,14 +1,14 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; +import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; import type { InvocationTemplate } from 'features/nodes/types/invocation'; import { useMemo } from 'react'; -export const useNodeTemplateByType = (type: string) => { +export const useNodeTemplateByType = (type: string): InvocationTemplate | null => { const selector = useMemo( () => - createMemoizedSelector(selectNodeTemplatesSlice, (nodeTemplates): InvocationTemplate | undefined => { - return nodeTemplates.templates[type]; + createSelector(selectNodesSlice, (nodes) => { + return nodes.templates[type] ?? null; }), [type] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts index 15d2ec38c3..120b8c758b 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts @@ -1,21 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useNodeTemplateTitle = (nodeId: string) => { +export const useNodeTemplateTitle = (nodeId: string): string | null => { const selector = useMemo( () => - createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } - const nodeTemplate = node ? nodeTemplates.templates[node.data.type] : undefined; - - return nodeTemplate?.title; + createSelector(selectNodesSlice, (nodes) => { + return selectNodeTemplate(nodes, nodeId)?.title ?? null; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts index e352bd8b90..24863080a7 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts @@ -1,8 +1,8 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; +import { EMPTY_ARRAY } from 'app/store/util'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames'; import { map } from 'lodash-es'; import { useMemo } from 'react'; @@ -10,17 +10,13 @@ import { useMemo } from 'react'; export const useOutputFieldNames = (nodeId: string) => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return []; - } - const nodeTemplate = nodeTemplates.templates[node.data.type]; - if (!nodeTemplate) { - return []; + createSelector(selectNodesSlice, (nodes) => { + const template = selectNodeTemplate(nodes, nodeId); + if (!template) { + return EMPTY_ARRAY; } - return getSortedFilteredFieldNames(map(nodeTemplate.outputs)); + return getSortedFilteredFieldNames(map(template.outputs)); }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useUseCache.ts b/invokeai/frontend/web/src/features/nodes/hooks/useUseCache.ts index edfc990882..aaca80039b 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useUseCache.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useUseCache.ts @@ -1,18 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeData } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; export const useUseCache = (nodeId: string) => { const selector = useMemo( () => createSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } - return node.data.useCache; + return selectNodeData(nodes, nodeId)?.useCache ?? false; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useWorkflowWatcher.ts b/invokeai/frontend/web/src/features/nodes/hooks/useWorkflowWatcher.ts index 0e4806d81b..5d79c15442 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useWorkflowWatcher.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useWorkflowWatcher.ts @@ -2,14 +2,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; import { selectWorkflowSlice } from 'features/nodes/store/workflowSlice'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; import type { BuildWorkflowArg } from 'features/nodes/util/workflow/buildWorkflow'; import { buildWorkflowFast } from 'features/nodes/util/workflow/buildWorkflow'; import { debounce } from 'lodash-es'; import { atom } from 'nanostores'; import { useEffect } from 'react'; -export const $builtWorkflow = atom(null); +export const $builtWorkflow = atom(null); const debouncedBuildWorkflow = debounce((arg: BuildWorkflowArg) => { $builtWorkflow.set(buildWorkflowFast(arg)); diff --git a/invokeai/frontend/web/src/features/nodes/store/actions.ts b/invokeai/frontend/web/src/features/nodes/store/actions.ts index 00457494bf..b32a3ba997 100644 --- a/invokeai/frontend/web/src/features/nodes/store/actions.ts +++ b/invokeai/frontend/web/src/features/nodes/store/actions.ts @@ -1,5 +1,5 @@ import { createAction, isAnyOf } from '@reduxjs/toolkit'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; import type { Graph } from 'services/api/types'; export const textToImageGraphBuilt = createAction('nodes/textToImageGraphBuilt'); @@ -21,4 +21,4 @@ export const workflowLoadRequested = createAction<{ export const updateAllNodesRequested = createAction('nodes/updateAllNodesRequested'); -export const workflowLoaded = createAction('workflow/workflowLoaded'); +export const workflowLoaded = createAction('workflow/workflowLoaded'); diff --git a/invokeai/frontend/web/src/features/nodes/store/nodeTemplatesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodeTemplatesSlice.ts deleted file mode 100644 index c211131aab..0000000000 --- a/invokeai/frontend/web/src/features/nodes/store/nodeTemplatesSlice.ts +++ /dev/null @@ -1,24 +0,0 @@ -import type { PayloadAction } from '@reduxjs/toolkit'; -import { createSlice } from '@reduxjs/toolkit'; -import type { RootState } from 'app/store/store'; -import type { InvocationTemplate } from 'features/nodes/types/invocation'; - -import type { NodeTemplatesState } from './types'; - -export const initialNodeTemplatesState: NodeTemplatesState = { - templates: {}, -}; - -export const nodesTemplatesSlice = createSlice({ - name: 'nodeTemplates', - initialState: initialNodeTemplatesState, - reducers: { - nodeTemplatesBuilt: (state, action: PayloadAction>) => { - state.templates = action.payload; - }, - }, -}); - -export const { nodeTemplatesBuilt } = nodesTemplatesSlice.actions; - -export const selectNodeTemplatesSlice = (state: RootState) => state.nodeTemplates; diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index aee01b381b..6b596da063 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -42,7 +42,7 @@ import { zT2IAdapterModelFieldValue, zVAEModelFieldValue, } from 'features/nodes/types/field'; -import type { AnyNode, NodeExecutionState } from 'features/nodes/types/invocation'; +import type { AnyNode, InvocationTemplate, NodeExecutionState } from 'features/nodes/types/invocation'; import { isInvocationNode, isNotesNode, zNodeStatus } from 'features/nodes/types/invocation'; import { cloneDeep, forEach } from 'lodash-es'; import type { @@ -92,6 +92,7 @@ export const initialNodesState: NodesState = { _version: 1, nodes: [], edges: [], + templates: {}, connectionStartParams: null, connectionStartFieldType: null, connectionMade: false, @@ -190,6 +191,7 @@ export const nodesSlice = createSlice({ node, state.nodes, state.edges, + state.templates, nodeId, handleId, handleType, @@ -224,12 +226,12 @@ export const nodesSlice = createSlice({ if (!nodeId || !handleId) { return; } - const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId); - const node = state.nodes?.[nodeIndex]; + const node = state.nodes.find((n) => n.id === nodeId); if (!isInvocationNode(node)) { return; } - const field = handleType === 'source' ? node.data.outputs[handleId] : node.data.inputs[handleId]; + const template = state.templates[node.data.type]; + const field = handleType === 'source' ? template?.outputs[handleId] : template?.inputs[handleId]; state.connectionStartFieldType = field?.type ?? null; }, connectionMade: (state, action: PayloadAction) => { @@ -260,6 +262,7 @@ export const nodesSlice = createSlice({ mouseOverNode, state.nodes, state.edges, + state.templates, nodeId, handleId, handleType, @@ -677,6 +680,9 @@ export const nodesSlice = createSlice({ selectionModeChanged: (state, action: PayloadAction) => { state.selectionMode = action.payload ? SelectionMode.Full : SelectionMode.Partial; }, + nodeTemplatesBuilt: (state, action: PayloadAction>) => { + state.templates = action.payload; + }, }, extraReducers: (builder) => { builder.addCase(workflowLoaded, (state, action) => { @@ -808,6 +814,7 @@ export const { shouldValidateGraphChanged, viewportChanged, edgeAdded, + nodeTemplatesBuilt, } = nodesSlice.actions; // This is used for tracking `state.workflow.isTouched` diff --git a/invokeai/frontend/web/src/features/nodes/store/selectors.ts b/invokeai/frontend/web/src/features/nodes/store/selectors.ts new file mode 100644 index 0000000000..90675d6270 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/selectors.ts @@ -0,0 +1,51 @@ +import type { NodesState } from 'features/nodes/store/types'; +import type { FieldInputInstance, FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field'; +import type { InvocationNode, InvocationNodeData, InvocationTemplate } from 'features/nodes/types/invocation'; +import { isInvocationNode } from 'features/nodes/types/invocation'; + +export const selectInvocationNode = (nodesSlice: NodesState, nodeId: string): InvocationNode | null => { + const node = nodesSlice.nodes.find((node) => node.id === nodeId); + if (!isInvocationNode(node)) { + return null; + } + return node; +}; + +export const selectNodeData = (nodesSlice: NodesState, nodeId: string): InvocationNodeData | null => { + return selectInvocationNode(nodesSlice, nodeId)?.data ?? null; +}; + +export const selectNodeTemplate = (nodesSlice: NodesState, nodeId: string): InvocationTemplate | null => { + const node = selectInvocationNode(nodesSlice, nodeId); + if (!node) { + return null; + } + return nodesSlice.templates[node.data.type] ?? null; +}; + +export const selectFieldInputInstance = ( + nodesSlice: NodesState, + nodeId: string, + fieldName: string +): FieldInputInstance | null => { + const data = selectNodeData(nodesSlice, nodeId); + return data?.inputs[fieldName] ?? null; +}; + +export const selectFieldInputTemplate = ( + nodesSlice: NodesState, + nodeId: string, + fieldName: string +): FieldInputTemplate | null => { + const template = selectNodeTemplate(nodesSlice, nodeId); + return template?.inputs[fieldName] ?? null; +}; + +export const selectFieldOutputTemplate = ( + nodesSlice: NodesState, + nodeId: string, + fieldName: string +): FieldOutputTemplate | null => { + const template = selectNodeTemplate(nodesSlice, nodeId); + return template?.outputs[fieldName] ?? null; +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/types.ts b/invokeai/frontend/web/src/features/nodes/store/types.ts index 8b0de447e4..1a040d2c70 100644 --- a/invokeai/frontend/web/src/features/nodes/store/types.ts +++ b/invokeai/frontend/web/src/features/nodes/store/types.ts @@ -5,13 +5,14 @@ import type { InvocationTemplate, NodeExecutionState, } from 'features/nodes/types/invocation'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; import type { OnConnectStartParams, SelectionMode, Viewport, XYPosition } from 'reactflow'; export type NodesState = { _version: 1; nodes: AnyNode[]; edges: InvocationNodeEdge[]; + templates: Record; connectionStartParams: OnConnectStartParams | null; connectionStartFieldType: FieldType | null; connectionMade: boolean; @@ -38,7 +39,7 @@ export type FieldIdentifierWithValue = FieldIdentifier & { value: StatefulFieldValue; }; -export type WorkflowsState = Omit & { +export type WorkflowsState = Omit & { _version: 1; isTouched: boolean; mode: WorkflowMode; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts b/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts index 9f2c37a2ad..ef899c5f41 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts @@ -1,4 +1,6 @@ -import type { FieldInputInstance, FieldOutputInstance, FieldType } from 'features/nodes/types/field'; +import type { FieldInputTemplate, FieldOutputTemplate, FieldType } from 'features/nodes/types/field'; +import type { AnyNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation'; +import { isInvocationNode } from 'features/nodes/types/invocation'; import type { Connection, Edge, HandleType, Node } from 'reactflow'; import { getIsGraphAcyclic } from './getIsGraphAcyclic'; @@ -9,7 +11,7 @@ const isValidConnection = ( handleCurrentType: HandleType, handleCurrentFieldType: FieldType, node: Node, - handle: FieldInputInstance | FieldOutputInstance + handle: FieldInputTemplate | FieldOutputTemplate ) => { let isValidConnection = true; if (handleCurrentType === 'source') { @@ -38,24 +40,31 @@ const isValidConnection = ( }; export const findConnectionToValidHandle = ( - node: Node, - nodes: Node[], - edges: Edge[], + node: AnyNode, + nodes: AnyNode[], + edges: InvocationNodeEdge[], + templates: Record, handleCurrentNodeId: string, handleCurrentName: string, handleCurrentType: HandleType, handleCurrentFieldType: FieldType ): Connection | null => { - if (node.id === handleCurrentNodeId) { + if (node.id === handleCurrentNodeId || !isInvocationNode(node)) { return null; } - const handles = handleCurrentType === 'source' ? node.data.inputs : node.data.outputs; + const template = templates[node.data.type]; + + if (!template) { + return null; + } + + const handles = handleCurrentType === 'source' ? template.inputs : template.outputs; //Prioritize handles whos name matches the node we're coming from - if (handles[handleCurrentName]) { - const handle = handles[handleCurrentName]; + const handle = handles[handleCurrentName]; + if (handle) { const sourceID = handleCurrentType === 'source' ? handleCurrentNodeId : node.id; const targetID = handleCurrentType === 'source' ? node.id : handleCurrentNodeId; const sourceHandle = handleCurrentType === 'source' ? handleCurrentName : handle.name; @@ -77,6 +86,9 @@ export const findConnectionToValidHandle = ( for (const handleName in handles) { const handle = handles[handleName]; + if (!handle) { + continue; + } const sourceID = handleCurrentType === 'source' ? handleCurrentNodeId : node.id; const targetID = handleCurrentType === 'source' ? node.id : handleCurrentNodeId; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts index 8575932cbd..d6ea0d9c86 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts @@ -16,7 +16,7 @@ export const makeConnectionErrorSelector = ( nodeId: string, fieldName: string, handleType: HandleType, - fieldType?: FieldType + fieldType?: FieldType | null ) => { return createSelector(selectNodesSlice, (nodesSlice) => { if (!fieldType) { diff --git a/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts b/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts index 2978f25138..4f40a68e1f 100644 --- a/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts @@ -10,10 +10,10 @@ import type { } from 'features/nodes/store/types'; import type { FieldIdentifier } from 'features/nodes/types/field'; import { isInvocationNode } from 'features/nodes/types/invocation'; -import type { WorkflowCategory, WorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowCategory, WorkflowV3 } from 'features/nodes/types/workflow'; import { cloneDeep, isEqual, omit, uniqBy } from 'lodash-es'; -export const blankWorkflow: Omit = { +export const blankWorkflow: Omit = { name: '', author: '', description: '', @@ -22,7 +22,7 @@ export const blankWorkflow: Omit = { tags: '', notes: '', exposedFields: [], - meta: { version: '2.0.0', category: 'user' }, + meta: { version: '3.0.0', category: 'user' }, id: undefined, }; diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts index 38f1af55dd..aa6164d6e5 100644 --- a/invokeai/frontend/web/src/features/nodes/types/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -46,20 +46,11 @@ export type FieldInput = z.infer; export const zFieldUIComponent = z.enum(['none', 'textarea', 'slider']); export type FieldUIComponent = z.infer; -export const zFieldInstanceBase = z.object({ - id: z.string().trim().min(1), +export const zFieldInputInstanceBase = z.object({ name: z.string().trim().min(1), -}); -export const zFieldInputInstanceBase = zFieldInstanceBase.extend({ - fieldKind: z.literal('input'), label: z.string().nullish(), }); -export const zFieldOutputInstanceBase = zFieldInstanceBase.extend({ - fieldKind: z.literal('output'), -}); -export type FieldInstanceBase = z.infer; export type FieldInputInstanceBase = z.infer; -export type FieldOutputInstanceBase = z.infer; export const zFieldTemplateBase = z.object({ name: z.string().min(1), @@ -102,12 +93,8 @@ export const zIntegerFieldType = zFieldTypeBase.extend({ }); export const zIntegerFieldValue = z.number().int(); export const zIntegerFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zIntegerFieldType, value: zIntegerFieldValue, }); -export const zIntegerFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zIntegerFieldType, -}); export const zIntegerFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zIntegerFieldType, default: zIntegerFieldValue, @@ -136,12 +123,8 @@ export const zFloatFieldType = zFieldTypeBase.extend({ }); export const zFloatFieldValue = z.number(); export const zFloatFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zFloatFieldType, value: zFloatFieldValue, }); -export const zFloatFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zFloatFieldType, -}); export const zFloatFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zFloatFieldType, default: zFloatFieldValue, @@ -157,7 +140,6 @@ export const zFloatFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type FloatFieldType = z.infer; export type FloatFieldValue = z.infer; export type FloatFieldInputInstance = z.infer; -export type FloatFieldOutputInstance = z.infer; export type FloatFieldInputTemplate = z.infer; export type FloatFieldOutputTemplate = z.infer; export const isFloatFieldInputInstance = (val: unknown): val is FloatFieldInputInstance => @@ -172,12 +154,8 @@ export const zStringFieldType = zFieldTypeBase.extend({ }); export const zStringFieldValue = z.string(); export const zStringFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zStringFieldType, value: zStringFieldValue, }); -export const zStringFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zStringFieldType, -}); export const zStringFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zStringFieldType, default: zStringFieldValue, @@ -191,7 +169,6 @@ export const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type StringFieldType = z.infer; export type StringFieldValue = z.infer; export type StringFieldInputInstance = z.infer; -export type StringFieldOutputInstance = z.infer; export type StringFieldInputTemplate = z.infer; export type StringFieldOutputTemplate = z.infer; export const isStringFieldInputInstance = (val: unknown): val is StringFieldInputInstance => @@ -206,12 +183,8 @@ export const zBooleanFieldType = zFieldTypeBase.extend({ }); export const zBooleanFieldValue = z.boolean(); export const zBooleanFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zBooleanFieldType, value: zBooleanFieldValue, }); -export const zBooleanFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zBooleanFieldType, -}); export const zBooleanFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zBooleanFieldType, default: zBooleanFieldValue, @@ -222,7 +195,6 @@ export const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type BooleanFieldType = z.infer; export type BooleanFieldValue = z.infer; export type BooleanFieldInputInstance = z.infer; -export type BooleanFieldOutputInstance = z.infer; export type BooleanFieldInputTemplate = z.infer; export type BooleanFieldOutputTemplate = z.infer; export const isBooleanFieldInputInstance = (val: unknown): val is BooleanFieldInputInstance => @@ -237,12 +209,8 @@ export const zEnumFieldType = zFieldTypeBase.extend({ }); export const zEnumFieldValue = z.string(); export const zEnumFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zEnumFieldType, value: zEnumFieldValue, }); -export const zEnumFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zEnumFieldType, -}); export const zEnumFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zEnumFieldType, default: zEnumFieldValue, @@ -255,7 +223,6 @@ export const zEnumFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type EnumFieldType = z.infer; export type EnumFieldValue = z.infer; export type EnumFieldInputInstance = z.infer; -export type EnumFieldOutputInstance = z.infer; export type EnumFieldInputTemplate = z.infer; export type EnumFieldOutputTemplate = z.infer; export const isEnumFieldInputInstance = (val: unknown): val is EnumFieldInputInstance => @@ -270,12 +237,8 @@ export const zImageFieldType = zFieldTypeBase.extend({ }); export const zImageFieldValue = zImageField.optional(); export const zImageFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zImageFieldType, value: zImageFieldValue, }); -export const zImageFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zImageFieldType, -}); export const zImageFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zImageFieldType, default: zImageFieldValue, @@ -286,7 +249,6 @@ export const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type ImageFieldType = z.infer; export type ImageFieldValue = z.infer; export type ImageFieldInputInstance = z.infer; -export type ImageFieldOutputInstance = z.infer; export type ImageFieldInputTemplate = z.infer; export type ImageFieldOutputTemplate = z.infer; export const isImageFieldInputInstance = (val: unknown): val is ImageFieldInputInstance => @@ -301,12 +263,8 @@ export const zBoardFieldType = zFieldTypeBase.extend({ }); export const zBoardFieldValue = zBoardField.optional(); export const zBoardFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zBoardFieldType, value: zBoardFieldValue, }); -export const zBoardFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zBoardFieldType, -}); export const zBoardFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zBoardFieldType, default: zBoardFieldValue, @@ -317,7 +275,6 @@ export const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type BoardFieldType = z.infer; export type BoardFieldValue = z.infer; export type BoardFieldInputInstance = z.infer; -export type BoardFieldOutputInstance = z.infer; export type BoardFieldInputTemplate = z.infer; export type BoardFieldOutputTemplate = z.infer; export const isBoardFieldInputInstance = (val: unknown): val is BoardFieldInputInstance => @@ -332,12 +289,8 @@ export const zColorFieldType = zFieldTypeBase.extend({ }); export const zColorFieldValue = zColorField.optional(); export const zColorFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zColorFieldType, value: zColorFieldValue, }); -export const zColorFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zColorFieldType, -}); export const zColorFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zColorFieldType, default: zColorFieldValue, @@ -348,7 +301,6 @@ export const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type ColorFieldType = z.infer; export type ColorFieldValue = z.infer; export type ColorFieldInputInstance = z.infer; -export type ColorFieldOutputInstance = z.infer; export type ColorFieldInputTemplate = z.infer; export type ColorFieldOutputTemplate = z.infer; export const isColorFieldInputInstance = (val: unknown): val is ColorFieldInputInstance => @@ -363,12 +315,8 @@ export const zMainModelFieldType = zFieldTypeBase.extend({ }); export const zMainModelFieldValue = zMainModelField.optional(); export const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zMainModelFieldType, value: zMainModelFieldValue, }); -export const zMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zMainModelFieldType, -}); export const zMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zMainModelFieldType, default: zMainModelFieldValue, @@ -379,7 +327,6 @@ export const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type MainModelFieldType = z.infer; export type MainModelFieldValue = z.infer; export type MainModelFieldInputInstance = z.infer; -export type MainModelFieldOutputInstance = z.infer; export type MainModelFieldInputTemplate = z.infer; export type MainModelFieldOutputTemplate = z.infer; export const isMainModelFieldInputInstance = (val: unknown): val is MainModelFieldInputInstance => @@ -394,12 +341,8 @@ export const zSDXLMainModelFieldType = zFieldTypeBase.extend({ }); export const zSDXLMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only. export const zSDXLMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zSDXLMainModelFieldType, value: zSDXLMainModelFieldValue, }); -export const zSDXLMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zSDXLMainModelFieldType, -}); export const zSDXLMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zSDXLMainModelFieldType, default: zSDXLMainModelFieldValue, @@ -410,7 +353,6 @@ export const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend export type SDXLMainModelFieldType = z.infer; export type SDXLMainModelFieldValue = z.infer; export type SDXLMainModelFieldInputInstance = z.infer; -export type SDXLMainModelFieldOutputInstance = z.infer; export type SDXLMainModelFieldInputTemplate = z.infer; export type SDXLMainModelFieldOutputTemplate = z.infer; export const isSDXLMainModelFieldInputInstance = (val: unknown): val is SDXLMainModelFieldInputInstance => @@ -425,12 +367,8 @@ export const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({ }); export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only. export const zSDXLRefinerModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zSDXLRefinerModelFieldType, value: zSDXLRefinerModelFieldValue, }); -export const zSDXLRefinerModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zSDXLRefinerModelFieldType, -}); export const zSDXLRefinerModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zSDXLRefinerModelFieldType, default: zSDXLRefinerModelFieldValue, @@ -441,7 +379,6 @@ export const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.ext export type SDXLRefinerModelFieldType = z.infer; export type SDXLRefinerModelFieldValue = z.infer; export type SDXLRefinerModelFieldInputInstance = z.infer; -export type SDXLRefinerModelFieldOutputInstance = z.infer; export type SDXLRefinerModelFieldInputTemplate = z.infer; export type SDXLRefinerModelFieldOutputTemplate = z.infer; export const isSDXLRefinerModelFieldInputInstance = (val: unknown): val is SDXLRefinerModelFieldInputInstance => @@ -456,12 +393,8 @@ export const zVAEModelFieldType = zFieldTypeBase.extend({ }); export const zVAEModelFieldValue = zVAEModelField.optional(); export const zVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zVAEModelFieldType, value: zVAEModelFieldValue, }); -export const zVAEModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zVAEModelFieldType, -}); export const zVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zVAEModelFieldType, default: zVAEModelFieldValue, @@ -472,7 +405,6 @@ export const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type VAEModelFieldType = z.infer; export type VAEModelFieldValue = z.infer; export type VAEModelFieldInputInstance = z.infer; -export type VAEModelFieldOutputInstance = z.infer; export type VAEModelFieldInputTemplate = z.infer; export type VAEModelFieldOutputTemplate = z.infer; export const isVAEModelFieldInputInstance = (val: unknown): val is VAEModelFieldInputInstance => @@ -487,12 +419,8 @@ export const zLoRAModelFieldType = zFieldTypeBase.extend({ }); export const zLoRAModelFieldValue = zLoRAModelField.optional(); export const zLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zLoRAModelFieldType, value: zLoRAModelFieldValue, }); -export const zLoRAModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zLoRAModelFieldType, -}); export const zLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zLoRAModelFieldType, default: zLoRAModelFieldValue, @@ -503,7 +431,6 @@ export const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type LoRAModelFieldType = z.infer; export type LoRAModelFieldValue = z.infer; export type LoRAModelFieldInputInstance = z.infer; -export type LoRAModelFieldOutputInstance = z.infer; export type LoRAModelFieldInputTemplate = z.infer; export type LoRAModelFieldOutputTemplate = z.infer; export const isLoRAModelFieldInputInstance = (val: unknown): val is LoRAModelFieldInputInstance => @@ -518,12 +445,8 @@ export const zControlNetModelFieldType = zFieldTypeBase.extend({ }); export const zControlNetModelFieldValue = zControlNetModelField.optional(); export const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zControlNetModelFieldType, value: zControlNetModelFieldValue, }); -export const zControlNetModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zControlNetModelFieldType, -}); export const zControlNetModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zControlNetModelFieldType, default: zControlNetModelFieldValue, @@ -534,7 +457,6 @@ export const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.exte export type ControlNetModelFieldType = z.infer; export type ControlNetModelFieldValue = z.infer; export type ControlNetModelFieldInputInstance = z.infer; -export type ControlNetModelFieldOutputInstance = z.infer; export type ControlNetModelFieldInputTemplate = z.infer; export type ControlNetModelFieldOutputTemplate = z.infer; export const isControlNetModelFieldInputInstance = (val: unknown): val is ControlNetModelFieldInputInstance => @@ -551,12 +473,8 @@ export const zIPAdapterModelFieldType = zFieldTypeBase.extend({ }); export const zIPAdapterModelFieldValue = zIPAdapterModelField.optional(); export const zIPAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zIPAdapterModelFieldType, value: zIPAdapterModelFieldValue, }); -export const zIPAdapterModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zIPAdapterModelFieldType, -}); export const zIPAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zIPAdapterModelFieldType, default: zIPAdapterModelFieldValue, @@ -567,7 +485,6 @@ export const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.exten export type IPAdapterModelFieldType = z.infer; export type IPAdapterModelFieldValue = z.infer; export type IPAdapterModelFieldInputInstance = z.infer; -export type IPAdapterModelFieldOutputInstance = z.infer; export type IPAdapterModelFieldInputTemplate = z.infer; export type IPAdapterModelFieldOutputTemplate = z.infer; export const isIPAdapterModelFieldInputInstance = (val: unknown): val is IPAdapterModelFieldInputInstance => @@ -584,12 +501,8 @@ export const zT2IAdapterModelFieldType = zFieldTypeBase.extend({ }); export const zT2IAdapterModelFieldValue = zT2IAdapterModelField.optional(); export const zT2IAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zT2IAdapterModelFieldType, value: zT2IAdapterModelFieldValue, }); -export const zT2IAdapterModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zT2IAdapterModelFieldType, -}); export const zT2IAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zT2IAdapterModelFieldType, default: zT2IAdapterModelFieldValue, @@ -600,7 +513,6 @@ export const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.exte export type T2IAdapterModelFieldType = z.infer; export type T2IAdapterModelFieldValue = z.infer; export type T2IAdapterModelFieldInputInstance = z.infer; -export type T2IAdapterModelFieldOutputInstance = z.infer; export type T2IAdapterModelFieldInputTemplate = z.infer; export type T2IAdapterModelFieldOutputTemplate = z.infer; export const isT2IAdapterModelFieldInputInstance = (val: unknown): val is T2IAdapterModelFieldInputInstance => @@ -615,12 +527,8 @@ export const zSchedulerFieldType = zFieldTypeBase.extend({ }); export const zSchedulerFieldValue = zSchedulerField.optional(); export const zSchedulerFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zSchedulerFieldType, value: zSchedulerFieldValue, }); -export const zSchedulerFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zSchedulerFieldType, -}); export const zSchedulerFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zSchedulerFieldType, default: zSchedulerFieldValue, @@ -631,7 +539,6 @@ export const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type SchedulerFieldType = z.infer; export type SchedulerFieldValue = z.infer; export type SchedulerFieldInputInstance = z.infer; -export type SchedulerFieldOutputInstance = z.infer; export type SchedulerFieldInputTemplate = z.infer; export type SchedulerFieldOutputTemplate = z.infer; export const isSchedulerFieldInputInstance = (val: unknown): val is SchedulerFieldInputInstance => @@ -657,12 +564,8 @@ export const zStatelessFieldType = zFieldTypeBase.extend({ }); export const zStatelessFieldValue = z.undefined().catch(undefined); // stateless --> no value, but making this z.never() introduces a lot of extra TS fanagling export const zStatelessFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zStatelessFieldType, value: zStatelessFieldValue, }); -export const zStatelessFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zStatelessFieldType, -}); export const zStatelessFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zStatelessFieldType, default: zStatelessFieldValue, @@ -675,7 +578,6 @@ export const zStatelessFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type StatelessFieldType = z.infer; export type StatelessFieldValue = z.infer; export type StatelessFieldInputInstance = z.infer; -export type StatelessFieldOutputInstance = z.infer; export type StatelessFieldInputTemplate = z.infer; export type StatelessFieldOutputTemplate = z.infer; // #endregion @@ -783,36 +685,6 @@ export const isFieldInputInstance = (val: unknown): val is FieldInputInstance => zFieldInputInstance.safeParse(val).success; // #endregion -// #region StatefulFieldOutputInstance & FieldOutputInstance -export const zStatefulFieldOutputInstance = z.union([ - zIntegerFieldOutputInstance, - zFloatFieldOutputInstance, - zStringFieldOutputInstance, - zBooleanFieldOutputInstance, - zEnumFieldOutputInstance, - zImageFieldOutputInstance, - zBoardFieldOutputInstance, - zMainModelFieldOutputInstance, - zSDXLMainModelFieldOutputInstance, - zSDXLRefinerModelFieldOutputInstance, - zVAEModelFieldOutputInstance, - zLoRAModelFieldOutputInstance, - zControlNetModelFieldOutputInstance, - zIPAdapterModelFieldOutputInstance, - zT2IAdapterModelFieldOutputInstance, - zColorFieldOutputInstance, - zSchedulerFieldOutputInstance, -]); -export type StatefulFieldOutputInstance = z.infer; -export const isStatefulFieldOutputInstance = (val: unknown): val is StatefulFieldOutputInstance => - zStatefulFieldOutputInstance.safeParse(val).success; - -export const zFieldOutputInstance = z.union([zStatefulFieldOutputInstance, zStatelessFieldOutputInstance]); -export type FieldOutputInstance = z.infer; -export const isFieldOutputInstance = (val: unknown): val is FieldOutputInstance => - zFieldOutputInstance.safeParse(val).success; -// #endregion - // #region StatefulFieldInputTemplate & FieldInputTemplate export const zStatefulFieldInputTemplate = z.union([ zIntegerFieldInputTemplate, diff --git a/invokeai/frontend/web/src/features/nodes/types/invocation.ts b/invokeai/frontend/web/src/features/nodes/types/invocation.ts index 86ec70fd9b..5ccb19430d 100644 --- a/invokeai/frontend/web/src/features/nodes/types/invocation.ts +++ b/invokeai/frontend/web/src/features/nodes/types/invocation.ts @@ -2,7 +2,7 @@ import type { Edge, Node } from 'reactflow'; import { z } from 'zod'; import { zClassification, zProgressImage } from './common'; -import { zFieldInputInstance, zFieldInputTemplate, zFieldOutputInstance, zFieldOutputTemplate } from './field'; +import { zFieldInputInstance, zFieldInputTemplate, zFieldOutputTemplate } from './field'; import { zSemVer } from './semver'; // #region InvocationTemplate @@ -25,16 +25,15 @@ export type InvocationTemplate = z.infer; // #region NodeData export const zInvocationNodeData = z.object({ id: z.string().trim().min(1), - type: z.string().trim().min(1), - label: z.string(), - isOpen: z.boolean(), - notes: z.string(), - isIntermediate: z.boolean(), - useCache: z.boolean(), version: zSemVer, nodePack: z.string().min(1).nullish(), + label: z.string(), + notes: z.string(), + type: z.string().trim().min(1), inputs: z.record(zFieldInputInstance), - outputs: z.record(zFieldOutputInstance), + isOpen: z.boolean(), + isIntermediate: z.boolean(), + useCache: z.boolean(), }); export const zNotesNodeData = z.object({ @@ -62,11 +61,12 @@ export type NotesNode = Node; export type CurrentImageNode = Node; export type AnyNode = Node; -export const isInvocationNode = (node?: AnyNode): node is InvocationNode => Boolean(node && node.type === 'invocation'); -export const isNotesNode = (node?: AnyNode): node is NotesNode => Boolean(node && node.type === 'notes'); -export const isCurrentImageNode = (node?: AnyNode): node is CurrentImageNode => +export const isInvocationNode = (node?: AnyNode | null): node is InvocationNode => + Boolean(node && node.type === 'invocation'); +export const isNotesNode = (node?: AnyNode | null): node is NotesNode => Boolean(node && node.type === 'notes'); +export const isCurrentImageNode = (node?: AnyNode | null): node is CurrentImageNode => Boolean(node && node.type === 'current_image'); -export const isInvocationNodeData = (node?: AnyNodeData): node is InvocationNodeData => +export const isInvocationNodeData = (node?: AnyNodeData | null): node is InvocationNodeData => Boolean(node && !['notes', 'current_image'].includes(node.type)); // node.type may be 'notes', 'current_image', or any invocation type // #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/common.ts b/invokeai/frontend/web/src/features/nodes/types/v2/common.ts new file mode 100644 index 0000000000..b524474379 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/common.ts @@ -0,0 +1,188 @@ +import { z } from 'zod'; + +// #region Field data schemas +export const zImageField = z.object({ + image_name: z.string().trim().min(1), +}); +export type ImageField = z.infer; + +export const zBoardField = z.object({ + board_id: z.string().trim().min(1), +}); +export type BoardField = z.infer; + +export const zColorField = z.object({ + r: z.number().int().min(0).max(255), + g: z.number().int().min(0).max(255), + b: z.number().int().min(0).max(255), + a: z.number().int().min(0).max(255), +}); +export type ColorField = z.infer; + +export const zClassification = z.enum(['stable', 'beta', 'prototype']); +export type Classification = z.infer; + +export const zSchedulerField = z.enum([ + 'euler', + 'deis', + 'ddim', + 'ddpm', + 'dpmpp_2s', + 'dpmpp_2m', + 'dpmpp_2m_sde', + 'dpmpp_sde', + 'heun', + 'kdpm_2', + 'lms', + 'pndm', + 'unipc', + 'euler_k', + 'dpmpp_2s_k', + 'dpmpp_2m_k', + 'dpmpp_2m_sde_k', + 'dpmpp_sde_k', + 'heun_k', + 'lms_k', + 'euler_a', + 'kdpm_2_a', + 'lcm', +]); +export type SchedulerField = z.infer; +// #endregion + +// #region Model-related schemas +export const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']); +export const zModelType = z.enum(['main', 'vae', 'lora', 'controlnet', 'embedding']); +export const zModelName = z.string().min(3); +export const zModelIdentifier = z.object({ + model_name: zModelName, + base_model: zBaseModel, +}); +export type BaseModel = z.infer; +export type ModelType = z.infer; +export type ModelIdentifier = z.infer; + +export const zMainModelField = z.object({ + model_name: zModelName, + base_model: zBaseModel, + model_type: z.literal('main'), +}); +export const zSDXLRefinerModelField = z.object({ + model_name: z.string().min(1), + base_model: z.literal('sdxl-refiner'), + model_type: z.literal('main'), +}); +export type MainModelField = z.infer; +export type SDXLRefinerModelField = z.infer; + +export const zSubModelType = z.enum([ + 'unet', + 'text_encoder', + 'text_encoder_2', + 'tokenizer', + 'tokenizer_2', + 'vae', + 'vae_decoder', + 'vae_encoder', + 'scheduler', + 'safety_checker', +]); +export type SubModelType = z.infer; + +export const zVAEModelField = zModelIdentifier; + +export const zModelInfo = zModelIdentifier.extend({ + model_type: zModelType, + submodel: zSubModelType.optional(), +}); +export type ModelInfo = z.infer; + +export const zLoRAModelField = zModelIdentifier; +export type LoRAModelField = z.infer; + +export const zControlNetModelField = zModelIdentifier; +export type ControlNetModelField = z.infer; + +export const zIPAdapterModelField = zModelIdentifier; +export type IPAdapterModelField = z.infer; + +export const zT2IAdapterModelField = zModelIdentifier; +export type T2IAdapterModelField = z.infer; + +export const zLoraInfo = zModelInfo.extend({ + weight: z.number().optional(), +}); +export type LoraInfo = z.infer; + +export const zUNetField = z.object({ + unet: zModelInfo, + scheduler: zModelInfo, + loras: z.array(zLoraInfo), +}); +export type UNetField = z.infer; + +export const zCLIPField = z.object({ + tokenizer: zModelInfo, + text_encoder: zModelInfo, + skipped_layers: z.number(), + loras: z.array(zLoraInfo), +}); +export type CLIPField = z.infer; + +export const zVAEField = z.object({ + vae: zModelInfo, +}); +export type VAEField = z.infer; +// #endregion + +// #region Control Adapters +export const zControlField = z.object({ + image: zImageField, + control_model: zControlNetModelField, + control_weight: z.union([z.number(), z.array(z.number())]).optional(), + begin_step_percent: z.number().optional(), + end_step_percent: z.number().optional(), + control_mode: z.enum(['balanced', 'more_prompt', 'more_control', 'unbalanced']).optional(), + resize_mode: z.enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple']).optional(), +}); +export type ControlField = z.infer; + +export const zIPAdapterField = z.object({ + image: zImageField, + ip_adapter_model: zIPAdapterModelField, + weight: z.number(), + begin_step_percent: z.number().optional(), + end_step_percent: z.number().optional(), +}); +export type IPAdapterField = z.infer; + +export const zT2IAdapterField = z.object({ + image: zImageField, + t2i_adapter_model: zT2IAdapterModelField, + weight: z.union([z.number(), z.array(z.number())]).optional(), + begin_step_percent: z.number().optional(), + end_step_percent: z.number().optional(), + resize_mode: z.enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple']).optional(), +}); +export type T2IAdapterField = z.infer; +// #endregion + +// #region ProgressImage +export const zProgressImage = z.object({ + dataURL: z.string(), + width: z.number().int(), + height: z.number().int(), +}); +export type ProgressImage = z.infer; +// #endregion + +// #region ImageOutput +export const zImageOutput = z.object({ + image: zImageField, + width: z.number().int().gt(0), + height: z.number().int().gt(0), + type: z.literal('image_output'), +}); +export type ImageOutput = z.infer; +export const isImageOutput = (output: unknown): output is ImageOutput => zImageOutput.safeParse(output).success; +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/constants.ts b/invokeai/frontend/web/src/features/nodes/types/v2/constants.ts new file mode 100644 index 0000000000..35ef9e9fd2 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/constants.ts @@ -0,0 +1,80 @@ +import type { Node } from 'reactflow'; + +/** + * How long to wait before showing a tooltip when hovering a field handle. + */ +export const HANDLE_TOOLTIP_OPEN_DELAY = 500; + +/** + * The width of a node in the UI in pixels. + */ +export const NODE_WIDTH = 320; + +/** + * This class name is special - reactflow uses it to identify the drag handle of a node, + * applying the appropriate listeners to it. + */ +export const DRAG_HANDLE_CLASSNAME = 'node-drag-handle'; + +/** + * reactflow-specifc properties shared between all node types. + */ +export const SHARED_NODE_PROPERTIES: Partial = { + dragHandle: `.${DRAG_HANDLE_CLASSNAME}`, +}; + +/** + * Helper for getting the kind of a field. + */ +export const KIND_MAP = { + input: 'inputs' as const, + output: 'outputs' as const, +}; + +/** + * Model types' handles are rendered as squares in the UI. + */ +export const MODEL_TYPES = [ + 'IPAdapterModelField', + 'ControlNetModelField', + 'LoRAModelField', + 'MainModelField', + 'SDXLMainModelField', + 'SDXLRefinerModelField', + 'VaeModelField', + 'UNetField', + 'VaeField', + 'ClipField', + 'T2IAdapterModelField', + 'IPAdapterModelField', +]; + +/** + * Colors for each field type - applies to their handles and edges. + */ +export const FIELD_COLORS: { [key: string]: string } = { + BoardField: 'purple.500', + BooleanField: 'green.500', + ClipField: 'green.500', + ColorField: 'pink.300', + ConditioningField: 'cyan.500', + ControlField: 'teal.500', + ControlNetModelField: 'teal.500', + EnumField: 'blue.500', + FloatField: 'orange.500', + ImageField: 'purple.500', + IntegerField: 'red.500', + IPAdapterField: 'teal.500', + IPAdapterModelField: 'teal.500', + LatentsField: 'pink.500', + LoRAModelField: 'teal.500', + MainModelField: 'teal.500', + SDXLMainModelField: 'teal.500', + SDXLRefinerModelField: 'teal.500', + StringField: 'yellow.500', + T2IAdapterField: 'teal.500', + T2IAdapterModelField: 'teal.500', + UNetField: 'red.500', + VaeField: 'blue.500', + VaeModelField: 'teal.500', +}; diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/error.ts b/invokeai/frontend/web/src/features/nodes/types/v2/error.ts new file mode 100644 index 0000000000..905b487fb0 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/error.ts @@ -0,0 +1,58 @@ +/** + * Invalid Workflow Version Error + * Raised when a workflow version is not recognized. + */ +export class WorkflowVersionError extends Error { + /** + * Create WorkflowVersionError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} +/** + * Workflow Migration Error + * Raised when a workflow migration fails. + */ +export class WorkflowMigrationError extends Error { + /** + * Create WorkflowMigrationError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} + +/** + * Unable to Update Node Error + * Raised when a node cannot be updated. + */ +export class NodeUpdateError extends Error { + /** + * Create NodeUpdateError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} + +/** + * FieldParseError + * Raised when a field cannot be parsed from a field schema. + */ +export class FieldParseError extends Error { + /** + * Create FieldTypeParseError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/field.ts b/invokeai/frontend/web/src/features/nodes/types/v2/field.ts new file mode 100644 index 0000000000..38f1af55dd --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/field.ts @@ -0,0 +1,875 @@ +import { z } from 'zod'; + +import { + zBoardField, + zColorField, + zControlNetModelField, + zImageField, + zIPAdapterModelField, + zLoRAModelField, + zMainModelField, + zSchedulerField, + zT2IAdapterModelField, + zVAEModelField, +} from './common'; + +/** + * zod schemas & inferred types for fields. + * + * These schemas and types are only required for stateful field - fields that have UI components + * and allow the user to directly provide values. + * + * This includes primitive values (numbers, strings, booleans), models, scheduler, etc. + * + * If a field type does not have a UI component, then it does not need to be included here, because + * we never store its value. Such field types will be handled via the "StatelessField" logic. + * + * Fields require: + * - zFieldType - zod schema for the field type + * - zFieldValue - zod schema for the field value + * - zFieldInputInstance - zod schema for the field's input instance + * - zFieldOutputInstance - zod schema for the field's output instance + * - zFieldInputTemplate - zod schema for the field's input template + * - zFieldOutputTemplate - zod schema for the field's output template + * - inferred types for each schema + * - type guards for InputInstance and InputTemplate + * + * These then must be added to the unions at the bottom of this file. + */ + +/** */ + +// #region Base schemas & misc +export const zFieldInput = z.enum(['connection', 'direct', 'any']); +export type FieldInput = z.infer; + +export const zFieldUIComponent = z.enum(['none', 'textarea', 'slider']); +export type FieldUIComponent = z.infer; + +export const zFieldInstanceBase = z.object({ + id: z.string().trim().min(1), + name: z.string().trim().min(1), +}); +export const zFieldInputInstanceBase = zFieldInstanceBase.extend({ + fieldKind: z.literal('input'), + label: z.string().nullish(), +}); +export const zFieldOutputInstanceBase = zFieldInstanceBase.extend({ + fieldKind: z.literal('output'), +}); +export type FieldInstanceBase = z.infer; +export type FieldInputInstanceBase = z.infer; +export type FieldOutputInstanceBase = z.infer; + +export const zFieldTemplateBase = z.object({ + name: z.string().min(1), + title: z.string().min(1), + description: z.string().nullish(), + ui_hidden: z.boolean(), + ui_type: z.string().nullish(), + ui_order: z.number().int().nullish(), +}); +export const zFieldInputTemplateBase = zFieldTemplateBase.extend({ + fieldKind: z.literal('input'), + input: zFieldInput, + required: z.boolean(), + ui_component: zFieldUIComponent.nullish(), + ui_choice_labels: z.record(z.string()).nullish(), +}); +export const zFieldOutputTemplateBase = zFieldTemplateBase.extend({ + fieldKind: z.literal('output'), +}); +export type FieldTemplateBase = z.infer; +export type FieldInputTemplateBase = z.infer; +export type FieldOutputTemplateBase = z.infer; + +export const zFieldTypeBase = z.object({ + isCollection: z.boolean(), + isCollectionOrScalar: z.boolean(), +}); + +export const zFieldIdentifier = z.object({ + nodeId: z.string().trim().min(1), + fieldName: z.string().trim().min(1), +}); +export type FieldIdentifier = z.infer; +export const isFieldIdentifier = (val: unknown): val is FieldIdentifier => zFieldIdentifier.safeParse(val).success; +// #endregion + +// #region IntegerField +export const zIntegerFieldType = zFieldTypeBase.extend({ + name: z.literal('IntegerField'), +}); +export const zIntegerFieldValue = z.number().int(); +export const zIntegerFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zIntegerFieldType, + value: zIntegerFieldValue, +}); +export const zIntegerFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zIntegerFieldType, +}); +export const zIntegerFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zIntegerFieldType, + default: zIntegerFieldValue, + multipleOf: z.number().int().optional(), + maximum: z.number().int().optional(), + exclusiveMaximum: z.number().int().optional(), + minimum: z.number().int().optional(), + exclusiveMinimum: z.number().int().optional(), +}); +export const zIntegerFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zIntegerFieldType, +}); +export type IntegerFieldType = z.infer; +export type IntegerFieldValue = z.infer; +export type IntegerFieldInputInstance = z.infer; +export type IntegerFieldInputTemplate = z.infer; +export const isIntegerFieldInputInstance = (val: unknown): val is IntegerFieldInputInstance => + zIntegerFieldInputInstance.safeParse(val).success; +export const isIntegerFieldInputTemplate = (val: unknown): val is IntegerFieldInputTemplate => + zIntegerFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region FloatField +export const zFloatFieldType = zFieldTypeBase.extend({ + name: z.literal('FloatField'), +}); +export const zFloatFieldValue = z.number(); +export const zFloatFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zFloatFieldType, + value: zFloatFieldValue, +}); +export const zFloatFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zFloatFieldType, +}); +export const zFloatFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zFloatFieldType, + default: zFloatFieldValue, + multipleOf: z.number().optional(), + maximum: z.number().optional(), + exclusiveMaximum: z.number().optional(), + minimum: z.number().optional(), + exclusiveMinimum: z.number().optional(), +}); +export const zFloatFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zFloatFieldType, +}); +export type FloatFieldType = z.infer; +export type FloatFieldValue = z.infer; +export type FloatFieldInputInstance = z.infer; +export type FloatFieldOutputInstance = z.infer; +export type FloatFieldInputTemplate = z.infer; +export type FloatFieldOutputTemplate = z.infer; +export const isFloatFieldInputInstance = (val: unknown): val is FloatFieldInputInstance => + zFloatFieldInputInstance.safeParse(val).success; +export const isFloatFieldInputTemplate = (val: unknown): val is FloatFieldInputTemplate => + zFloatFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region StringField +export const zStringFieldType = zFieldTypeBase.extend({ + name: z.literal('StringField'), +}); +export const zStringFieldValue = z.string(); +export const zStringFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zStringFieldType, + value: zStringFieldValue, +}); +export const zStringFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zStringFieldType, +}); +export const zStringFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zStringFieldType, + default: zStringFieldValue, + maxLength: z.number().int().optional(), + minLength: z.number().int().optional(), +}); +export const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zStringFieldType, +}); + +export type StringFieldType = z.infer; +export type StringFieldValue = z.infer; +export type StringFieldInputInstance = z.infer; +export type StringFieldOutputInstance = z.infer; +export type StringFieldInputTemplate = z.infer; +export type StringFieldOutputTemplate = z.infer; +export const isStringFieldInputInstance = (val: unknown): val is StringFieldInputInstance => + zStringFieldInputInstance.safeParse(val).success; +export const isStringFieldInputTemplate = (val: unknown): val is StringFieldInputTemplate => + zStringFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region BooleanField +export const zBooleanFieldType = zFieldTypeBase.extend({ + name: z.literal('BooleanField'), +}); +export const zBooleanFieldValue = z.boolean(); +export const zBooleanFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zBooleanFieldType, + value: zBooleanFieldValue, +}); +export const zBooleanFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zBooleanFieldType, +}); +export const zBooleanFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zBooleanFieldType, + default: zBooleanFieldValue, +}); +export const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zBooleanFieldType, +}); +export type BooleanFieldType = z.infer; +export type BooleanFieldValue = z.infer; +export type BooleanFieldInputInstance = z.infer; +export type BooleanFieldOutputInstance = z.infer; +export type BooleanFieldInputTemplate = z.infer; +export type BooleanFieldOutputTemplate = z.infer; +export const isBooleanFieldInputInstance = (val: unknown): val is BooleanFieldInputInstance => + zBooleanFieldInputInstance.safeParse(val).success; +export const isBooleanFieldInputTemplate = (val: unknown): val is BooleanFieldInputTemplate => + zBooleanFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region EnumField +export const zEnumFieldType = zFieldTypeBase.extend({ + name: z.literal('EnumField'), +}); +export const zEnumFieldValue = z.string(); +export const zEnumFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zEnumFieldType, + value: zEnumFieldValue, +}); +export const zEnumFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zEnumFieldType, +}); +export const zEnumFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zEnumFieldType, + default: zEnumFieldValue, + options: z.array(z.string()), + labels: z.record(z.string()).optional(), +}); +export const zEnumFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zEnumFieldType, +}); +export type EnumFieldType = z.infer; +export type EnumFieldValue = z.infer; +export type EnumFieldInputInstance = z.infer; +export type EnumFieldOutputInstance = z.infer; +export type EnumFieldInputTemplate = z.infer; +export type EnumFieldOutputTemplate = z.infer; +export const isEnumFieldInputInstance = (val: unknown): val is EnumFieldInputInstance => + zEnumFieldInputInstance.safeParse(val).success; +export const isEnumFieldInputTemplate = (val: unknown): val is EnumFieldInputTemplate => + zEnumFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region ImageField +export const zImageFieldType = zFieldTypeBase.extend({ + name: z.literal('ImageField'), +}); +export const zImageFieldValue = zImageField.optional(); +export const zImageFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zImageFieldType, + value: zImageFieldValue, +}); +export const zImageFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zImageFieldType, +}); +export const zImageFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zImageFieldType, + default: zImageFieldValue, +}); +export const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zImageFieldType, +}); +export type ImageFieldType = z.infer; +export type ImageFieldValue = z.infer; +export type ImageFieldInputInstance = z.infer; +export type ImageFieldOutputInstance = z.infer; +export type ImageFieldInputTemplate = z.infer; +export type ImageFieldOutputTemplate = z.infer; +export const isImageFieldInputInstance = (val: unknown): val is ImageFieldInputInstance => + zImageFieldInputInstance.safeParse(val).success; +export const isImageFieldInputTemplate = (val: unknown): val is ImageFieldInputTemplate => + zImageFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region BoardField +export const zBoardFieldType = zFieldTypeBase.extend({ + name: z.literal('BoardField'), +}); +export const zBoardFieldValue = zBoardField.optional(); +export const zBoardFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zBoardFieldType, + value: zBoardFieldValue, +}); +export const zBoardFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zBoardFieldType, +}); +export const zBoardFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zBoardFieldType, + default: zBoardFieldValue, +}); +export const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zBoardFieldType, +}); +export type BoardFieldType = z.infer; +export type BoardFieldValue = z.infer; +export type BoardFieldInputInstance = z.infer; +export type BoardFieldOutputInstance = z.infer; +export type BoardFieldInputTemplate = z.infer; +export type BoardFieldOutputTemplate = z.infer; +export const isBoardFieldInputInstance = (val: unknown): val is BoardFieldInputInstance => + zBoardFieldInputInstance.safeParse(val).success; +export const isBoardFieldInputTemplate = (val: unknown): val is BoardFieldInputTemplate => + zBoardFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region ColorField +export const zColorFieldType = zFieldTypeBase.extend({ + name: z.literal('ColorField'), +}); +export const zColorFieldValue = zColorField.optional(); +export const zColorFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zColorFieldType, + value: zColorFieldValue, +}); +export const zColorFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zColorFieldType, +}); +export const zColorFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zColorFieldType, + default: zColorFieldValue, +}); +export const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zColorFieldType, +}); +export type ColorFieldType = z.infer; +export type ColorFieldValue = z.infer; +export type ColorFieldInputInstance = z.infer; +export type ColorFieldOutputInstance = z.infer; +export type ColorFieldInputTemplate = z.infer; +export type ColorFieldOutputTemplate = z.infer; +export const isColorFieldInputInstance = (val: unknown): val is ColorFieldInputInstance => + zColorFieldInputInstance.safeParse(val).success; +export const isColorFieldInputTemplate = (val: unknown): val is ColorFieldInputTemplate => + zColorFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region MainModelField +export const zMainModelFieldType = zFieldTypeBase.extend({ + name: z.literal('MainModelField'), +}); +export const zMainModelFieldValue = zMainModelField.optional(); +export const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zMainModelFieldType, + value: zMainModelFieldValue, +}); +export const zMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zMainModelFieldType, +}); +export const zMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zMainModelFieldType, + default: zMainModelFieldValue, +}); +export const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zMainModelFieldType, +}); +export type MainModelFieldType = z.infer; +export type MainModelFieldValue = z.infer; +export type MainModelFieldInputInstance = z.infer; +export type MainModelFieldOutputInstance = z.infer; +export type MainModelFieldInputTemplate = z.infer; +export type MainModelFieldOutputTemplate = z.infer; +export const isMainModelFieldInputInstance = (val: unknown): val is MainModelFieldInputInstance => + zMainModelFieldInputInstance.safeParse(val).success; +export const isMainModelFieldInputTemplate = (val: unknown): val is MainModelFieldInputTemplate => + zMainModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region SDXLMainModelField +export const zSDXLMainModelFieldType = zFieldTypeBase.extend({ + name: z.literal('SDXLMainModelField'), +}); +export const zSDXLMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only. +export const zSDXLMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zSDXLMainModelFieldType, + value: zSDXLMainModelFieldValue, +}); +export const zSDXLMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zSDXLMainModelFieldType, +}); +export const zSDXLMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zSDXLMainModelFieldType, + default: zSDXLMainModelFieldValue, +}); +export const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zSDXLMainModelFieldType, +}); +export type SDXLMainModelFieldType = z.infer; +export type SDXLMainModelFieldValue = z.infer; +export type SDXLMainModelFieldInputInstance = z.infer; +export type SDXLMainModelFieldOutputInstance = z.infer; +export type SDXLMainModelFieldInputTemplate = z.infer; +export type SDXLMainModelFieldOutputTemplate = z.infer; +export const isSDXLMainModelFieldInputInstance = (val: unknown): val is SDXLMainModelFieldInputInstance => + zSDXLMainModelFieldInputInstance.safeParse(val).success; +export const isSDXLMainModelFieldInputTemplate = (val: unknown): val is SDXLMainModelFieldInputTemplate => + zSDXLMainModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region SDXLRefinerModelField +export const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({ + name: z.literal('SDXLRefinerModelField'), +}); +export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only. +export const zSDXLRefinerModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zSDXLRefinerModelFieldType, + value: zSDXLRefinerModelFieldValue, +}); +export const zSDXLRefinerModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zSDXLRefinerModelFieldType, +}); +export const zSDXLRefinerModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zSDXLRefinerModelFieldType, + default: zSDXLRefinerModelFieldValue, +}); +export const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zSDXLRefinerModelFieldType, +}); +export type SDXLRefinerModelFieldType = z.infer; +export type SDXLRefinerModelFieldValue = z.infer; +export type SDXLRefinerModelFieldInputInstance = z.infer; +export type SDXLRefinerModelFieldOutputInstance = z.infer; +export type SDXLRefinerModelFieldInputTemplate = z.infer; +export type SDXLRefinerModelFieldOutputTemplate = z.infer; +export const isSDXLRefinerModelFieldInputInstance = (val: unknown): val is SDXLRefinerModelFieldInputInstance => + zSDXLRefinerModelFieldInputInstance.safeParse(val).success; +export const isSDXLRefinerModelFieldInputTemplate = (val: unknown): val is SDXLRefinerModelFieldInputTemplate => + zSDXLRefinerModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region VAEModelField +export const zVAEModelFieldType = zFieldTypeBase.extend({ + name: z.literal('VAEModelField'), +}); +export const zVAEModelFieldValue = zVAEModelField.optional(); +export const zVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zVAEModelFieldType, + value: zVAEModelFieldValue, +}); +export const zVAEModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zVAEModelFieldType, +}); +export const zVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zVAEModelFieldType, + default: zVAEModelFieldValue, +}); +export const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zVAEModelFieldType, +}); +export type VAEModelFieldType = z.infer; +export type VAEModelFieldValue = z.infer; +export type VAEModelFieldInputInstance = z.infer; +export type VAEModelFieldOutputInstance = z.infer; +export type VAEModelFieldInputTemplate = z.infer; +export type VAEModelFieldOutputTemplate = z.infer; +export const isVAEModelFieldInputInstance = (val: unknown): val is VAEModelFieldInputInstance => + zVAEModelFieldInputInstance.safeParse(val).success; +export const isVAEModelFieldInputTemplate = (val: unknown): val is VAEModelFieldInputTemplate => + zVAEModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region LoRAModelField +export const zLoRAModelFieldType = zFieldTypeBase.extend({ + name: z.literal('LoRAModelField'), +}); +export const zLoRAModelFieldValue = zLoRAModelField.optional(); +export const zLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zLoRAModelFieldType, + value: zLoRAModelFieldValue, +}); +export const zLoRAModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zLoRAModelFieldType, +}); +export const zLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zLoRAModelFieldType, + default: zLoRAModelFieldValue, +}); +export const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zLoRAModelFieldType, +}); +export type LoRAModelFieldType = z.infer; +export type LoRAModelFieldValue = z.infer; +export type LoRAModelFieldInputInstance = z.infer; +export type LoRAModelFieldOutputInstance = z.infer; +export type LoRAModelFieldInputTemplate = z.infer; +export type LoRAModelFieldOutputTemplate = z.infer; +export const isLoRAModelFieldInputInstance = (val: unknown): val is LoRAModelFieldInputInstance => + zLoRAModelFieldInputInstance.safeParse(val).success; +export const isLoRAModelFieldInputTemplate = (val: unknown): val is LoRAModelFieldInputTemplate => + zLoRAModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region ControlNetModelField +export const zControlNetModelFieldType = zFieldTypeBase.extend({ + name: z.literal('ControlNetModelField'), +}); +export const zControlNetModelFieldValue = zControlNetModelField.optional(); +export const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zControlNetModelFieldType, + value: zControlNetModelFieldValue, +}); +export const zControlNetModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zControlNetModelFieldType, +}); +export const zControlNetModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zControlNetModelFieldType, + default: zControlNetModelFieldValue, +}); +export const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zControlNetModelFieldType, +}); +export type ControlNetModelFieldType = z.infer; +export type ControlNetModelFieldValue = z.infer; +export type ControlNetModelFieldInputInstance = z.infer; +export type ControlNetModelFieldOutputInstance = z.infer; +export type ControlNetModelFieldInputTemplate = z.infer; +export type ControlNetModelFieldOutputTemplate = z.infer; +export const isControlNetModelFieldInputInstance = (val: unknown): val is ControlNetModelFieldInputInstance => + zControlNetModelFieldInputInstance.safeParse(val).success; +export const isControlNetModelFieldInputTemplate = (val: unknown): val is ControlNetModelFieldInputTemplate => + zControlNetModelFieldInputTemplate.safeParse(val).success; +export const isControlNetModelFieldValue = (v: unknown): v is ControlNetModelFieldValue => + zControlNetModelFieldValue.safeParse(v).success; +// #endregion + +// #region IPAdapterModelField +export const zIPAdapterModelFieldType = zFieldTypeBase.extend({ + name: z.literal('IPAdapterModelField'), +}); +export const zIPAdapterModelFieldValue = zIPAdapterModelField.optional(); +export const zIPAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zIPAdapterModelFieldType, + value: zIPAdapterModelFieldValue, +}); +export const zIPAdapterModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zIPAdapterModelFieldType, +}); +export const zIPAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zIPAdapterModelFieldType, + default: zIPAdapterModelFieldValue, +}); +export const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zIPAdapterModelFieldType, +}); +export type IPAdapterModelFieldType = z.infer; +export type IPAdapterModelFieldValue = z.infer; +export type IPAdapterModelFieldInputInstance = z.infer; +export type IPAdapterModelFieldOutputInstance = z.infer; +export type IPAdapterModelFieldInputTemplate = z.infer; +export type IPAdapterModelFieldOutputTemplate = z.infer; +export const isIPAdapterModelFieldInputInstance = (val: unknown): val is IPAdapterModelFieldInputInstance => + zIPAdapterModelFieldInputInstance.safeParse(val).success; +export const isIPAdapterModelFieldInputTemplate = (val: unknown): val is IPAdapterModelFieldInputTemplate => + zIPAdapterModelFieldInputTemplate.safeParse(val).success; +export const isIPAdapterModelFieldValue = (val: unknown): val is IPAdapterModelFieldValue => + zIPAdapterModelFieldValue.safeParse(val).success; +// #endregion + +// #region T2IAdapterField +export const zT2IAdapterModelFieldType = zFieldTypeBase.extend({ + name: z.literal('T2IAdapterModelField'), +}); +export const zT2IAdapterModelFieldValue = zT2IAdapterModelField.optional(); +export const zT2IAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zT2IAdapterModelFieldType, + value: zT2IAdapterModelFieldValue, +}); +export const zT2IAdapterModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zT2IAdapterModelFieldType, +}); +export const zT2IAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zT2IAdapterModelFieldType, + default: zT2IAdapterModelFieldValue, +}); +export const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zT2IAdapterModelFieldType, +}); +export type T2IAdapterModelFieldType = z.infer; +export type T2IAdapterModelFieldValue = z.infer; +export type T2IAdapterModelFieldInputInstance = z.infer; +export type T2IAdapterModelFieldOutputInstance = z.infer; +export type T2IAdapterModelFieldInputTemplate = z.infer; +export type T2IAdapterModelFieldOutputTemplate = z.infer; +export const isT2IAdapterModelFieldInputInstance = (val: unknown): val is T2IAdapterModelFieldInputInstance => + zT2IAdapterModelFieldInputInstance.safeParse(val).success; +export const isT2IAdapterModelFieldInputTemplate = (val: unknown): val is T2IAdapterModelFieldInputTemplate => + zT2IAdapterModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region SchedulerField +export const zSchedulerFieldType = zFieldTypeBase.extend({ + name: z.literal('SchedulerField'), +}); +export const zSchedulerFieldValue = zSchedulerField.optional(); +export const zSchedulerFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zSchedulerFieldType, + value: zSchedulerFieldValue, +}); +export const zSchedulerFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zSchedulerFieldType, +}); +export const zSchedulerFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zSchedulerFieldType, + default: zSchedulerFieldValue, +}); +export const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zSchedulerFieldType, +}); +export type SchedulerFieldType = z.infer; +export type SchedulerFieldValue = z.infer; +export type SchedulerFieldInputInstance = z.infer; +export type SchedulerFieldOutputInstance = z.infer; +export type SchedulerFieldInputTemplate = z.infer; +export type SchedulerFieldOutputTemplate = z.infer; +export const isSchedulerFieldInputInstance = (val: unknown): val is SchedulerFieldInputInstance => + zSchedulerFieldInputInstance.safeParse(val).success; +export const isSchedulerFieldInputTemplate = (val: unknown): val is SchedulerFieldInputTemplate => + zSchedulerFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region StatelessField +/** + * StatelessField is a catchall for stateless fields with no UI input components. They do not + * do not support "direct" input, instead only accepting connections from other fields. + * + * This field type serves as a "generic" field type. + * + * Examples include: + * - Fields like UNetField or LatentsField where we do not allow direct UI input + * - Reserved fields like IsIntermediate + * - Any other field we don't have full-on schemas for + */ +export const zStatelessFieldType = zFieldTypeBase.extend({ + name: z.string().min(1), // stateless --> we accept the field's name as the type +}); +export const zStatelessFieldValue = z.undefined().catch(undefined); // stateless --> no value, but making this z.never() introduces a lot of extra TS fanagling +export const zStatelessFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zStatelessFieldType, + value: zStatelessFieldValue, +}); +export const zStatelessFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zStatelessFieldType, +}); +export const zStatelessFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zStatelessFieldType, + default: zStatelessFieldValue, + input: z.literal('connection'), // stateless --> only accepts connection inputs +}); +export const zStatelessFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zStatelessFieldType, +}); + +export type StatelessFieldType = z.infer; +export type StatelessFieldValue = z.infer; +export type StatelessFieldInputInstance = z.infer; +export type StatelessFieldOutputInstance = z.infer; +export type StatelessFieldInputTemplate = z.infer; +export type StatelessFieldOutputTemplate = z.infer; +// #endregion + +/** + * Here we define the main field unions: + * - FieldType + * - FieldValue + * - FieldInputInstance + * - FieldOutputInstance + * - FieldInputTemplate + * - FieldOutputTemplate + * + * All stateful fields are unioned together, and then that union is unioned with StatelessField. + * + * This allows us to interact with stateful fields without needing to worry about "generic" handling + * for all other StatelessFields. + */ + +// #region StatefulFieldType & FieldType +export const zStatefulFieldType = z.union([ + zIntegerFieldType, + zFloatFieldType, + zStringFieldType, + zBooleanFieldType, + zEnumFieldType, + zImageFieldType, + zBoardFieldType, + zMainModelFieldType, + zSDXLMainModelFieldType, + zSDXLRefinerModelFieldType, + zVAEModelFieldType, + zLoRAModelFieldType, + zControlNetModelFieldType, + zIPAdapterModelFieldType, + zT2IAdapterModelFieldType, + zColorFieldType, + zSchedulerFieldType, +]); +export type StatefulFieldType = z.infer; +export const isStatefulFieldType = (val: unknown): val is StatefulFieldType => + zStatefulFieldType.safeParse(val).success; + +export const zFieldType = z.union([zStatefulFieldType, zStatelessFieldType]); +export type FieldType = z.infer; +export const isFieldType = (val: unknown): val is FieldType => zFieldType.safeParse(val).success; +// #endregion + +// #region StatefulFieldValue & FieldValue +export const zStatefulFieldValue = z.union([ + zIntegerFieldValue, + zFloatFieldValue, + zStringFieldValue, + zBooleanFieldValue, + zEnumFieldValue, + zImageFieldValue, + zBoardFieldValue, + zMainModelFieldValue, + zSDXLMainModelFieldValue, + zSDXLRefinerModelFieldValue, + zVAEModelFieldValue, + zLoRAModelFieldValue, + zControlNetModelFieldValue, + zIPAdapterModelFieldValue, + zT2IAdapterModelFieldValue, + zColorFieldValue, + zSchedulerFieldValue, +]); +export type StatefulFieldValue = z.infer; +export const isStatefulFieldValue = (val: unknown): val is StatefulFieldValue => + zStatefulFieldValue.safeParse(val).success; + +export const zFieldValue = z.union([zStatefulFieldValue, zStatelessFieldValue]); +export type FieldValue = z.infer; +export const isFieldValue = (val: unknown): val is FieldValue => zFieldValue.safeParse(val).success; +// #endregion + +// #region StatefulFieldInputInstance & FieldInputInstance +export const zStatefulFieldInputInstance = z.union([ + zIntegerFieldInputInstance, + zFloatFieldInputInstance, + zStringFieldInputInstance, + zBooleanFieldInputInstance, + zEnumFieldInputInstance, + zImageFieldInputInstance, + zBoardFieldInputInstance, + zMainModelFieldInputInstance, + zSDXLMainModelFieldInputInstance, + zSDXLRefinerModelFieldInputInstance, + zVAEModelFieldInputInstance, + zLoRAModelFieldInputInstance, + zControlNetModelFieldInputInstance, + zIPAdapterModelFieldInputInstance, + zT2IAdapterModelFieldInputInstance, + zColorFieldInputInstance, + zSchedulerFieldInputInstance, +]); +export type StatefulFieldInputInstance = z.infer; +export const isStatefulFieldInputInstance = (val: unknown): val is StatefulFieldInputInstance => + zStatefulFieldInputInstance.safeParse(val).success; + +export const zFieldInputInstance = z.union([zStatefulFieldInputInstance, zStatelessFieldInputInstance]); +export type FieldInputInstance = z.infer; +export const isFieldInputInstance = (val: unknown): val is FieldInputInstance => + zFieldInputInstance.safeParse(val).success; +// #endregion + +// #region StatefulFieldOutputInstance & FieldOutputInstance +export const zStatefulFieldOutputInstance = z.union([ + zIntegerFieldOutputInstance, + zFloatFieldOutputInstance, + zStringFieldOutputInstance, + zBooleanFieldOutputInstance, + zEnumFieldOutputInstance, + zImageFieldOutputInstance, + zBoardFieldOutputInstance, + zMainModelFieldOutputInstance, + zSDXLMainModelFieldOutputInstance, + zSDXLRefinerModelFieldOutputInstance, + zVAEModelFieldOutputInstance, + zLoRAModelFieldOutputInstance, + zControlNetModelFieldOutputInstance, + zIPAdapterModelFieldOutputInstance, + zT2IAdapterModelFieldOutputInstance, + zColorFieldOutputInstance, + zSchedulerFieldOutputInstance, +]); +export type StatefulFieldOutputInstance = z.infer; +export const isStatefulFieldOutputInstance = (val: unknown): val is StatefulFieldOutputInstance => + zStatefulFieldOutputInstance.safeParse(val).success; + +export const zFieldOutputInstance = z.union([zStatefulFieldOutputInstance, zStatelessFieldOutputInstance]); +export type FieldOutputInstance = z.infer; +export const isFieldOutputInstance = (val: unknown): val is FieldOutputInstance => + zFieldOutputInstance.safeParse(val).success; +// #endregion + +// #region StatefulFieldInputTemplate & FieldInputTemplate +export const zStatefulFieldInputTemplate = z.union([ + zIntegerFieldInputTemplate, + zFloatFieldInputTemplate, + zStringFieldInputTemplate, + zBooleanFieldInputTemplate, + zEnumFieldInputTemplate, + zImageFieldInputTemplate, + zBoardFieldInputTemplate, + zMainModelFieldInputTemplate, + zSDXLMainModelFieldInputTemplate, + zSDXLRefinerModelFieldInputTemplate, + zVAEModelFieldInputTemplate, + zLoRAModelFieldInputTemplate, + zControlNetModelFieldInputTemplate, + zIPAdapterModelFieldInputTemplate, + zT2IAdapterModelFieldInputTemplate, + zColorFieldInputTemplate, + zSchedulerFieldInputTemplate, + zStatelessFieldInputTemplate, +]); +export type StatefulFieldInputTemplate = z.infer; +export const isStatefulFieldInputTemplate = (val: unknown): val is StatefulFieldInputTemplate => + zStatefulFieldInputTemplate.safeParse(val).success; + +export const zFieldInputTemplate = z.union([zStatefulFieldInputTemplate, zStatelessFieldInputTemplate]); +export type FieldInputTemplate = z.infer; +export const isFieldInputTemplate = (val: unknown): val is FieldInputTemplate => + zFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region StatefulFieldOutputTemplate & FieldOutputTemplate +export const zStatefulFieldOutputTemplate = z.union([ + zIntegerFieldOutputTemplate, + zFloatFieldOutputTemplate, + zStringFieldOutputTemplate, + zBooleanFieldOutputTemplate, + zEnumFieldOutputTemplate, + zImageFieldOutputTemplate, + zBoardFieldOutputTemplate, + zMainModelFieldOutputTemplate, + zSDXLMainModelFieldOutputTemplate, + zSDXLRefinerModelFieldOutputTemplate, + zVAEModelFieldOutputTemplate, + zLoRAModelFieldOutputTemplate, + zControlNetModelFieldOutputTemplate, + zIPAdapterModelFieldOutputTemplate, + zT2IAdapterModelFieldOutputTemplate, + zColorFieldOutputTemplate, + zSchedulerFieldOutputTemplate, +]); +export type StatefulFieldOutputTemplate = z.infer; +export const isStatefulFieldOutputTemplate = (val: unknown): val is StatefulFieldOutputTemplate => + zStatefulFieldOutputTemplate.safeParse(val).success; + +export const zFieldOutputTemplate = z.union([zStatefulFieldOutputTemplate, zStatelessFieldOutputTemplate]); +export type FieldOutputTemplate = z.infer; +export const isFieldOutputTemplate = (val: unknown): val is FieldOutputTemplate => + zFieldOutputTemplate.safeParse(val).success; +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/invocation.ts b/invokeai/frontend/web/src/features/nodes/types/v2/invocation.ts new file mode 100644 index 0000000000..86ec70fd9b --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/invocation.ts @@ -0,0 +1,93 @@ +import type { Edge, Node } from 'reactflow'; +import { z } from 'zod'; + +import { zClassification, zProgressImage } from './common'; +import { zFieldInputInstance, zFieldInputTemplate, zFieldOutputInstance, zFieldOutputTemplate } from './field'; +import { zSemVer } from './semver'; + +// #region InvocationTemplate +export const zInvocationTemplate = z.object({ + type: z.string(), + title: z.string(), + description: z.string(), + tags: z.array(z.string().min(1)), + inputs: z.record(zFieldInputTemplate), + outputs: z.record(zFieldOutputTemplate), + outputType: z.string().min(1), + version: zSemVer, + useCache: z.boolean(), + nodePack: z.string().min(1).nullish(), + classification: zClassification, +}); +export type InvocationTemplate = z.infer; +// #endregion + +// #region NodeData +export const zInvocationNodeData = z.object({ + id: z.string().trim().min(1), + type: z.string().trim().min(1), + label: z.string(), + isOpen: z.boolean(), + notes: z.string(), + isIntermediate: z.boolean(), + useCache: z.boolean(), + version: zSemVer, + nodePack: z.string().min(1).nullish(), + inputs: z.record(zFieldInputInstance), + outputs: z.record(zFieldOutputInstance), +}); + +export const zNotesNodeData = z.object({ + id: z.string().trim().min(1), + type: z.literal('notes'), + label: z.string(), + isOpen: z.boolean(), + notes: z.string(), +}); +export const zCurrentImageNodeData = z.object({ + id: z.string().trim().min(1), + type: z.literal('current_image'), + label: z.string(), + isOpen: z.boolean(), +}); +export const zAnyNodeData = z.union([zInvocationNodeData, zNotesNodeData, zCurrentImageNodeData]); + +export type NotesNodeData = z.infer; +export type InvocationNodeData = z.infer; +export type CurrentImageNodeData = z.infer; +export type AnyNodeData = z.infer; + +export type InvocationNode = Node; +export type NotesNode = Node; +export type CurrentImageNode = Node; +export type AnyNode = Node; + +export const isInvocationNode = (node?: AnyNode): node is InvocationNode => Boolean(node && node.type === 'invocation'); +export const isNotesNode = (node?: AnyNode): node is NotesNode => Boolean(node && node.type === 'notes'); +export const isCurrentImageNode = (node?: AnyNode): node is CurrentImageNode => + Boolean(node && node.type === 'current_image'); +export const isInvocationNodeData = (node?: AnyNodeData): node is InvocationNodeData => + Boolean(node && !['notes', 'current_image'].includes(node.type)); // node.type may be 'notes', 'current_image', or any invocation type +// #endregion + +// #region NodeExecutionState +export const zNodeStatus = z.enum(['PENDING', 'IN_PROGRESS', 'COMPLETED', 'FAILED']); +export const zNodeExecutionState = z.object({ + nodeId: z.string().trim().min(1), + status: zNodeStatus, + progress: z.number().nullable(), + progressImage: zProgressImage.nullable(), + error: z.string().nullable(), + outputs: z.array(z.any()), +}); +export type NodeExecutionState = z.infer; +export type NodeStatus = z.infer; +// #endregion + +// #region Edges +export const zInvocationNodeEdgeExtra = z.object({ + type: z.union([z.literal('default'), z.literal('collapsed')]), +}); +export type InvocationNodeEdgeExtra = z.infer; +export type InvocationNodeEdge = Edge; +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/metadata.ts b/invokeai/frontend/web/src/features/nodes/types/v2/metadata.ts new file mode 100644 index 0000000000..0cc30499e3 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/metadata.ts @@ -0,0 +1,77 @@ +import { z } from 'zod'; + +import { + zControlField, + zIPAdapterField, + zLoRAModelField, + zMainModelField, + zSDXLRefinerModelField, + zT2IAdapterField, + zVAEModelField, +} from './common'; + +// #region Metadata-optimized versions of schemas +// TODO: It's possible that `deepPartial` will be deprecated: +// - https://github.com/colinhacks/zod/issues/2106 +// - https://github.com/colinhacks/zod/issues/2854 +export const zLoRAMetadataItem = z.object({ + lora: zLoRAModelField.deepPartial(), + weight: z.number(), +}); +const zControlNetMetadataItem = zControlField.deepPartial(); +const zIPAdapterMetadataItem = zIPAdapterField.deepPartial(); +const zT2IAdapterMetadataItem = zT2IAdapterField.deepPartial(); +const zSDXLRefinerModelMetadataItem = zSDXLRefinerModelField.deepPartial(); +const zModelMetadataItem = zMainModelField.deepPartial(); +const zVAEModelMetadataItem = zVAEModelField.deepPartial(); +export type LoRAMetadataItem = z.infer; +export type ControlNetMetadataItem = z.infer; +export type IPAdapterMetadataItem = z.infer; +export type T2IAdapterMetadataItem = z.infer; +export type SDXLRefinerModelMetadataItem = z.infer; +export type ModelMetadataItem = z.infer; +export type VAEModelMetadataItem = z.infer; +// #endregion + +// #region CoreMetadata +export const zCoreMetadata = z + .object({ + app_version: z.string().nullish().catch(null), + generation_mode: z.string().nullish().catch(null), + created_by: z.string().nullish().catch(null), + positive_prompt: z.string().nullish().catch(null), + negative_prompt: z.string().nullish().catch(null), + width: z.number().int().nullish().catch(null), + height: z.number().int().nullish().catch(null), + seed: z.number().int().nullish().catch(null), + rand_device: z.string().nullish().catch(null), + cfg_scale: z.number().nullish().catch(null), + cfg_rescale_multiplier: z.number().nullish().catch(null), + steps: z.number().int().nullish().catch(null), + scheduler: z.string().nullish().catch(null), + clip_skip: z.number().int().nullish().catch(null), + model: zModelMetadataItem.nullish().catch(null), + controlnets: z.array(zControlNetMetadataItem).nullish().catch(null), + ipAdapters: z.array(zIPAdapterMetadataItem).nullish().catch(null), + t2iAdapters: z.array(zT2IAdapterMetadataItem).nullish().catch(null), + loras: z.array(zLoRAMetadataItem).nullish().catch(null), + vae: zVAEModelMetadataItem.nullish().catch(null), + strength: z.number().nullish().catch(null), + hrf_enabled: z.boolean().nullish().catch(null), + hrf_strength: z.number().nullish().catch(null), + hrf_method: z.string().nullish().catch(null), + init_image: z.string().nullish().catch(null), + positive_style_prompt: z.string().nullish().catch(null), + negative_style_prompt: z.string().nullish().catch(null), + refiner_model: zSDXLRefinerModelMetadataItem.nullish().catch(null), + refiner_cfg_scale: z.number().nullish().catch(null), + refiner_steps: z.number().int().nullish().catch(null), + refiner_scheduler: z.string().nullish().catch(null), + refiner_positive_aesthetic_score: z.number().nullish().catch(null), + refiner_negative_aesthetic_score: z.number().nullish().catch(null), + refiner_start: z.number().nullish().catch(null), + }) + .passthrough(); +export type CoreMetadata = z.infer; + +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/openapi.ts b/invokeai/frontend/web/src/features/nodes/types/v2/openapi.ts new file mode 100644 index 0000000000..83d774439a --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/openapi.ts @@ -0,0 +1,86 @@ +import type { OpenAPIV3_1 } from 'openapi-types'; +import type { + InputFieldJSONSchemaExtra, + InvocationJSONSchemaExtra, + OutputFieldJSONSchemaExtra, +} from 'services/api/types'; + +// Janky customization of OpenAPI Schema :/ + +export type InvocationSchemaExtra = InvocationJSONSchemaExtra & { + output: OpenAPIV3_1.ReferenceObject; // the output of the invocation + title: string; + category?: string; + tags?: string[]; + version: string; + properties: Omit< + NonNullable & (InputFieldJSONSchemaExtra | OutputFieldJSONSchemaExtra), + 'type' + > & { + type: Omit & { + default: string; + }; + use_cache: Omit & { + default: boolean; + }; + }; +}; + +export type InvocationSchemaType = { + default: string; // the type of the invocation +}; + +export type InvocationBaseSchemaObject = Omit & + InvocationSchemaExtra; + +export type InvocationOutputSchemaObject = Omit & { + properties: OpenAPIV3_1.SchemaObject['properties'] & { + type: Omit & { + default: string; + }; + } & { + class: 'output'; + }; +}; + +export type InvocationFieldSchema = OpenAPIV3_1.SchemaObject & InputFieldJSONSchemaExtra; + +export type OpenAPIV3_1SchemaOrRef = OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject; + +export interface ArraySchemaObject extends InvocationBaseSchemaObject { + type: OpenAPIV3_1.ArraySchemaObjectType; + items: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject; +} +export interface NonArraySchemaObject extends InvocationBaseSchemaObject { + type?: OpenAPIV3_1.NonArraySchemaObjectType; +} + +export type InvocationSchemaObject = (ArraySchemaObject | NonArraySchemaObject) & { class: 'invocation' }; + +export const isSchemaObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined +): obj is OpenAPIV3_1.SchemaObject => Boolean(obj && !('$ref' in obj)); + +export const isArraySchemaObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined +): obj is OpenAPIV3_1.ArraySchemaObject => Boolean(obj && !('$ref' in obj) && obj.type === 'array'); + +export const isNonArraySchemaObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined +): obj is OpenAPIV3_1.NonArraySchemaObject => Boolean(obj && !('$ref' in obj) && obj.type !== 'array'); + +export const isRefObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined +): obj is OpenAPIV3_1.ReferenceObject => Boolean(obj && '$ref' in obj); + +export const isInvocationSchemaObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | InvocationSchemaObject +): obj is InvocationSchemaObject => 'class' in obj && obj.class === 'invocation'; + +export const isInvocationOutputSchemaObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | InvocationOutputSchemaObject +): obj is InvocationOutputSchemaObject => 'class' in obj && obj.class === 'output'; + +export const isInvocationFieldSchema = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject +): obj is InvocationFieldSchema => !('$ref' in obj); diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/semver.ts b/invokeai/frontend/web/src/features/nodes/types/v2/semver.ts new file mode 100644 index 0000000000..3ba330eac4 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/semver.ts @@ -0,0 +1,21 @@ +import { z } from 'zod'; + +// Schemas and types for working with semver + +const zVersionInt = z.coerce.number().int().min(0); + +export const zSemVer = z.string().refine((val) => { + const [major, minor, patch] = val.split('.'); + return ( + zVersionInt.safeParse(major).success && zVersionInt.safeParse(minor).success && zVersionInt.safeParse(patch).success + ); +}); + +export const zParsedSemver = zSemVer.transform((val) => { + const [major, minor, patch] = val.split('.'); + return { + major: Number(major), + minor: Number(minor), + patch: Number(patch), + }; +}); diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/workflow.ts b/invokeai/frontend/web/src/features/nodes/types/v2/workflow.ts new file mode 100644 index 0000000000..723a354013 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/workflow.ts @@ -0,0 +1,89 @@ +import { z } from 'zod'; + +import { zFieldIdentifier } from './field'; +import { zInvocationNodeData, zNotesNodeData } from './invocation'; + +// #region Workflow misc +export const zXYPosition = z + .object({ + x: z.number(), + y: z.number(), + }) + .default({ x: 0, y: 0 }); +export type XYPosition = z.infer; + +export const zDimension = z.number().gt(0).nullish(); +export type Dimension = z.infer; + +export const zWorkflowCategory = z.enum(['user', 'default', 'project']); +export type WorkflowCategory = z.infer; +// #endregion + +// #region Workflow Nodes +export const zWorkflowInvocationNode = z.object({ + id: z.string().trim().min(1), + type: z.literal('invocation'), + data: zInvocationNodeData, + width: zDimension, + height: zDimension, + position: zXYPosition, +}); +export const zWorkflowNotesNode = z.object({ + id: z.string().trim().min(1), + type: z.literal('notes'), + data: zNotesNodeData, + width: zDimension, + height: zDimension, + position: zXYPosition, +}); +export const zWorkflowNode = z.union([zWorkflowInvocationNode, zWorkflowNotesNode]); + +export type WorkflowInvocationNode = z.infer; +export type WorkflowNotesNode = z.infer; +export type WorkflowNode = z.infer; + +export const isWorkflowInvocationNode = (val: unknown): val is WorkflowInvocationNode => + zWorkflowInvocationNode.safeParse(val).success; +// #endregion + +// #region Workflow Edges +export const zWorkflowEdgeBase = z.object({ + id: z.string().trim().min(1), + source: z.string().trim().min(1), + target: z.string().trim().min(1), +}); +export const zWorkflowEdgeDefault = zWorkflowEdgeBase.extend({ + type: z.literal('default'), + sourceHandle: z.string().trim().min(1), + targetHandle: z.string().trim().min(1), +}); +export const zWorkflowEdgeCollapsed = zWorkflowEdgeBase.extend({ + type: z.literal('collapsed'), +}); +export const zWorkflowEdge = z.union([zWorkflowEdgeDefault, zWorkflowEdgeCollapsed]); + +export type WorkflowEdgeDefault = z.infer; +export type WorkflowEdgeCollapsed = z.infer; +export type WorkflowEdge = z.infer; +// #endregion + +// #region Workflow +export const zWorkflowV2 = z.object({ + id: z.string().min(1).optional(), + name: z.string(), + author: z.string(), + description: z.string(), + version: z.string(), + contact: z.string(), + tags: z.string(), + notes: z.string(), + nodes: z.array(zWorkflowNode), + edges: z.array(zWorkflowEdge), + exposedFields: z.array(zFieldIdentifier), + meta: z.object({ + category: zWorkflowCategory.default('user'), + version: z.literal('2.0.0'), + }), +}); +export type WorkflowV2 = z.infer; +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/workflow.ts b/invokeai/frontend/web/src/features/nodes/types/workflow.ts index 723a354013..adad7c0f21 100644 --- a/invokeai/frontend/web/src/features/nodes/types/workflow.ts +++ b/invokeai/frontend/web/src/features/nodes/types/workflow.ts @@ -24,16 +24,12 @@ export const zWorkflowInvocationNode = z.object({ id: z.string().trim().min(1), type: z.literal('invocation'), data: zInvocationNodeData, - width: zDimension, - height: zDimension, position: zXYPosition, }); export const zWorkflowNotesNode = z.object({ id: z.string().trim().min(1), type: z.literal('notes'), data: zNotesNodeData, - width: zDimension, - height: zDimension, position: zXYPosition, }); export const zWorkflowNode = z.union([zWorkflowInvocationNode, zWorkflowNotesNode]); @@ -68,7 +64,7 @@ export type WorkflowEdge = z.infer; // #endregion // #region Workflow -export const zWorkflowV2 = z.object({ +export const zWorkflowV3 = z.object({ id: z.string().min(1).optional(), name: z.string(), author: z.string(), @@ -82,8 +78,8 @@ export const zWorkflowV2 = z.object({ exposedFields: z.array(zFieldIdentifier), meta: z.object({ category: zWorkflowCategory.default('user'), - version: z.literal('2.0.0'), + version: z.literal('3.0.0'), }), }); -export type WorkflowV2 = z.infer; +export type WorkflowV3 = z.infer; // #endregion diff --git a/invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts b/invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts index ea40bd4660..af19aa86ea 100644 --- a/invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts +++ b/invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts @@ -1,5 +1,5 @@ import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants'; -import type { FieldInputInstance, FieldOutputInstance } from 'features/nodes/types/field'; +import type { FieldInputInstance } from 'features/nodes/types/field'; import type { InvocationNode, InvocationTemplate } from 'features/nodes/types/invocation'; import { buildFieldInputInstance } from 'features/nodes/util/schema/buildFieldInputInstance'; import { reduce } from 'lodash-es'; @@ -24,25 +24,6 @@ export const buildInvocationNode = (position: XYPosition, template: InvocationTe {} as Record ); - const outputs = reduce( - template.outputs, - (outputsAccumulator, outputTemplate, outputName) => { - const fieldId = uuidv4(); - - const outputFieldValue: FieldOutputInstance = { - id: fieldId, - name: outputName, - type: outputTemplate.type, - fieldKind: 'output', - }; - - outputsAccumulator[outputName] = outputFieldValue; - - return outputsAccumulator; - }, - {} as Record - ); - const node: InvocationNode = { ...SHARED_NODE_PROPERTIES, id: nodeId, @@ -58,7 +39,6 @@ export const buildInvocationNode = (position: XYPosition, template: InvocationTe isIntermediate: type === 'save_image' ? false : true, useCache: template.useCache, inputs, - outputs, }, }; diff --git a/invokeai/frontend/web/src/features/nodes/util/node/nodeUpdate.ts b/invokeai/frontend/web/src/features/nodes/util/node/nodeUpdate.ts index f195c49d30..5ece51d0f3 100644 --- a/invokeai/frontend/web/src/features/nodes/util/node/nodeUpdate.ts +++ b/invokeai/frontend/web/src/features/nodes/util/node/nodeUpdate.ts @@ -54,6 +54,5 @@ export const updateNode = (node: InvocationNode, template: InvocationTemplate): // Remove any fields that are not in the template clone.data.inputs = pick(clone.data.inputs, keys(defaults.data.inputs)); - clone.data.outputs = pick(clone.data.outputs, keys(defaults.data.outputs)); return clone; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts index dd3cf0ad7b..f8097566c9 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts @@ -23,11 +23,8 @@ const FIELD_VALUE_FALLBACK_MAP: Record = export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => { const fieldInstance: FieldInputInstance = { - id, name: template.name, - type: template.type, label: '', - fieldKind: 'input' as const, value: template.default ?? get(FIELD_VALUE_FALLBACK_MAP, template.type.name), }; diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts index 70775a9882..720da16464 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts @@ -2,8 +2,8 @@ import { logger } from 'app/logging/logger'; import { parseify } from 'common/util/serialize'; import type { NodesState, WorkflowsState } from 'features/nodes/store/types'; import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; -import { zWorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; +import { zWorkflowV3 } from 'features/nodes/types/workflow'; import i18n from 'i18n'; import { cloneDeep, pick } from 'lodash-es'; import { fromZodError } from 'zod-validation-error'; @@ -25,14 +25,14 @@ const workflowKeys = [ 'exposedFields', 'meta', 'id', -] satisfies (keyof WorkflowV2)[]; +] satisfies (keyof WorkflowV3)[]; -export type BuildWorkflowFunction = (arg: BuildWorkflowArg) => WorkflowV2; +export type BuildWorkflowFunction = (arg: BuildWorkflowArg) => WorkflowV3; -export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV2 => { +export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV3 => { const clonedWorkflow = pick(cloneDeep(workflow), workflowKeys); - const newWorkflow: WorkflowV2 = { + const newWorkflow: WorkflowV3 = { ...clonedWorkflow, nodes: [], edges: [], @@ -45,8 +45,6 @@ export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflo type: node.type, data: cloneDeep(node.data), position: { ...node.position }, - width: node.width, - height: node.height, }); } else if (isNotesNode(node) && node.type) { newWorkflow.nodes.push({ @@ -54,8 +52,6 @@ export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflo type: node.type, data: cloneDeep(node.data), position: { ...node.position }, - width: node.width, - height: node.height, }); } }); @@ -83,12 +79,12 @@ export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflo return newWorkflow; }; -export const buildWorkflowWithValidation = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV2 | null => { +export const buildWorkflowWithValidation = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV3 | null => { // builds what really, really should be a valid workflow const workflowToValidate = buildWorkflowFast({ nodes, edges, workflow }); // but bc we are storing this in the DB, let's be extra sure - const result = zWorkflowV2.safeParse(workflowToValidate); + const result = zWorkflowV3.safeParse(workflowToValidate); if (!result.success) { const { message } = fromZodError(result.error, { diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts index a2677f3d17..a023c96ba9 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts @@ -6,8 +6,10 @@ import { zSemVer } from 'features/nodes/types/semver'; import { FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING } from 'features/nodes/types/v1/fieldTypeMap'; import type { WorkflowV1 } from 'features/nodes/types/v1/workflowV1'; import { zWorkflowV1 } from 'features/nodes/types/v1/workflowV1'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; -import { zWorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV2 } from 'features/nodes/types/v2/workflow'; +import { zWorkflowV2 } from 'features/nodes/types/v2/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; +import { zWorkflowV3 } from 'features/nodes/types/workflow'; import { t } from 'i18next'; import { forEach } from 'lodash-es'; import { z } from 'zod'; @@ -30,7 +32,7 @@ const zWorkflowMetaVersion = z.object({ * - Workflow schema version bumped to 2.0.0 */ const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => { - const invocationTemplates = $store.get()?.getState().nodeTemplates.templates; + const invocationTemplates = $store.get()?.getState().nodes.templates; if (!invocationTemplates) { throw new Error(t('app.storeNotInitialized')); @@ -70,26 +72,34 @@ const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => { return zWorkflowV2.parse(workflowToMigrate); }; +const migrateV2toV3 = (workflowToMigrate: WorkflowV2): WorkflowV3 => { + // Bump version + (workflowToMigrate as unknown as WorkflowV3).meta.version = '3.0.0'; + // Parsing strips out any extra properties not in the latest version + return zWorkflowV3.parse(workflowToMigrate); +}; + /** * Parses a workflow and migrates it to the latest version if necessary. */ -export const parseAndMigrateWorkflow = (data: unknown): WorkflowV2 => { +export const parseAndMigrateWorkflow = (data: unknown): WorkflowV3 => { const workflowVersionResult = zWorkflowMetaVersion.safeParse(data); if (!workflowVersionResult.success) { throw new WorkflowVersionError(t('nodes.unableToGetWorkflowVersion')); } - const { version } = workflowVersionResult.data.meta; + let workflow = data as WorkflowV1 | WorkflowV2 | WorkflowV3; - if (version === '1.0.0') { - const v1 = zWorkflowV1.parse(data); - return migrateV1toV2(v1); + if (workflow.meta.version === '1.0.0') { + const v1 = zWorkflowV1.parse(workflow); + workflow = migrateV1toV2(v1); } - if (version === '2.0.0') { - return zWorkflowV2.parse(data); + if (workflow.meta.version === '2.0.0') { + const v2 = zWorkflowV2.parse(workflow); + workflow = migrateV2toV3(v2); } - throw new WorkflowVersionError(t('nodes.unrecognizedWorkflowVersion', { version })); + return workflow as WorkflowV3; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts index 848d2aee77..5096e588b0 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts @@ -1,6 +1,6 @@ import { parseify } from 'common/util/serialize'; import type { InvocationTemplate } from 'features/nodes/types/invocation'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; import { isWorkflowInvocationNode } from 'features/nodes/types/workflow'; import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate'; import { t } from 'i18next'; @@ -16,7 +16,7 @@ type WorkflowWarning = { }; type ValidateWorkflowResult = { - workflow: WorkflowV2; + workflow: WorkflowV3; warnings: WorkflowWarning[]; }; diff --git a/invokeai/frontend/web/src/features/workflowLibrary/hooks/useSaveWorkflow.ts b/invokeai/frontend/web/src/features/workflowLibrary/hooks/useSaveWorkflow.ts index 5d484b6897..7b49d70213 100644 --- a/invokeai/frontend/web/src/features/workflowLibrary/hooks/useSaveWorkflow.ts +++ b/invokeai/frontend/web/src/features/workflowLibrary/hooks/useSaveWorkflow.ts @@ -3,7 +3,7 @@ import { useToast } from '@invoke-ai/ui-library'; import { useAppDispatch } from 'app/store/storeHooks'; import { $builtWorkflow } from 'features/nodes/hooks/useWorkflowWatcher'; import { workflowIDChanged, workflowSaved } from 'features/nodes/store/workflowSlice'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; import { workflowUpdated } from 'features/workflowLibrary/store/actions'; import { useCallback, useRef } from 'react'; import { useTranslation } from 'react-i18next'; @@ -18,7 +18,7 @@ type UseSaveLibraryWorkflowReturn = { type UseSaveLibraryWorkflow = () => UseSaveLibraryWorkflowReturn; -export const isWorkflowWithID = (workflow: WorkflowV2): workflow is O.Required => +export const isWorkflowWithID = (workflow: WorkflowV3): workflow is O.Required => Boolean(workflow.id); export const useSaveLibraryWorkflow: UseSaveLibraryWorkflow = () => {