diff --git a/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts b/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts index 2e2d2014b2..ed8c82d91c 100644 --- a/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts +++ b/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts @@ -1,6 +1,6 @@ import type { UnknownAction } from '@reduxjs/toolkit'; import { isAnyGraphBuilt } from 'features/nodes/store/actions'; -import { nodeTemplatesBuilt } from 'features/nodes/store/nodeTemplatesSlice'; +import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice'; import { cloneDeep } from 'lodash-es'; import { appInfoApi } from 'services/api/endpoints/appInfo'; import type { Graph } from 'services/api/types'; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema.ts index b2d3615909..88518e2c0b 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema.ts @@ -1,6 +1,6 @@ import { logger } from 'app/logging/logger'; import { parseify } from 'common/util/serialize'; -import { nodeTemplatesBuilt } from 'features/nodes/store/nodeTemplatesSlice'; +import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice'; import { parseSchema } from 'features/nodes/util/schema/parseSchema'; import { size } from 'lodash-es'; import { appInfoApi } from 'services/api/endpoints/appInfo'; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts index 752c3b09df..ac1298da5b 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts @@ -15,8 +15,7 @@ export const addUpdateAllNodesRequestedListener = () => { actionCreator: updateAllNodesRequested, effect: (action, { dispatch, getState }) => { const log = logger('nodes'); - const nodes = getState().nodes.nodes; - const templates = getState().nodeTemplates.templates; + const { nodes, templates } = getState().nodes; let unableToUpdateCount = 0; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts index 9307031e6d..ad41dc2654 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts @@ -18,7 +18,7 @@ export const addWorkflowLoadRequestedListener = () => { effect: (action, { dispatch, getState }) => { const log = logger('nodes'); const { workflow, asCopy } = action.payload; - const nodeTemplates = getState().nodeTemplates.templates; + const nodeTemplates = getState().nodes.templates; try { const { workflow: validatedWorkflow, warnings } = validateWorkflow(workflow, nodeTemplates); diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index e25e1351eb..270662c3d2 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -16,7 +16,6 @@ import { hrfPersistConfig, hrfSlice } from 'features/hrf/store/hrfSlice'; import { loraPersistConfig, loraSlice } from 'features/lora/store/loraSlice'; import { modelManagerPersistConfig, modelManagerSlice } from 'features/modelManager/store/modelManagerSlice'; import { nodesPersistConfig, nodesSlice } from 'features/nodes/store/nodesSlice'; -import { nodesTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; import { workflowPersistConfig, workflowSlice } from 'features/nodes/store/workflowSlice'; import { generationPersistConfig, generationSlice } from 'features/parameters/store/generationSlice'; import { postprocessingPersistConfig, postprocessingSlice } from 'features/parameters/store/postprocessingSlice'; @@ -46,7 +45,6 @@ const allReducers = { [gallerySlice.name]: gallerySlice.reducer, [generationSlice.name]: generationSlice.reducer, [nodesSlice.name]: nodesSlice.reducer, - [nodesTemplatesSlice.name]: nodesTemplatesSlice.reducer, [postprocessingSlice.name]: postprocessingSlice.reducer, [systemSlice.name]: systemSlice.reducer, [configSlice.name]: configSlice.reducer, diff --git a/invokeai/frontend/web/src/app/store/storeHooks.ts b/invokeai/frontend/web/src/app/store/storeHooks.ts index f1a9aa979c..6bc904acb3 100644 --- a/invokeai/frontend/web/src/app/store/storeHooks.ts +++ b/invokeai/frontend/web/src/app/store/storeHooks.ts @@ -1,7 +1,8 @@ import type { AppThunkDispatch, RootState } from 'app/store/store'; import type { TypedUseSelectorHook } from 'react-redux'; -import { useDispatch, useSelector } from 'react-redux'; +import { useDispatch, useSelector, useStore } from 'react-redux'; // Use throughout your app instead of plain `useDispatch` and `useSelector` export const useAppDispatch = () => useDispatch(); export const useAppSelector: TypedUseSelectorHook = useSelector; +export const useAppStore = () => useStore(); diff --git a/invokeai/frontend/web/src/app/store/util.ts b/invokeai/frontend/web/src/app/store/util.ts new file mode 100644 index 0000000000..381f7f85d2 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/util.ts @@ -0,0 +1,2 @@ +export const EMPTY_ARRAY = []; +export const EMPTY_OBJECT = {}; diff --git a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts index 4952fa1c47..baa704e75c 100644 --- a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts +++ b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts @@ -8,7 +8,6 @@ import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types'; import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice'; import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { selectGenerationSlice } from 'features/parameters/store/generationSlice'; import { selectSystemSlice } from 'features/system/store/systemSlice'; @@ -23,11 +22,10 @@ const selector = createMemoizedSelector( selectGenerationSlice, selectSystemSlice, selectNodesSlice, - selectNodeTemplatesSlice, selectDynamicPromptsSlice, activeTabNameSelector, ], - (controlAdapters, generation, system, nodes, nodeTemplates, dynamicPrompts, activeTabName) => { + (controlAdapters, generation, system, nodes, dynamicPrompts, activeTabName) => { const { initialImage, model, positivePrompt } = generation; const { isConnected } = system; @@ -54,7 +52,7 @@ const selector = createMemoizedSelector( return; } - const nodeTemplate = nodeTemplates.templates[node.data.type]; + const nodeTemplate = nodes.templates[node.data.type]; if (!nodeTemplate) { // Node type not found diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx index b24b52c6ab..061209cafc 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx @@ -7,8 +7,12 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import type { SelectInstance } from 'chakra-react-select'; import { useBuildNode } from 'features/nodes/hooks/useBuildNode'; -import { addNodePopoverClosed, addNodePopoverOpened, nodeAdded } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; +import { + addNodePopoverClosed, + addNodePopoverOpened, + nodeAdded, + selectNodesSlice, +} from 'features/nodes/store/nodesSlice'; import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes'; import { filter, map, memoize, some } from 'lodash-es'; import type { KeyboardEventHandler } from 'react'; @@ -54,10 +58,10 @@ const AddNodePopover = () => { const fieldFilter = useAppSelector((s) => s.nodes.connectionStartFieldType); const handleFilter = useAppSelector((s) => s.nodes.connectionStartParams?.handleType); - const selector = createMemoizedSelector(selectNodeTemplatesSlice, (nodeTemplates) => { + const selector = createMemoizedSelector(selectNodesSlice, (nodes) => { // If we have a connection in progress, we need to filter the node choices const filteredNodeTemplates = fieldFilter - ? filter(nodeTemplates.templates, (template) => { + ? filter(nodes.templates, (template) => { const handles = handleFilter === 'source' ? template.inputs : template.outputs; return some(handles, (handle) => { @@ -67,7 +71,7 @@ const AddNodePopover = () => { return validateSourceAndTargetTypes(sourceType, targetType); }); }) - : map(nodeTemplates.templates); + : map(nodes.templates); const options: ComboboxOption[] = map(filteredNodeTemplates, (template) => { return { diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts index 4bfc588e67..ba40b4984c 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts @@ -1,10 +1,17 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; +import { selectFieldOutputTemplate } from 'features/nodes/store/selectors'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { getFieldColor } from './getEdgeColor'; +const defaultReturnValue = { + isSelected: false, + shouldAnimate: false, + stroke: colorTokenToCssVar('base.500'), +}; + export const makeEdgeSelector = ( source: string, sourceHandleId: string | null | undefined, @@ -12,14 +19,19 @@ export const makeEdgeSelector = ( targetHandleId: string | null | undefined, selected?: boolean ) => - createMemoizedSelector(selectNodesSlice, (nodes) => { + createMemoizedSelector(selectNodesSlice, (nodes): { isSelected: boolean; shouldAnimate: boolean; stroke: string } => { const sourceNode = nodes.nodes.find((node) => node.id === source); const targetNode = nodes.nodes.find((node) => node.id === target); const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode); - const isSelected = sourceNode?.selected || targetNode?.selected || selected; - const sourceType = isInvocationToInvocationEdge ? sourceNode?.data?.outputs[sourceHandleId || '']?.type : undefined; + const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected); + if (!sourceNode || !sourceHandleId) { + return defaultReturnValue; + } + + const outputFieldTemplate = selectFieldOutputTemplate(nodes, sourceNode.id, sourceHandleId); + const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined; const stroke = sourceType && nodes.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500'); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeCollapsedHandles.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeCollapsedHandles.tsx index c287842f6e..b888e8a516 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeCollapsedHandles.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeCollapsedHandles.tsx @@ -1,6 +1,5 @@ import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens'; -import { useNodeData } from 'features/nodes/hooks/useNodeData'; -import { isInvocationNodeData } from 'features/nodes/types/invocation'; +import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate'; import { map } from 'lodash-es'; import type { CSSProperties } from 'react'; import { memo, useMemo } from 'react'; @@ -13,7 +12,7 @@ interface Props { const hiddenHandleStyles: CSSProperties = { visibility: 'hidden' }; const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => { - const data = useNodeData(nodeId); + const template = useNodeTemplate(nodeId); const { base600 } = useChakraThemeTokens(); const dummyHandleStyles: CSSProperties = useMemo( @@ -37,7 +36,7 @@ const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => { [dummyHandleStyles] ); - if (!isInvocationNodeData(data)) { + if (!template) { return null; } @@ -45,14 +44,14 @@ const InvocationNodeCollapsedHandles = ({ nodeId }: Props) => { <> - {map(data.inputs, (input) => ( + {map(template.inputs, (input) => ( { ))} - {map(data.outputs, (output) => ( + {map(template.outputs, (output) => ( ) => { const { id: nodeId, type, isOpen, label } = data; const hasTemplateSelector = useMemo( - () => createSelector(selectNodeTemplatesSlice, (nodeTemplates) => Boolean(nodeTemplates.templates[type])), + () => createSelector(selectNodesSlice, (nodes) => Boolean(nodes.templates[type])), [type] ); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/EditableFieldTitle.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/EditableFieldTitle.tsx index c2231f703a..e02b1a1474 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/EditableFieldTitle.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/EditableFieldTitle.tsx @@ -22,7 +22,7 @@ import FieldTooltipContent from './FieldTooltipContent'; interface Props { nodeId: string; fieldName: string; - kind: 'input' | 'output'; + kind: 'inputs' | 'outputs'; isMissingInput?: boolean; withTooltip?: boolean; } @@ -58,7 +58,7 @@ const EditableFieldTitle = forwardRef((props: Props, ref) => { return ( : undefined} + label={withTooltip ? : undefined} openDelay={HANDLE_TOOLTIP_OPEN_DELAY} > { - const field = useFieldInstance(nodeId, fieldName); + const field = useFieldInputInstance(nodeId, fieldName); const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind); const isInputTemplate = isFieldInputTemplate(fieldTemplate); const fieldTypeName = useFieldTypeName(fieldTemplate?.type); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx index 2b9f7960e4..66b0d3f755 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx @@ -25,7 +25,7 @@ const InputField = ({ nodeId, fieldName }: Props) => { const [isHovered, setIsHovered] = useState(false); const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } = - useConnectionState({ nodeId, fieldName, kind: 'input' }); + useConnectionState({ nodeId, fieldName, kind: 'inputs' }); const isMissingInput = useMemo(() => { if (!fieldTemplate) { @@ -76,7 +76,7 @@ const InputField = ({ nodeId, fieldName }: Props) => { @@ -101,7 +101,7 @@ const InputField = ({ nodeId, fieldName }: Props) => { diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx index c1d52c1d4f..b6e331c114 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx @@ -1,6 +1,5 @@ -import { Box, Text } from '@invoke-ai/ui-library'; -import { useFieldInstance } from 'features/nodes/hooks/useFieldData'; -import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate'; +import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance'; +import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate'; import { isBoardFieldInputInstance, isBoardFieldInputTemplate, @@ -38,7 +37,6 @@ import { isVAEModelFieldInputTemplate, } from 'features/nodes/types/field'; import { memo } from 'react'; -import { useTranslation } from 'react-i18next'; import BoardFieldInputComponent from './inputs/BoardFieldInputComponent'; import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent'; @@ -63,17 +61,8 @@ type InputFieldProps = { }; const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { - const { t } = useTranslation(); - const fieldInstance = useFieldInstance(nodeId, fieldName); - const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input'); - - if (fieldTemplate?.fieldKind === 'output') { - return ( - - {t('nodes.outputFieldInInput')}: {fieldInstance?.type.name} - - ); - } + const fieldInstance = useFieldInputInstance(nodeId, fieldName); + const fieldTemplate = useFieldInputTemplate(nodeId, fieldName); if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) { return ; @@ -141,18 +130,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { return ; } - if (fieldInstance && fieldTemplate) { + if (fieldTemplate) { // Fallback for when there is no component for the type return null; } - - return ( - - - {t('nodes.unknownFieldType', { type: fieldInstance?.type.name })} - - - ); }; export default memo(InputFieldRenderer); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx index d0a30ecc3c..0cd199f7a4 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx @@ -62,7 +62,7 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => { /> - + {isValueChanged && ( { /> )} } + label={} openDelay={HANDLE_TOOLTIP_OPEN_DELAY} placement="top" > diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx index 48c4c0d740..f2d776a2da 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx @@ -1,6 +1,5 @@ import { Flex, FormControl, FormLabel, Tooltip } from '@invoke-ai/ui-library'; import { useConnectionState } from 'features/nodes/hooks/useConnectionState'; -import { useFieldOutputInstance } from 'features/nodes/hooks/useFieldOutputInstance'; import { useFieldOutputTemplate } from 'features/nodes/hooks/useFieldOutputTemplate'; import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants'; import type { PropsWithChildren } from 'react'; @@ -18,18 +17,17 @@ interface Props { const OutputField = ({ nodeId, fieldName }: Props) => { const { t } = useTranslation(); const fieldTemplate = useFieldOutputTemplate(nodeId, fieldName); - const fieldInstance = useFieldOutputInstance(nodeId, fieldName); const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } = - useConnectionState({ nodeId, fieldName, kind: 'output' }); + useConnectionState({ nodeId, fieldName, kind: 'outputs' }); - if (!fieldTemplate || !fieldInstance) { + if (!fieldTemplate) { return ( {t('nodes.unknownOutput', { - name: fieldTemplate?.title ?? fieldName, + name: fieldName, })} @@ -40,7 +38,7 @@ const OutputField = ({ nodeId, fieldName }: Props) => { return ( } + label={} openDelay={HANDLE_TOOLTIP_OPEN_DELAY} placement="top" shouldWrapChildren diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx index b7c9033d6b..d72d2f5aa8 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx @@ -6,19 +6,18 @@ import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableCon import NotesTextarea from 'features/nodes/components/flow/nodes/Invocation/NotesTextarea'; import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; import EditableNodeTitle from './details/EditableNodeTitle'; -const selector = createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { +const selector = createMemoizedSelector(selectNodesSlice, (nodes) => { const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1]; const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId); - const lastSelectedNodeTemplate = lastSelectedNode ? nodeTemplates.templates[lastSelectedNode.data.type] : undefined; + const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined; if (!isInvocationNode(lastSelectedNode) || !lastSelectedNodeTemplate) { return; diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx index ee7dfaa693..978eeddd24 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx @@ -5,7 +5,6 @@ import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -14,12 +13,12 @@ import type { AnyResult } from 'services/events/types'; import ImageOutputPreview from './outputs/ImageOutputPreview'; -const selector = createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { +const selector = createMemoizedSelector(selectNodesSlice, (nodes) => { const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1]; const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId); - const lastSelectedNodeTemplate = lastSelectedNode ? nodeTemplates.templates[lastSelectedNode.data.type] : undefined; + const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined; const nes = nodes.nodeExecutionStates[lastSelectedNodeId ?? '__UNKNOWN_NODE__']; diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorTemplateTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorTemplateTab.tsx index 28f0e82d68..ea6e8ed704 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorTemplateTab.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorTemplateTab.tsx @@ -3,16 +3,15 @@ import { useAppSelector } from 'app/store/storeHooks'; import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; -const selector = createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { +const selector = createMemoizedSelector(selectNodesSlice, (nodes) => { const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1]; const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId); - const lastSelectedNodeTemplate = lastSelectedNode ? nodeTemplates.templates[lastSelectedNode.data.type] : undefined; + const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined; return { template: lastSelectedNodeTemplate, diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts index d0263a8bda..c882924e24 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts @@ -1,26 +1,22 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; +import { EMPTY_ARRAY } from 'app/store/util'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames'; import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate'; import { keys, map } from 'lodash-es'; import { useMemo } from 'react'; -export const useAnyOrDirectInputFieldNames = (nodeId: string) => { +export const useAnyOrDirectInputFieldNames = (nodeId: string): string[] => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return []; + createMemoizedSelector(selectNodesSlice, (nodes) => { + const template = selectNodeTemplate(nodes, nodeId); + if (!template) { + return EMPTY_ARRAY; } - const nodeTemplate = nodeTemplates.templates[node.data.type]; - if (!nodeTemplate) { - return []; - } - const fields = map(nodeTemplate.inputs).filter( + const fields = map(template.inputs).filter( (field) => (['any', 'direct'].includes(field.input) || field.type.isCollectionOrScalar) && keys(TEMPLATE_BUILDER_MAP).includes(field.type.name) diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts index aecc931893..b19edf3c85 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts @@ -13,7 +13,7 @@ export const SHARED_NODE_PROPERTIES: Partial = { }; export const useBuildNode = () => { - const nodeTemplates = useAppSelector((s) => s.nodeTemplates.templates); + const nodeTemplates = useAppSelector((s) => s.nodes.templates); const flow = useReactFlow(); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts index 23f318517b..dc8a05b88c 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts @@ -1,28 +1,24 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; +import { EMPTY_ARRAY } from 'app/store/util'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames'; import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate'; import { keys, map } from 'lodash-es'; import { useMemo } from 'react'; -export const useConnectionInputFieldNames = (nodeId: string) => { +export const useConnectionInputFieldNames = (nodeId: string): string[] => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return []; - } - const nodeTemplate = nodeTemplates.templates[node.data.type]; - if (!nodeTemplate) { - return []; + createMemoizedSelector(selectNodesSlice, (nodes) => { + const template = selectNodeTemplate(nodes, nodeId); + if (!template) { + return EMPTY_ARRAY; } // get the visible fields - const fields = map(nodeTemplate.inputs).filter( + const fields = map(template.inputs).filter( (field) => (field.input === 'connection' && !field.type.isCollectionOrScalar) || !keys(TEMPLATE_BUILDER_MAP).includes(field.type.name) diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts index a6f8b663f6..97b96f323a 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts @@ -14,7 +14,7 @@ const selectIsConnectionInProgress = createSelector( export type UseConnectionStateProps = { nodeId: string; fieldName: string; - kind: 'input' | 'output'; + kind: 'inputs' | 'outputs'; }; export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => { @@ -26,8 +26,8 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta Boolean( nodes.edges.filter((edge) => { return ( - (kind === 'input' ? edge.target : edge.source) === nodeId && - (kind === 'input' ? edge.targetHandle : edge.sourceHandle) === fieldName + (kind === 'inputs' ? edge.target : edge.source) === nodeId && + (kind === 'inputs' ? edge.targetHandle : edge.sourceHandle) === fieldName ); }).length ) @@ -36,7 +36,7 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta ); const selectConnectionError = useMemo( - () => makeConnectionErrorSelector(nodeId, fieldName, kind === 'input' ? 'target' : 'source', fieldType), + () => makeConnectionErrorSelector(nodeId, fieldName, kind === 'inputs' ? 'target' : 'source', fieldType), [nodeId, fieldName, kind, fieldType] ); @@ -46,7 +46,7 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta Boolean( nodes.connectionStartParams?.nodeId === nodeId && nodes.connectionStartParams?.handleId === fieldName && - nodes.connectionStartParams?.handleType === { input: 'target', output: 'source' }[kind] + nodes.connectionStartParams?.handleType === { inputs: 'target', outputs: 'source' }[kind] ) ), [fieldName, kind, nodeId] diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts b/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts index bfbf0a3b2d..91994cf752 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts @@ -2,23 +2,19 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { compareVersions } from 'compare-versions'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeData, selectNodeTemplate } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useDoNodeVersionsMatch = (nodeId: string) => { +export const useDoNodeVersionsMatch = (nodeId: string): boolean => { const selector = useMemo( () => - createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { + createSelector(selectNodesSlice, (nodes) => { + const data = selectNodeData(nodes, nodeId); + const template = selectNodeTemplate(nodes, nodeId); + if (!template?.version || !data?.version) { return false; } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - if (!nodeTemplate?.version || !node.data?.version) { - return false; - } - return compareVersions(nodeTemplate.version, node.data.version) === 0; + return compareVersions(template.version, data.version) === 0; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts b/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts index cfe5c90d9c..5051eaa55b 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts @@ -1,18 +1,18 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeData } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useDoesInputHaveValue = (nodeId: string, fieldName: string) => { +export const useDoesInputHaveValue = (nodeId: string, fieldName: string): boolean => { const selector = useMemo( () => createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; + const data = selectNodeData(nodes, nodeId); + if (!data) { + return false; } - return node?.data.inputs[fieldName]?.value !== undefined; + return data.inputs[fieldName]?.value !== undefined; }), [fieldName, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldData.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldData.ts deleted file mode 100644 index 8b35a2d44b..0000000000 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldData.ts +++ /dev/null @@ -1,23 +0,0 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; -import { useAppSelector } from 'app/store/storeHooks'; -import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; -import { useMemo } from 'react'; - -export const useFieldInstance = (nodeId: string, fieldName: string) => { - const selector = useMemo( - () => - createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - return node?.data.inputs[fieldName]; - }), - [fieldName, nodeId] - ); - - const fieldData = useAppSelector(selector); - - return fieldData; -}; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputInstance.ts index 0793f1f952..25065e7aba 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputInstance.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputInstance.ts @@ -1,23 +1,20 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputInstance } from 'features/nodes/store/selectors'; +import type { FieldInputInstance } from 'features/nodes/types/field'; import { useMemo } from 'react'; -export const useFieldInputInstance = (nodeId: string, fieldName: string) => { +export const useFieldInputInstance = (nodeId: string, fieldName: string): FieldInputInstance | null => { const selector = useMemo( () => createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - return node.data.inputs[fieldName]; + return selectFieldInputInstance(nodes, nodeId, fieldName); }), [fieldName, nodeId] ); - const fieldTemplate = useAppSelector(selector); + const fieldData = useAppSelector(selector); - return fieldTemplate; + return fieldData; }; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts index 11d44dbde2..08de3d9b20 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts @@ -1,21 +1,16 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputTemplate } from 'features/nodes/store/selectors'; +import type { FieldInput } from 'features/nodes/types/field'; import { useMemo } from 'react'; export const useFieldInputKind = (nodeId: string, fieldName: string) => { const selector = useMemo( () => - createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - const fieldTemplate = nodeTemplate?.inputs[fieldName]; - return fieldTemplate?.input; + createSelector(selectNodesSlice, (nodes): FieldInput | null => { + const template = selectFieldInputTemplate(nodes, nodeId, fieldName); + return template?.input ?? null; }), [fieldName, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts index 8533d2be8d..e8289d7e07 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts @@ -1,20 +1,15 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputTemplate } from 'features/nodes/store/selectors'; +import type { FieldInputTemplate } from 'features/nodes/types/field'; import { useMemo } from 'react'; -export const useFieldInputTemplate = (nodeId: string, fieldName: string) => { +export const useFieldInputTemplate = (nodeId: string, fieldName: string): FieldInputTemplate | null => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - return nodeTemplate?.inputs[fieldName]; + createMemoizedSelector(selectNodesSlice, (nodes) => { + return selectFieldInputTemplate(nodes, nodeId, fieldName); }), [fieldName, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldLabel.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldLabel.ts index ef57956047..92eab8d1b1 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldLabel.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldLabel.ts @@ -1,18 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputInstance } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useFieldLabel = (nodeId: string, fieldName: string) => { +export const useFieldLabel = (nodeId: string, fieldName: string): string | null => { const selector = useMemo( () => createSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - return node?.data.inputs[fieldName]?.label; + return selectFieldInputInstance(nodes, nodeId, fieldName)?.label ?? null; }), [fieldName, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputInstance.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputInstance.ts deleted file mode 100644 index 8b71f1ea01..0000000000 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputInstance.ts +++ /dev/null @@ -1,23 +0,0 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; -import { useAppSelector } from 'app/store/storeHooks'; -import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; -import { useMemo } from 'react'; - -export const useFieldOutputInstance = (nodeId: string, fieldName: string) => { - const selector = useMemo( - () => - createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - return node.data.outputs[fieldName]; - }), - [fieldName, nodeId] - ); - - const fieldTemplate = useAppSelector(selector); - - return fieldTemplate; -}; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputTemplate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputTemplate.ts index 11f592b399..cb154071e9 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputTemplate.ts @@ -1,20 +1,15 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldOutputTemplate } from 'features/nodes/store/selectors'; +import type { FieldOutputTemplate } from 'features/nodes/types/field'; import { useMemo } from 'react'; -export const useFieldOutputTemplate = (nodeId: string, fieldName: string) => { +export const useFieldOutputTemplate = (nodeId: string, fieldName: string): FieldOutputTemplate | null => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; - } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - return nodeTemplate?.outputs[fieldName]; + createMemoizedSelector(selectNodesSlice, (nodes) => { + return selectFieldOutputTemplate(nodes, nodeId, fieldName); }), [fieldName, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts index 663821da81..7be4ecfd4d 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts @@ -1,21 +1,22 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { KIND_MAP } from 'features/nodes/types/constants'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors'; +import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field'; import { useMemo } from 'react'; -export const useFieldTemplate = (nodeId: string, fieldName: string, kind: 'input' | 'output') => { +export const useFieldTemplate = ( + nodeId: string, + fieldName: string, + kind: 'inputs' | 'outputs' +): FieldInputTemplate | FieldOutputTemplate | null => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; + createMemoizedSelector(selectNodesSlice, (nodes) => { + if (kind === 'inputs') { + return selectFieldInputTemplate(nodes, nodeId, fieldName); } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - return nodeTemplate?.[KIND_MAP[kind]][fieldName]; + return selectFieldOutputTemplate(nodes, nodeId, fieldName); }), [fieldName, kind, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts index cfdcda6efa..e41e019572 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts @@ -1,21 +1,17 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { KIND_MAP } from 'features/nodes/types/constants'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useFieldTemplateTitle = (nodeId: string, fieldName: string, kind: 'input' | 'output') => { +export const useFieldTemplateTitle = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): string | null => { const selector = useMemo( () => - createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; + createSelector(selectNodesSlice, (nodes) => { + if (kind === 'inputs') { + return selectFieldInputTemplate(nodes, nodeId, fieldName)?.title ?? null; } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - return nodeTemplate?.[KIND_MAP[kind]][fieldName]?.title; + return selectFieldOutputTemplate(nodes, nodeId, fieldName)?.title ?? null; }), [fieldName, kind, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts index a834726a13..a71a4d044e 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts @@ -1,20 +1,18 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { KIND_MAP } from 'features/nodes/types/constants'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectFieldInputTemplate, selectFieldOutputTemplate } from 'features/nodes/store/selectors'; +import type { FieldType } from 'features/nodes/types/field'; import { useMemo } from 'react'; -export const useFieldType = (nodeId: string, fieldName: string, kind: 'input' | 'output') => { +export const useFieldType = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): FieldType | null => { const selector = useMemo( () => createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return; + if (kind === 'inputs') { + return selectFieldInputTemplate(nodes, nodeId, fieldName)?.type ?? null; } - const field = node.data[KIND_MAP[kind]][fieldName]; - return field?.type; + return selectFieldOutputTemplate(nodes, nodeId, fieldName)?.type ?? null; }), [fieldName, kind, nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts index a8019c92d6..71344197d5 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts @@ -1,13 +1,12 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate'; -const selector = createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => +const selector = createSelector(selectNodesSlice, (nodes) => nodes.nodes.filter(isInvocationNode).some((node) => { - const template = nodeTemplates.templates[node.data.type]; + const template = nodes.templates[node.data.type]; if (!template) { return false; } diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useHasImageOutput.ts b/invokeai/frontend/web/src/features/nodes/hooks/useHasImageOutput.ts index 617e713c7c..3ac3cabb22 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useHasImageOutput.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useHasImageOutput.ts @@ -1,24 +1,21 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; import { some } from 'lodash-es'; import { useMemo } from 'react'; -export const useHasImageOutput = (nodeId: string) => { +export const useHasImageOutput = (nodeId: string): boolean => { const selector = useMemo( () => createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } + const template = selectNodeTemplate(nodes, nodeId); return some( - node.data.outputs, + template?.outputs, (output) => output.type.name === 'ImageField' && // the image primitive node (node type "image") does not actually save the image, do not show the image-saving checkboxes - node.data.type !== 'image' + template?.type !== 'image' ); }), [nodeId] diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsIntermediate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsIntermediate.ts index 729bfa0cea..3fad0a2a86 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsIntermediate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsIntermediate.ts @@ -1,18 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeData } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useIsIntermediate = (nodeId: string) => { +export const useIsIntermediate = (nodeId: string): boolean => { const selector = useMemo( () => createSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } - return node.data.isIntermediate; + return selectNodeData(nodes, nodeId)?.isIntermediate ?? false; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts index 39a8abbe7a..ded05c7b9b 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts @@ -1,11 +1,10 @@ // TODO: enable this at some point -import { useAppSelector } from 'app/store/storeHooks'; +import { useAppSelector, useAppStore } from 'app/store/storeHooks'; import { getIsGraphAcyclic } from 'features/nodes/store/util/getIsGraphAcyclic'; import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes'; import type { InvocationNodeData } from 'features/nodes/types/invocation'; import { useCallback } from 'react'; import type { Connection, Node } from 'reactflow'; -import { useReactFlow } from 'reactflow'; /** * NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts` @@ -13,39 +12,34 @@ import { useReactFlow } from 'reactflow'; */ export const useIsValidConnection = () => { - const flow = useReactFlow(); + const store = useAppStore(); const shouldValidateGraph = useAppSelector((s) => s.nodes.shouldValidateGraph); const isValidConnection = useCallback( ({ source, sourceHandle, target, targetHandle }: Connection): boolean => { - const edges = flow.getEdges(); - const nodes = flow.getNodes(); // Connection must have valid targets if (!(source && sourceHandle && target && targetHandle)) { return false; } - // Find the source and target nodes - const sourceNode = flow.getNode(source) as Node; - const targetNode = flow.getNode(target) as Node; - - // Conditional guards against undefined nodes/handles - if (!(sourceNode && targetNode && sourceNode.data && targetNode.data)) { - return false; - } - - const sourceField = sourceNode.data.outputs[sourceHandle]; - const targetField = targetNode.data.inputs[targetHandle]; - - if (!sourceField || !targetField) { - // something has gone terribly awry - return false; - } - if (source === target) { // Don't allow nodes to connect to themselves, even if validation is disabled return false; } + const state = store.getState(); + const { nodes, edges, templates } = state.nodes; + + // Find the source and target nodes + const sourceNode = nodes.find((node) => node.id === source) as Node; + const targetNode = nodes.find((node) => node.id === target) as Node; + const sourceFieldTemplate = templates[sourceNode.data.type]?.outputs[sourceHandle]; + const targetFieldTemplate = templates[targetNode.data.type]?.inputs[targetHandle]; + + // Conditional guards against undefined nodes/handles + if (!(sourceFieldTemplate && targetFieldTemplate)) { + return false; + } + if (!shouldValidateGraph) { // manual override! return true; @@ -69,20 +63,20 @@ export const useIsValidConnection = () => { return edge.target === target && edge.targetHandle === targetHandle; }) && // except CollectionItem inputs can have multiples - targetField.type.name !== 'CollectionItemField' + targetFieldTemplate.type.name !== 'CollectionItemField' ) { return false; } // Must use the originalType here if it exists - if (!validateSourceAndTargetTypes(sourceField.type, targetField.type)) { + if (!validateSourceAndTargetTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) { return false; } // Graphs much be acyclic (no loops!) return getIsGraphAcyclic(source, target, nodes, edges); }, - [flow, shouldValidateGraph] + [shouldValidateGraph, store] ); return isValidConnection; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeClassification.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeClassification.ts index c61721030e..bab8ff3f19 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeClassification.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeClassification.ts @@ -1,20 +1,15 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; +import type { Classification } from 'features/nodes/types/common'; import { useMemo } from 'react'; -export const useNodeClassification = (nodeId: string) => { +export const useNodeClassification = (nodeId: string): Classification | null => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - return nodeTemplate?.classification; + createSelector(selectNodesSlice, (nodes) => { + return selectNodeTemplate(nodes, nodeId)?.classification ?? null; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeData.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeData.ts index c507def5ee..fa21008ff8 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeData.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeData.ts @@ -1,14 +1,15 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; +import { selectNodeData } from 'features/nodes/store/selectors'; +import type { InvocationNodeData } from 'features/nodes/types/invocation'; import { useMemo } from 'react'; -export const useNodeData = (nodeId: string) => { +export const useNodeData = (nodeId: string): InvocationNodeData | null => { const selector = useMemo( () => createMemoizedSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - return node?.data; + return selectNodeData(nodes, nodeId); }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts index c5fc43742a..31dcb9c466 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts @@ -1,19 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeData } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; export const useNodeLabel = (nodeId: string) => { const selector = useMemo( () => createSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } - - return node.data.label; + return selectNodeData(nodes, nodeId)?.label ?? null; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts index e6efa667f1..aa0294f70f 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts @@ -1,21 +1,20 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectInvocationNode, selectNodeTemplate } from 'features/nodes/store/selectors'; import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate'; import { useMemo } from 'react'; export const useNodeNeedsUpdate = (nodeId: string) => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - const template = nodeTemplates.templates[node?.data.type ?? '']; - if (isInvocationNode(node) && template) { - return getNeedsUpdate(node, template); + createMemoizedSelector(selectNodesSlice, (nodes) => { + const node = selectInvocationNode(nodes, nodeId); + const template = selectNodeTemplate(nodes, nodeId); + if (!node || !template) { + return false; } - return false; + return getNeedsUpdate(node, template); }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodePack.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodePack.ts index ca3dd5cfdf..5c920866e9 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodePack.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodePack.ts @@ -1,18 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeData } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useNodePack = (nodeId: string) => { +export const useNodePack = (nodeId: string): string | null => { const selector = useMemo( () => createSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } - return node.data.nodePack; + return selectNodeData(nodes, nodeId)?.nodePack ?? null; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplate.ts index 7544cbff46..866c9275fb 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplate.ts @@ -1,16 +1,15 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; +import type { InvocationTemplate } from 'features/nodes/types/invocation'; import { useMemo } from 'react'; -export const useNodeTemplate = (nodeId: string) => { +export const useNodeTemplate = (nodeId: string): InvocationTemplate | null => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; - return nodeTemplate; + createSelector(selectNodesSlice, (nodes) => { + return selectNodeTemplate(nodes, nodeId); }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts index 8fd1345f6f..a0c870f694 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts @@ -1,14 +1,14 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; +import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; import type { InvocationTemplate } from 'features/nodes/types/invocation'; import { useMemo } from 'react'; -export const useNodeTemplateByType = (type: string) => { +export const useNodeTemplateByType = (type: string): InvocationTemplate | null => { const selector = useMemo( () => - createMemoizedSelector(selectNodeTemplatesSlice, (nodeTemplates): InvocationTemplate | undefined => { - return nodeTemplates.templates[type]; + createSelector(selectNodesSlice, (nodes) => { + return nodes.templates[type] ?? null; }), [type] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts index 15d2ec38c3..120b8c758b 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts @@ -1,21 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; -export const useNodeTemplateTitle = (nodeId: string) => { +export const useNodeTemplateTitle = (nodeId: string): string | null => { const selector = useMemo( () => - createSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } - const nodeTemplate = node ? nodeTemplates.templates[node.data.type] : undefined; - - return nodeTemplate?.title; + createSelector(selectNodesSlice, (nodes) => { + return selectNodeTemplate(nodes, nodeId)?.title ?? null; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts index e352bd8b90..24863080a7 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts @@ -1,8 +1,8 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; +import { EMPTY_ARRAY } from 'app/store/util'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeTemplatesSlice } from 'features/nodes/store/nodeTemplatesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeTemplate } from 'features/nodes/store/selectors'; import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames'; import { map } from 'lodash-es'; import { useMemo } from 'react'; @@ -10,17 +10,13 @@ import { useMemo } from 'react'; export const useOutputFieldNames = (nodeId: string) => { const selector = useMemo( () => - createMemoizedSelector(selectNodesSlice, selectNodeTemplatesSlice, (nodes, nodeTemplates) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return []; - } - const nodeTemplate = nodeTemplates.templates[node.data.type]; - if (!nodeTemplate) { - return []; + createSelector(selectNodesSlice, (nodes) => { + const template = selectNodeTemplate(nodes, nodeId); + if (!template) { + return EMPTY_ARRAY; } - return getSortedFilteredFieldNames(map(nodeTemplate.outputs)); + return getSortedFilteredFieldNames(map(template.outputs)); }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useUseCache.ts b/invokeai/frontend/web/src/features/nodes/hooks/useUseCache.ts index edfc990882..aaca80039b 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useUseCache.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useUseCache.ts @@ -1,18 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { selectNodeData } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; export const useUseCache = (nodeId: string) => { const selector = useMemo( () => createSelector(selectNodesSlice, (nodes) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - if (!isInvocationNode(node)) { - return false; - } - return node.data.useCache; + return selectNodeData(nodes, nodeId)?.useCache ?? false; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useWorkflowWatcher.ts b/invokeai/frontend/web/src/features/nodes/hooks/useWorkflowWatcher.ts index 0e4806d81b..5d79c15442 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useWorkflowWatcher.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useWorkflowWatcher.ts @@ -2,14 +2,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; import { selectWorkflowSlice } from 'features/nodes/store/workflowSlice'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; import type { BuildWorkflowArg } from 'features/nodes/util/workflow/buildWorkflow'; import { buildWorkflowFast } from 'features/nodes/util/workflow/buildWorkflow'; import { debounce } from 'lodash-es'; import { atom } from 'nanostores'; import { useEffect } from 'react'; -export const $builtWorkflow = atom(null); +export const $builtWorkflow = atom(null); const debouncedBuildWorkflow = debounce((arg: BuildWorkflowArg) => { $builtWorkflow.set(buildWorkflowFast(arg)); diff --git a/invokeai/frontend/web/src/features/nodes/store/actions.ts b/invokeai/frontend/web/src/features/nodes/store/actions.ts index 00457494bf..b32a3ba997 100644 --- a/invokeai/frontend/web/src/features/nodes/store/actions.ts +++ b/invokeai/frontend/web/src/features/nodes/store/actions.ts @@ -1,5 +1,5 @@ import { createAction, isAnyOf } from '@reduxjs/toolkit'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; import type { Graph } from 'services/api/types'; export const textToImageGraphBuilt = createAction('nodes/textToImageGraphBuilt'); @@ -21,4 +21,4 @@ export const workflowLoadRequested = createAction<{ export const updateAllNodesRequested = createAction('nodes/updateAllNodesRequested'); -export const workflowLoaded = createAction('workflow/workflowLoaded'); +export const workflowLoaded = createAction('workflow/workflowLoaded'); diff --git a/invokeai/frontend/web/src/features/nodes/store/nodeTemplatesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodeTemplatesSlice.ts deleted file mode 100644 index c211131aab..0000000000 --- a/invokeai/frontend/web/src/features/nodes/store/nodeTemplatesSlice.ts +++ /dev/null @@ -1,24 +0,0 @@ -import type { PayloadAction } from '@reduxjs/toolkit'; -import { createSlice } from '@reduxjs/toolkit'; -import type { RootState } from 'app/store/store'; -import type { InvocationTemplate } from 'features/nodes/types/invocation'; - -import type { NodeTemplatesState } from './types'; - -export const initialNodeTemplatesState: NodeTemplatesState = { - templates: {}, -}; - -export const nodesTemplatesSlice = createSlice({ - name: 'nodeTemplates', - initialState: initialNodeTemplatesState, - reducers: { - nodeTemplatesBuilt: (state, action: PayloadAction>) => { - state.templates = action.payload; - }, - }, -}); - -export const { nodeTemplatesBuilt } = nodesTemplatesSlice.actions; - -export const selectNodeTemplatesSlice = (state: RootState) => state.nodeTemplates; diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index aee01b381b..6b596da063 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -42,7 +42,7 @@ import { zT2IAdapterModelFieldValue, zVAEModelFieldValue, } from 'features/nodes/types/field'; -import type { AnyNode, NodeExecutionState } from 'features/nodes/types/invocation'; +import type { AnyNode, InvocationTemplate, NodeExecutionState } from 'features/nodes/types/invocation'; import { isInvocationNode, isNotesNode, zNodeStatus } from 'features/nodes/types/invocation'; import { cloneDeep, forEach } from 'lodash-es'; import type { @@ -92,6 +92,7 @@ export const initialNodesState: NodesState = { _version: 1, nodes: [], edges: [], + templates: {}, connectionStartParams: null, connectionStartFieldType: null, connectionMade: false, @@ -190,6 +191,7 @@ export const nodesSlice = createSlice({ node, state.nodes, state.edges, + state.templates, nodeId, handleId, handleType, @@ -224,12 +226,12 @@ export const nodesSlice = createSlice({ if (!nodeId || !handleId) { return; } - const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId); - const node = state.nodes?.[nodeIndex]; + const node = state.nodes.find((n) => n.id === nodeId); if (!isInvocationNode(node)) { return; } - const field = handleType === 'source' ? node.data.outputs[handleId] : node.data.inputs[handleId]; + const template = state.templates[node.data.type]; + const field = handleType === 'source' ? template?.outputs[handleId] : template?.inputs[handleId]; state.connectionStartFieldType = field?.type ?? null; }, connectionMade: (state, action: PayloadAction) => { @@ -260,6 +262,7 @@ export const nodesSlice = createSlice({ mouseOverNode, state.nodes, state.edges, + state.templates, nodeId, handleId, handleType, @@ -677,6 +680,9 @@ export const nodesSlice = createSlice({ selectionModeChanged: (state, action: PayloadAction) => { state.selectionMode = action.payload ? SelectionMode.Full : SelectionMode.Partial; }, + nodeTemplatesBuilt: (state, action: PayloadAction>) => { + state.templates = action.payload; + }, }, extraReducers: (builder) => { builder.addCase(workflowLoaded, (state, action) => { @@ -808,6 +814,7 @@ export const { shouldValidateGraphChanged, viewportChanged, edgeAdded, + nodeTemplatesBuilt, } = nodesSlice.actions; // This is used for tracking `state.workflow.isTouched` diff --git a/invokeai/frontend/web/src/features/nodes/store/selectors.ts b/invokeai/frontend/web/src/features/nodes/store/selectors.ts new file mode 100644 index 0000000000..90675d6270 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/selectors.ts @@ -0,0 +1,51 @@ +import type { NodesState } from 'features/nodes/store/types'; +import type { FieldInputInstance, FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field'; +import type { InvocationNode, InvocationNodeData, InvocationTemplate } from 'features/nodes/types/invocation'; +import { isInvocationNode } from 'features/nodes/types/invocation'; + +export const selectInvocationNode = (nodesSlice: NodesState, nodeId: string): InvocationNode | null => { + const node = nodesSlice.nodes.find((node) => node.id === nodeId); + if (!isInvocationNode(node)) { + return null; + } + return node; +}; + +export const selectNodeData = (nodesSlice: NodesState, nodeId: string): InvocationNodeData | null => { + return selectInvocationNode(nodesSlice, nodeId)?.data ?? null; +}; + +export const selectNodeTemplate = (nodesSlice: NodesState, nodeId: string): InvocationTemplate | null => { + const node = selectInvocationNode(nodesSlice, nodeId); + if (!node) { + return null; + } + return nodesSlice.templates[node.data.type] ?? null; +}; + +export const selectFieldInputInstance = ( + nodesSlice: NodesState, + nodeId: string, + fieldName: string +): FieldInputInstance | null => { + const data = selectNodeData(nodesSlice, nodeId); + return data?.inputs[fieldName] ?? null; +}; + +export const selectFieldInputTemplate = ( + nodesSlice: NodesState, + nodeId: string, + fieldName: string +): FieldInputTemplate | null => { + const template = selectNodeTemplate(nodesSlice, nodeId); + return template?.inputs[fieldName] ?? null; +}; + +export const selectFieldOutputTemplate = ( + nodesSlice: NodesState, + nodeId: string, + fieldName: string +): FieldOutputTemplate | null => { + const template = selectNodeTemplate(nodesSlice, nodeId); + return template?.outputs[fieldName] ?? null; +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/types.ts b/invokeai/frontend/web/src/features/nodes/store/types.ts index 8b0de447e4..1a040d2c70 100644 --- a/invokeai/frontend/web/src/features/nodes/store/types.ts +++ b/invokeai/frontend/web/src/features/nodes/store/types.ts @@ -5,13 +5,14 @@ import type { InvocationTemplate, NodeExecutionState, } from 'features/nodes/types/invocation'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; import type { OnConnectStartParams, SelectionMode, Viewport, XYPosition } from 'reactflow'; export type NodesState = { _version: 1; nodes: AnyNode[]; edges: InvocationNodeEdge[]; + templates: Record; connectionStartParams: OnConnectStartParams | null; connectionStartFieldType: FieldType | null; connectionMade: boolean; @@ -38,7 +39,7 @@ export type FieldIdentifierWithValue = FieldIdentifier & { value: StatefulFieldValue; }; -export type WorkflowsState = Omit & { +export type WorkflowsState = Omit & { _version: 1; isTouched: boolean; mode: WorkflowMode; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts b/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts index 9f2c37a2ad..ef899c5f41 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts @@ -1,4 +1,6 @@ -import type { FieldInputInstance, FieldOutputInstance, FieldType } from 'features/nodes/types/field'; +import type { FieldInputTemplate, FieldOutputTemplate, FieldType } from 'features/nodes/types/field'; +import type { AnyNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation'; +import { isInvocationNode } from 'features/nodes/types/invocation'; import type { Connection, Edge, HandleType, Node } from 'reactflow'; import { getIsGraphAcyclic } from './getIsGraphAcyclic'; @@ -9,7 +11,7 @@ const isValidConnection = ( handleCurrentType: HandleType, handleCurrentFieldType: FieldType, node: Node, - handle: FieldInputInstance | FieldOutputInstance + handle: FieldInputTemplate | FieldOutputTemplate ) => { let isValidConnection = true; if (handleCurrentType === 'source') { @@ -38,24 +40,31 @@ const isValidConnection = ( }; export const findConnectionToValidHandle = ( - node: Node, - nodes: Node[], - edges: Edge[], + node: AnyNode, + nodes: AnyNode[], + edges: InvocationNodeEdge[], + templates: Record, handleCurrentNodeId: string, handleCurrentName: string, handleCurrentType: HandleType, handleCurrentFieldType: FieldType ): Connection | null => { - if (node.id === handleCurrentNodeId) { + if (node.id === handleCurrentNodeId || !isInvocationNode(node)) { return null; } - const handles = handleCurrentType === 'source' ? node.data.inputs : node.data.outputs; + const template = templates[node.data.type]; + + if (!template) { + return null; + } + + const handles = handleCurrentType === 'source' ? template.inputs : template.outputs; //Prioritize handles whos name matches the node we're coming from - if (handles[handleCurrentName]) { - const handle = handles[handleCurrentName]; + const handle = handles[handleCurrentName]; + if (handle) { const sourceID = handleCurrentType === 'source' ? handleCurrentNodeId : node.id; const targetID = handleCurrentType === 'source' ? node.id : handleCurrentNodeId; const sourceHandle = handleCurrentType === 'source' ? handleCurrentName : handle.name; @@ -77,6 +86,9 @@ export const findConnectionToValidHandle = ( for (const handleName in handles) { const handle = handles[handleName]; + if (!handle) { + continue; + } const sourceID = handleCurrentType === 'source' ? handleCurrentNodeId : node.id; const targetID = handleCurrentType === 'source' ? node.id : handleCurrentNodeId; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts index 8575932cbd..d6ea0d9c86 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts @@ -16,7 +16,7 @@ export const makeConnectionErrorSelector = ( nodeId: string, fieldName: string, handleType: HandleType, - fieldType?: FieldType + fieldType?: FieldType | null ) => { return createSelector(selectNodesSlice, (nodesSlice) => { if (!fieldType) { diff --git a/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts b/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts index 2978f25138..4f40a68e1f 100644 --- a/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts @@ -10,10 +10,10 @@ import type { } from 'features/nodes/store/types'; import type { FieldIdentifier } from 'features/nodes/types/field'; import { isInvocationNode } from 'features/nodes/types/invocation'; -import type { WorkflowCategory, WorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowCategory, WorkflowV3 } from 'features/nodes/types/workflow'; import { cloneDeep, isEqual, omit, uniqBy } from 'lodash-es'; -export const blankWorkflow: Omit = { +export const blankWorkflow: Omit = { name: '', author: '', description: '', @@ -22,7 +22,7 @@ export const blankWorkflow: Omit = { tags: '', notes: '', exposedFields: [], - meta: { version: '2.0.0', category: 'user' }, + meta: { version: '3.0.0', category: 'user' }, id: undefined, }; diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts index 38f1af55dd..aa6164d6e5 100644 --- a/invokeai/frontend/web/src/features/nodes/types/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -46,20 +46,11 @@ export type FieldInput = z.infer; export const zFieldUIComponent = z.enum(['none', 'textarea', 'slider']); export type FieldUIComponent = z.infer; -export const zFieldInstanceBase = z.object({ - id: z.string().trim().min(1), +export const zFieldInputInstanceBase = z.object({ name: z.string().trim().min(1), -}); -export const zFieldInputInstanceBase = zFieldInstanceBase.extend({ - fieldKind: z.literal('input'), label: z.string().nullish(), }); -export const zFieldOutputInstanceBase = zFieldInstanceBase.extend({ - fieldKind: z.literal('output'), -}); -export type FieldInstanceBase = z.infer; export type FieldInputInstanceBase = z.infer; -export type FieldOutputInstanceBase = z.infer; export const zFieldTemplateBase = z.object({ name: z.string().min(1), @@ -102,12 +93,8 @@ export const zIntegerFieldType = zFieldTypeBase.extend({ }); export const zIntegerFieldValue = z.number().int(); export const zIntegerFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zIntegerFieldType, value: zIntegerFieldValue, }); -export const zIntegerFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zIntegerFieldType, -}); export const zIntegerFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zIntegerFieldType, default: zIntegerFieldValue, @@ -136,12 +123,8 @@ export const zFloatFieldType = zFieldTypeBase.extend({ }); export const zFloatFieldValue = z.number(); export const zFloatFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zFloatFieldType, value: zFloatFieldValue, }); -export const zFloatFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zFloatFieldType, -}); export const zFloatFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zFloatFieldType, default: zFloatFieldValue, @@ -157,7 +140,6 @@ export const zFloatFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type FloatFieldType = z.infer; export type FloatFieldValue = z.infer; export type FloatFieldInputInstance = z.infer; -export type FloatFieldOutputInstance = z.infer; export type FloatFieldInputTemplate = z.infer; export type FloatFieldOutputTemplate = z.infer; export const isFloatFieldInputInstance = (val: unknown): val is FloatFieldInputInstance => @@ -172,12 +154,8 @@ export const zStringFieldType = zFieldTypeBase.extend({ }); export const zStringFieldValue = z.string(); export const zStringFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zStringFieldType, value: zStringFieldValue, }); -export const zStringFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zStringFieldType, -}); export const zStringFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zStringFieldType, default: zStringFieldValue, @@ -191,7 +169,6 @@ export const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type StringFieldType = z.infer; export type StringFieldValue = z.infer; export type StringFieldInputInstance = z.infer; -export type StringFieldOutputInstance = z.infer; export type StringFieldInputTemplate = z.infer; export type StringFieldOutputTemplate = z.infer; export const isStringFieldInputInstance = (val: unknown): val is StringFieldInputInstance => @@ -206,12 +183,8 @@ export const zBooleanFieldType = zFieldTypeBase.extend({ }); export const zBooleanFieldValue = z.boolean(); export const zBooleanFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zBooleanFieldType, value: zBooleanFieldValue, }); -export const zBooleanFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zBooleanFieldType, -}); export const zBooleanFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zBooleanFieldType, default: zBooleanFieldValue, @@ -222,7 +195,6 @@ export const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type BooleanFieldType = z.infer; export type BooleanFieldValue = z.infer; export type BooleanFieldInputInstance = z.infer; -export type BooleanFieldOutputInstance = z.infer; export type BooleanFieldInputTemplate = z.infer; export type BooleanFieldOutputTemplate = z.infer; export const isBooleanFieldInputInstance = (val: unknown): val is BooleanFieldInputInstance => @@ -237,12 +209,8 @@ export const zEnumFieldType = zFieldTypeBase.extend({ }); export const zEnumFieldValue = z.string(); export const zEnumFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zEnumFieldType, value: zEnumFieldValue, }); -export const zEnumFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zEnumFieldType, -}); export const zEnumFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zEnumFieldType, default: zEnumFieldValue, @@ -255,7 +223,6 @@ export const zEnumFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type EnumFieldType = z.infer; export type EnumFieldValue = z.infer; export type EnumFieldInputInstance = z.infer; -export type EnumFieldOutputInstance = z.infer; export type EnumFieldInputTemplate = z.infer; export type EnumFieldOutputTemplate = z.infer; export const isEnumFieldInputInstance = (val: unknown): val is EnumFieldInputInstance => @@ -270,12 +237,8 @@ export const zImageFieldType = zFieldTypeBase.extend({ }); export const zImageFieldValue = zImageField.optional(); export const zImageFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zImageFieldType, value: zImageFieldValue, }); -export const zImageFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zImageFieldType, -}); export const zImageFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zImageFieldType, default: zImageFieldValue, @@ -286,7 +249,6 @@ export const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type ImageFieldType = z.infer; export type ImageFieldValue = z.infer; export type ImageFieldInputInstance = z.infer; -export type ImageFieldOutputInstance = z.infer; export type ImageFieldInputTemplate = z.infer; export type ImageFieldOutputTemplate = z.infer; export const isImageFieldInputInstance = (val: unknown): val is ImageFieldInputInstance => @@ -301,12 +263,8 @@ export const zBoardFieldType = zFieldTypeBase.extend({ }); export const zBoardFieldValue = zBoardField.optional(); export const zBoardFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zBoardFieldType, value: zBoardFieldValue, }); -export const zBoardFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zBoardFieldType, -}); export const zBoardFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zBoardFieldType, default: zBoardFieldValue, @@ -317,7 +275,6 @@ export const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type BoardFieldType = z.infer; export type BoardFieldValue = z.infer; export type BoardFieldInputInstance = z.infer; -export type BoardFieldOutputInstance = z.infer; export type BoardFieldInputTemplate = z.infer; export type BoardFieldOutputTemplate = z.infer; export const isBoardFieldInputInstance = (val: unknown): val is BoardFieldInputInstance => @@ -332,12 +289,8 @@ export const zColorFieldType = zFieldTypeBase.extend({ }); export const zColorFieldValue = zColorField.optional(); export const zColorFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zColorFieldType, value: zColorFieldValue, }); -export const zColorFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zColorFieldType, -}); export const zColorFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zColorFieldType, default: zColorFieldValue, @@ -348,7 +301,6 @@ export const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type ColorFieldType = z.infer; export type ColorFieldValue = z.infer; export type ColorFieldInputInstance = z.infer; -export type ColorFieldOutputInstance = z.infer; export type ColorFieldInputTemplate = z.infer; export type ColorFieldOutputTemplate = z.infer; export const isColorFieldInputInstance = (val: unknown): val is ColorFieldInputInstance => @@ -363,12 +315,8 @@ export const zMainModelFieldType = zFieldTypeBase.extend({ }); export const zMainModelFieldValue = zMainModelField.optional(); export const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zMainModelFieldType, value: zMainModelFieldValue, }); -export const zMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zMainModelFieldType, -}); export const zMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zMainModelFieldType, default: zMainModelFieldValue, @@ -379,7 +327,6 @@ export const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type MainModelFieldType = z.infer; export type MainModelFieldValue = z.infer; export type MainModelFieldInputInstance = z.infer; -export type MainModelFieldOutputInstance = z.infer; export type MainModelFieldInputTemplate = z.infer; export type MainModelFieldOutputTemplate = z.infer; export const isMainModelFieldInputInstance = (val: unknown): val is MainModelFieldInputInstance => @@ -394,12 +341,8 @@ export const zSDXLMainModelFieldType = zFieldTypeBase.extend({ }); export const zSDXLMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only. export const zSDXLMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zSDXLMainModelFieldType, value: zSDXLMainModelFieldValue, }); -export const zSDXLMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zSDXLMainModelFieldType, -}); export const zSDXLMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zSDXLMainModelFieldType, default: zSDXLMainModelFieldValue, @@ -410,7 +353,6 @@ export const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend export type SDXLMainModelFieldType = z.infer; export type SDXLMainModelFieldValue = z.infer; export type SDXLMainModelFieldInputInstance = z.infer; -export type SDXLMainModelFieldOutputInstance = z.infer; export type SDXLMainModelFieldInputTemplate = z.infer; export type SDXLMainModelFieldOutputTemplate = z.infer; export const isSDXLMainModelFieldInputInstance = (val: unknown): val is SDXLMainModelFieldInputInstance => @@ -425,12 +367,8 @@ export const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({ }); export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only. export const zSDXLRefinerModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zSDXLRefinerModelFieldType, value: zSDXLRefinerModelFieldValue, }); -export const zSDXLRefinerModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zSDXLRefinerModelFieldType, -}); export const zSDXLRefinerModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zSDXLRefinerModelFieldType, default: zSDXLRefinerModelFieldValue, @@ -441,7 +379,6 @@ export const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.ext export type SDXLRefinerModelFieldType = z.infer; export type SDXLRefinerModelFieldValue = z.infer; export type SDXLRefinerModelFieldInputInstance = z.infer; -export type SDXLRefinerModelFieldOutputInstance = z.infer; export type SDXLRefinerModelFieldInputTemplate = z.infer; export type SDXLRefinerModelFieldOutputTemplate = z.infer; export const isSDXLRefinerModelFieldInputInstance = (val: unknown): val is SDXLRefinerModelFieldInputInstance => @@ -456,12 +393,8 @@ export const zVAEModelFieldType = zFieldTypeBase.extend({ }); export const zVAEModelFieldValue = zVAEModelField.optional(); export const zVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zVAEModelFieldType, value: zVAEModelFieldValue, }); -export const zVAEModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zVAEModelFieldType, -}); export const zVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zVAEModelFieldType, default: zVAEModelFieldValue, @@ -472,7 +405,6 @@ export const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type VAEModelFieldType = z.infer; export type VAEModelFieldValue = z.infer; export type VAEModelFieldInputInstance = z.infer; -export type VAEModelFieldOutputInstance = z.infer; export type VAEModelFieldInputTemplate = z.infer; export type VAEModelFieldOutputTemplate = z.infer; export const isVAEModelFieldInputInstance = (val: unknown): val is VAEModelFieldInputInstance => @@ -487,12 +419,8 @@ export const zLoRAModelFieldType = zFieldTypeBase.extend({ }); export const zLoRAModelFieldValue = zLoRAModelField.optional(); export const zLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zLoRAModelFieldType, value: zLoRAModelFieldValue, }); -export const zLoRAModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zLoRAModelFieldType, -}); export const zLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zLoRAModelFieldType, default: zLoRAModelFieldValue, @@ -503,7 +431,6 @@ export const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type LoRAModelFieldType = z.infer; export type LoRAModelFieldValue = z.infer; export type LoRAModelFieldInputInstance = z.infer; -export type LoRAModelFieldOutputInstance = z.infer; export type LoRAModelFieldInputTemplate = z.infer; export type LoRAModelFieldOutputTemplate = z.infer; export const isLoRAModelFieldInputInstance = (val: unknown): val is LoRAModelFieldInputInstance => @@ -518,12 +445,8 @@ export const zControlNetModelFieldType = zFieldTypeBase.extend({ }); export const zControlNetModelFieldValue = zControlNetModelField.optional(); export const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zControlNetModelFieldType, value: zControlNetModelFieldValue, }); -export const zControlNetModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zControlNetModelFieldType, -}); export const zControlNetModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zControlNetModelFieldType, default: zControlNetModelFieldValue, @@ -534,7 +457,6 @@ export const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.exte export type ControlNetModelFieldType = z.infer; export type ControlNetModelFieldValue = z.infer; export type ControlNetModelFieldInputInstance = z.infer; -export type ControlNetModelFieldOutputInstance = z.infer; export type ControlNetModelFieldInputTemplate = z.infer; export type ControlNetModelFieldOutputTemplate = z.infer; export const isControlNetModelFieldInputInstance = (val: unknown): val is ControlNetModelFieldInputInstance => @@ -551,12 +473,8 @@ export const zIPAdapterModelFieldType = zFieldTypeBase.extend({ }); export const zIPAdapterModelFieldValue = zIPAdapterModelField.optional(); export const zIPAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zIPAdapterModelFieldType, value: zIPAdapterModelFieldValue, }); -export const zIPAdapterModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zIPAdapterModelFieldType, -}); export const zIPAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zIPAdapterModelFieldType, default: zIPAdapterModelFieldValue, @@ -567,7 +485,6 @@ export const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.exten export type IPAdapterModelFieldType = z.infer; export type IPAdapterModelFieldValue = z.infer; export type IPAdapterModelFieldInputInstance = z.infer; -export type IPAdapterModelFieldOutputInstance = z.infer; export type IPAdapterModelFieldInputTemplate = z.infer; export type IPAdapterModelFieldOutputTemplate = z.infer; export const isIPAdapterModelFieldInputInstance = (val: unknown): val is IPAdapterModelFieldInputInstance => @@ -584,12 +501,8 @@ export const zT2IAdapterModelFieldType = zFieldTypeBase.extend({ }); export const zT2IAdapterModelFieldValue = zT2IAdapterModelField.optional(); export const zT2IAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zT2IAdapterModelFieldType, value: zT2IAdapterModelFieldValue, }); -export const zT2IAdapterModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zT2IAdapterModelFieldType, -}); export const zT2IAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zT2IAdapterModelFieldType, default: zT2IAdapterModelFieldValue, @@ -600,7 +513,6 @@ export const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.exte export type T2IAdapterModelFieldType = z.infer; export type T2IAdapterModelFieldValue = z.infer; export type T2IAdapterModelFieldInputInstance = z.infer; -export type T2IAdapterModelFieldOutputInstance = z.infer; export type T2IAdapterModelFieldInputTemplate = z.infer; export type T2IAdapterModelFieldOutputTemplate = z.infer; export const isT2IAdapterModelFieldInputInstance = (val: unknown): val is T2IAdapterModelFieldInputInstance => @@ -615,12 +527,8 @@ export const zSchedulerFieldType = zFieldTypeBase.extend({ }); export const zSchedulerFieldValue = zSchedulerField.optional(); export const zSchedulerFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zSchedulerFieldType, value: zSchedulerFieldValue, }); -export const zSchedulerFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zSchedulerFieldType, -}); export const zSchedulerFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zSchedulerFieldType, default: zSchedulerFieldValue, @@ -631,7 +539,6 @@ export const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type SchedulerFieldType = z.infer; export type SchedulerFieldValue = z.infer; export type SchedulerFieldInputInstance = z.infer; -export type SchedulerFieldOutputInstance = z.infer; export type SchedulerFieldInputTemplate = z.infer; export type SchedulerFieldOutputTemplate = z.infer; export const isSchedulerFieldInputInstance = (val: unknown): val is SchedulerFieldInputInstance => @@ -657,12 +564,8 @@ export const zStatelessFieldType = zFieldTypeBase.extend({ }); export const zStatelessFieldValue = z.undefined().catch(undefined); // stateless --> no value, but making this z.never() introduces a lot of extra TS fanagling export const zStatelessFieldInputInstance = zFieldInputInstanceBase.extend({ - type: zStatelessFieldType, value: zStatelessFieldValue, }); -export const zStatelessFieldOutputInstance = zFieldOutputInstanceBase.extend({ - type: zStatelessFieldType, -}); export const zStatelessFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zStatelessFieldType, default: zStatelessFieldValue, @@ -675,7 +578,6 @@ export const zStatelessFieldOutputTemplate = zFieldOutputTemplateBase.extend({ export type StatelessFieldType = z.infer; export type StatelessFieldValue = z.infer; export type StatelessFieldInputInstance = z.infer; -export type StatelessFieldOutputInstance = z.infer; export type StatelessFieldInputTemplate = z.infer; export type StatelessFieldOutputTemplate = z.infer; // #endregion @@ -783,36 +685,6 @@ export const isFieldInputInstance = (val: unknown): val is FieldInputInstance => zFieldInputInstance.safeParse(val).success; // #endregion -// #region StatefulFieldOutputInstance & FieldOutputInstance -export const zStatefulFieldOutputInstance = z.union([ - zIntegerFieldOutputInstance, - zFloatFieldOutputInstance, - zStringFieldOutputInstance, - zBooleanFieldOutputInstance, - zEnumFieldOutputInstance, - zImageFieldOutputInstance, - zBoardFieldOutputInstance, - zMainModelFieldOutputInstance, - zSDXLMainModelFieldOutputInstance, - zSDXLRefinerModelFieldOutputInstance, - zVAEModelFieldOutputInstance, - zLoRAModelFieldOutputInstance, - zControlNetModelFieldOutputInstance, - zIPAdapterModelFieldOutputInstance, - zT2IAdapterModelFieldOutputInstance, - zColorFieldOutputInstance, - zSchedulerFieldOutputInstance, -]); -export type StatefulFieldOutputInstance = z.infer; -export const isStatefulFieldOutputInstance = (val: unknown): val is StatefulFieldOutputInstance => - zStatefulFieldOutputInstance.safeParse(val).success; - -export const zFieldOutputInstance = z.union([zStatefulFieldOutputInstance, zStatelessFieldOutputInstance]); -export type FieldOutputInstance = z.infer; -export const isFieldOutputInstance = (val: unknown): val is FieldOutputInstance => - zFieldOutputInstance.safeParse(val).success; -// #endregion - // #region StatefulFieldInputTemplate & FieldInputTemplate export const zStatefulFieldInputTemplate = z.union([ zIntegerFieldInputTemplate, diff --git a/invokeai/frontend/web/src/features/nodes/types/invocation.ts b/invokeai/frontend/web/src/features/nodes/types/invocation.ts index 86ec70fd9b..5ccb19430d 100644 --- a/invokeai/frontend/web/src/features/nodes/types/invocation.ts +++ b/invokeai/frontend/web/src/features/nodes/types/invocation.ts @@ -2,7 +2,7 @@ import type { Edge, Node } from 'reactflow'; import { z } from 'zod'; import { zClassification, zProgressImage } from './common'; -import { zFieldInputInstance, zFieldInputTemplate, zFieldOutputInstance, zFieldOutputTemplate } from './field'; +import { zFieldInputInstance, zFieldInputTemplate, zFieldOutputTemplate } from './field'; import { zSemVer } from './semver'; // #region InvocationTemplate @@ -25,16 +25,15 @@ export type InvocationTemplate = z.infer; // #region NodeData export const zInvocationNodeData = z.object({ id: z.string().trim().min(1), - type: z.string().trim().min(1), - label: z.string(), - isOpen: z.boolean(), - notes: z.string(), - isIntermediate: z.boolean(), - useCache: z.boolean(), version: zSemVer, nodePack: z.string().min(1).nullish(), + label: z.string(), + notes: z.string(), + type: z.string().trim().min(1), inputs: z.record(zFieldInputInstance), - outputs: z.record(zFieldOutputInstance), + isOpen: z.boolean(), + isIntermediate: z.boolean(), + useCache: z.boolean(), }); export const zNotesNodeData = z.object({ @@ -62,11 +61,12 @@ export type NotesNode = Node; export type CurrentImageNode = Node; export type AnyNode = Node; -export const isInvocationNode = (node?: AnyNode): node is InvocationNode => Boolean(node && node.type === 'invocation'); -export const isNotesNode = (node?: AnyNode): node is NotesNode => Boolean(node && node.type === 'notes'); -export const isCurrentImageNode = (node?: AnyNode): node is CurrentImageNode => +export const isInvocationNode = (node?: AnyNode | null): node is InvocationNode => + Boolean(node && node.type === 'invocation'); +export const isNotesNode = (node?: AnyNode | null): node is NotesNode => Boolean(node && node.type === 'notes'); +export const isCurrentImageNode = (node?: AnyNode | null): node is CurrentImageNode => Boolean(node && node.type === 'current_image'); -export const isInvocationNodeData = (node?: AnyNodeData): node is InvocationNodeData => +export const isInvocationNodeData = (node?: AnyNodeData | null): node is InvocationNodeData => Boolean(node && !['notes', 'current_image'].includes(node.type)); // node.type may be 'notes', 'current_image', or any invocation type // #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/common.ts b/invokeai/frontend/web/src/features/nodes/types/v2/common.ts new file mode 100644 index 0000000000..b524474379 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/common.ts @@ -0,0 +1,188 @@ +import { z } from 'zod'; + +// #region Field data schemas +export const zImageField = z.object({ + image_name: z.string().trim().min(1), +}); +export type ImageField = z.infer; + +export const zBoardField = z.object({ + board_id: z.string().trim().min(1), +}); +export type BoardField = z.infer; + +export const zColorField = z.object({ + r: z.number().int().min(0).max(255), + g: z.number().int().min(0).max(255), + b: z.number().int().min(0).max(255), + a: z.number().int().min(0).max(255), +}); +export type ColorField = z.infer; + +export const zClassification = z.enum(['stable', 'beta', 'prototype']); +export type Classification = z.infer; + +export const zSchedulerField = z.enum([ + 'euler', + 'deis', + 'ddim', + 'ddpm', + 'dpmpp_2s', + 'dpmpp_2m', + 'dpmpp_2m_sde', + 'dpmpp_sde', + 'heun', + 'kdpm_2', + 'lms', + 'pndm', + 'unipc', + 'euler_k', + 'dpmpp_2s_k', + 'dpmpp_2m_k', + 'dpmpp_2m_sde_k', + 'dpmpp_sde_k', + 'heun_k', + 'lms_k', + 'euler_a', + 'kdpm_2_a', + 'lcm', +]); +export type SchedulerField = z.infer; +// #endregion + +// #region Model-related schemas +export const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']); +export const zModelType = z.enum(['main', 'vae', 'lora', 'controlnet', 'embedding']); +export const zModelName = z.string().min(3); +export const zModelIdentifier = z.object({ + model_name: zModelName, + base_model: zBaseModel, +}); +export type BaseModel = z.infer; +export type ModelType = z.infer; +export type ModelIdentifier = z.infer; + +export const zMainModelField = z.object({ + model_name: zModelName, + base_model: zBaseModel, + model_type: z.literal('main'), +}); +export const zSDXLRefinerModelField = z.object({ + model_name: z.string().min(1), + base_model: z.literal('sdxl-refiner'), + model_type: z.literal('main'), +}); +export type MainModelField = z.infer; +export type SDXLRefinerModelField = z.infer; + +export const zSubModelType = z.enum([ + 'unet', + 'text_encoder', + 'text_encoder_2', + 'tokenizer', + 'tokenizer_2', + 'vae', + 'vae_decoder', + 'vae_encoder', + 'scheduler', + 'safety_checker', +]); +export type SubModelType = z.infer; + +export const zVAEModelField = zModelIdentifier; + +export const zModelInfo = zModelIdentifier.extend({ + model_type: zModelType, + submodel: zSubModelType.optional(), +}); +export type ModelInfo = z.infer; + +export const zLoRAModelField = zModelIdentifier; +export type LoRAModelField = z.infer; + +export const zControlNetModelField = zModelIdentifier; +export type ControlNetModelField = z.infer; + +export const zIPAdapterModelField = zModelIdentifier; +export type IPAdapterModelField = z.infer; + +export const zT2IAdapterModelField = zModelIdentifier; +export type T2IAdapterModelField = z.infer; + +export const zLoraInfo = zModelInfo.extend({ + weight: z.number().optional(), +}); +export type LoraInfo = z.infer; + +export const zUNetField = z.object({ + unet: zModelInfo, + scheduler: zModelInfo, + loras: z.array(zLoraInfo), +}); +export type UNetField = z.infer; + +export const zCLIPField = z.object({ + tokenizer: zModelInfo, + text_encoder: zModelInfo, + skipped_layers: z.number(), + loras: z.array(zLoraInfo), +}); +export type CLIPField = z.infer; + +export const zVAEField = z.object({ + vae: zModelInfo, +}); +export type VAEField = z.infer; +// #endregion + +// #region Control Adapters +export const zControlField = z.object({ + image: zImageField, + control_model: zControlNetModelField, + control_weight: z.union([z.number(), z.array(z.number())]).optional(), + begin_step_percent: z.number().optional(), + end_step_percent: z.number().optional(), + control_mode: z.enum(['balanced', 'more_prompt', 'more_control', 'unbalanced']).optional(), + resize_mode: z.enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple']).optional(), +}); +export type ControlField = z.infer; + +export const zIPAdapterField = z.object({ + image: zImageField, + ip_adapter_model: zIPAdapterModelField, + weight: z.number(), + begin_step_percent: z.number().optional(), + end_step_percent: z.number().optional(), +}); +export type IPAdapterField = z.infer; + +export const zT2IAdapterField = z.object({ + image: zImageField, + t2i_adapter_model: zT2IAdapterModelField, + weight: z.union([z.number(), z.array(z.number())]).optional(), + begin_step_percent: z.number().optional(), + end_step_percent: z.number().optional(), + resize_mode: z.enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple']).optional(), +}); +export type T2IAdapterField = z.infer; +// #endregion + +// #region ProgressImage +export const zProgressImage = z.object({ + dataURL: z.string(), + width: z.number().int(), + height: z.number().int(), +}); +export type ProgressImage = z.infer; +// #endregion + +// #region ImageOutput +export const zImageOutput = z.object({ + image: zImageField, + width: z.number().int().gt(0), + height: z.number().int().gt(0), + type: z.literal('image_output'), +}); +export type ImageOutput = z.infer; +export const isImageOutput = (output: unknown): output is ImageOutput => zImageOutput.safeParse(output).success; +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/constants.ts b/invokeai/frontend/web/src/features/nodes/types/v2/constants.ts new file mode 100644 index 0000000000..35ef9e9fd2 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/constants.ts @@ -0,0 +1,80 @@ +import type { Node } from 'reactflow'; + +/** + * How long to wait before showing a tooltip when hovering a field handle. + */ +export const HANDLE_TOOLTIP_OPEN_DELAY = 500; + +/** + * The width of a node in the UI in pixels. + */ +export const NODE_WIDTH = 320; + +/** + * This class name is special - reactflow uses it to identify the drag handle of a node, + * applying the appropriate listeners to it. + */ +export const DRAG_HANDLE_CLASSNAME = 'node-drag-handle'; + +/** + * reactflow-specifc properties shared between all node types. + */ +export const SHARED_NODE_PROPERTIES: Partial = { + dragHandle: `.${DRAG_HANDLE_CLASSNAME}`, +}; + +/** + * Helper for getting the kind of a field. + */ +export const KIND_MAP = { + input: 'inputs' as const, + output: 'outputs' as const, +}; + +/** + * Model types' handles are rendered as squares in the UI. + */ +export const MODEL_TYPES = [ + 'IPAdapterModelField', + 'ControlNetModelField', + 'LoRAModelField', + 'MainModelField', + 'SDXLMainModelField', + 'SDXLRefinerModelField', + 'VaeModelField', + 'UNetField', + 'VaeField', + 'ClipField', + 'T2IAdapterModelField', + 'IPAdapterModelField', +]; + +/** + * Colors for each field type - applies to their handles and edges. + */ +export const FIELD_COLORS: { [key: string]: string } = { + BoardField: 'purple.500', + BooleanField: 'green.500', + ClipField: 'green.500', + ColorField: 'pink.300', + ConditioningField: 'cyan.500', + ControlField: 'teal.500', + ControlNetModelField: 'teal.500', + EnumField: 'blue.500', + FloatField: 'orange.500', + ImageField: 'purple.500', + IntegerField: 'red.500', + IPAdapterField: 'teal.500', + IPAdapterModelField: 'teal.500', + LatentsField: 'pink.500', + LoRAModelField: 'teal.500', + MainModelField: 'teal.500', + SDXLMainModelField: 'teal.500', + SDXLRefinerModelField: 'teal.500', + StringField: 'yellow.500', + T2IAdapterField: 'teal.500', + T2IAdapterModelField: 'teal.500', + UNetField: 'red.500', + VaeField: 'blue.500', + VaeModelField: 'teal.500', +}; diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/error.ts b/invokeai/frontend/web/src/features/nodes/types/v2/error.ts new file mode 100644 index 0000000000..905b487fb0 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/error.ts @@ -0,0 +1,58 @@ +/** + * Invalid Workflow Version Error + * Raised when a workflow version is not recognized. + */ +export class WorkflowVersionError extends Error { + /** + * Create WorkflowVersionError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} +/** + * Workflow Migration Error + * Raised when a workflow migration fails. + */ +export class WorkflowMigrationError extends Error { + /** + * Create WorkflowMigrationError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} + +/** + * Unable to Update Node Error + * Raised when a node cannot be updated. + */ +export class NodeUpdateError extends Error { + /** + * Create NodeUpdateError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} + +/** + * FieldParseError + * Raised when a field cannot be parsed from a field schema. + */ +export class FieldParseError extends Error { + /** + * Create FieldTypeParseError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/field.ts b/invokeai/frontend/web/src/features/nodes/types/v2/field.ts new file mode 100644 index 0000000000..38f1af55dd --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/field.ts @@ -0,0 +1,875 @@ +import { z } from 'zod'; + +import { + zBoardField, + zColorField, + zControlNetModelField, + zImageField, + zIPAdapterModelField, + zLoRAModelField, + zMainModelField, + zSchedulerField, + zT2IAdapterModelField, + zVAEModelField, +} from './common'; + +/** + * zod schemas & inferred types for fields. + * + * These schemas and types are only required for stateful field - fields that have UI components + * and allow the user to directly provide values. + * + * This includes primitive values (numbers, strings, booleans), models, scheduler, etc. + * + * If a field type does not have a UI component, then it does not need to be included here, because + * we never store its value. Such field types will be handled via the "StatelessField" logic. + * + * Fields require: + * - zFieldType - zod schema for the field type + * - zFieldValue - zod schema for the field value + * - zFieldInputInstance - zod schema for the field's input instance + * - zFieldOutputInstance - zod schema for the field's output instance + * - zFieldInputTemplate - zod schema for the field's input template + * - zFieldOutputTemplate - zod schema for the field's output template + * - inferred types for each schema + * - type guards for InputInstance and InputTemplate + * + * These then must be added to the unions at the bottom of this file. + */ + +/** */ + +// #region Base schemas & misc +export const zFieldInput = z.enum(['connection', 'direct', 'any']); +export type FieldInput = z.infer; + +export const zFieldUIComponent = z.enum(['none', 'textarea', 'slider']); +export type FieldUIComponent = z.infer; + +export const zFieldInstanceBase = z.object({ + id: z.string().trim().min(1), + name: z.string().trim().min(1), +}); +export const zFieldInputInstanceBase = zFieldInstanceBase.extend({ + fieldKind: z.literal('input'), + label: z.string().nullish(), +}); +export const zFieldOutputInstanceBase = zFieldInstanceBase.extend({ + fieldKind: z.literal('output'), +}); +export type FieldInstanceBase = z.infer; +export type FieldInputInstanceBase = z.infer; +export type FieldOutputInstanceBase = z.infer; + +export const zFieldTemplateBase = z.object({ + name: z.string().min(1), + title: z.string().min(1), + description: z.string().nullish(), + ui_hidden: z.boolean(), + ui_type: z.string().nullish(), + ui_order: z.number().int().nullish(), +}); +export const zFieldInputTemplateBase = zFieldTemplateBase.extend({ + fieldKind: z.literal('input'), + input: zFieldInput, + required: z.boolean(), + ui_component: zFieldUIComponent.nullish(), + ui_choice_labels: z.record(z.string()).nullish(), +}); +export const zFieldOutputTemplateBase = zFieldTemplateBase.extend({ + fieldKind: z.literal('output'), +}); +export type FieldTemplateBase = z.infer; +export type FieldInputTemplateBase = z.infer; +export type FieldOutputTemplateBase = z.infer; + +export const zFieldTypeBase = z.object({ + isCollection: z.boolean(), + isCollectionOrScalar: z.boolean(), +}); + +export const zFieldIdentifier = z.object({ + nodeId: z.string().trim().min(1), + fieldName: z.string().trim().min(1), +}); +export type FieldIdentifier = z.infer; +export const isFieldIdentifier = (val: unknown): val is FieldIdentifier => zFieldIdentifier.safeParse(val).success; +// #endregion + +// #region IntegerField +export const zIntegerFieldType = zFieldTypeBase.extend({ + name: z.literal('IntegerField'), +}); +export const zIntegerFieldValue = z.number().int(); +export const zIntegerFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zIntegerFieldType, + value: zIntegerFieldValue, +}); +export const zIntegerFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zIntegerFieldType, +}); +export const zIntegerFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zIntegerFieldType, + default: zIntegerFieldValue, + multipleOf: z.number().int().optional(), + maximum: z.number().int().optional(), + exclusiveMaximum: z.number().int().optional(), + minimum: z.number().int().optional(), + exclusiveMinimum: z.number().int().optional(), +}); +export const zIntegerFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zIntegerFieldType, +}); +export type IntegerFieldType = z.infer; +export type IntegerFieldValue = z.infer; +export type IntegerFieldInputInstance = z.infer; +export type IntegerFieldInputTemplate = z.infer; +export const isIntegerFieldInputInstance = (val: unknown): val is IntegerFieldInputInstance => + zIntegerFieldInputInstance.safeParse(val).success; +export const isIntegerFieldInputTemplate = (val: unknown): val is IntegerFieldInputTemplate => + zIntegerFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region FloatField +export const zFloatFieldType = zFieldTypeBase.extend({ + name: z.literal('FloatField'), +}); +export const zFloatFieldValue = z.number(); +export const zFloatFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zFloatFieldType, + value: zFloatFieldValue, +}); +export const zFloatFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zFloatFieldType, +}); +export const zFloatFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zFloatFieldType, + default: zFloatFieldValue, + multipleOf: z.number().optional(), + maximum: z.number().optional(), + exclusiveMaximum: z.number().optional(), + minimum: z.number().optional(), + exclusiveMinimum: z.number().optional(), +}); +export const zFloatFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zFloatFieldType, +}); +export type FloatFieldType = z.infer; +export type FloatFieldValue = z.infer; +export type FloatFieldInputInstance = z.infer; +export type FloatFieldOutputInstance = z.infer; +export type FloatFieldInputTemplate = z.infer; +export type FloatFieldOutputTemplate = z.infer; +export const isFloatFieldInputInstance = (val: unknown): val is FloatFieldInputInstance => + zFloatFieldInputInstance.safeParse(val).success; +export const isFloatFieldInputTemplate = (val: unknown): val is FloatFieldInputTemplate => + zFloatFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region StringField +export const zStringFieldType = zFieldTypeBase.extend({ + name: z.literal('StringField'), +}); +export const zStringFieldValue = z.string(); +export const zStringFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zStringFieldType, + value: zStringFieldValue, +}); +export const zStringFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zStringFieldType, +}); +export const zStringFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zStringFieldType, + default: zStringFieldValue, + maxLength: z.number().int().optional(), + minLength: z.number().int().optional(), +}); +export const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zStringFieldType, +}); + +export type StringFieldType = z.infer; +export type StringFieldValue = z.infer; +export type StringFieldInputInstance = z.infer; +export type StringFieldOutputInstance = z.infer; +export type StringFieldInputTemplate = z.infer; +export type StringFieldOutputTemplate = z.infer; +export const isStringFieldInputInstance = (val: unknown): val is StringFieldInputInstance => + zStringFieldInputInstance.safeParse(val).success; +export const isStringFieldInputTemplate = (val: unknown): val is StringFieldInputTemplate => + zStringFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region BooleanField +export const zBooleanFieldType = zFieldTypeBase.extend({ + name: z.literal('BooleanField'), +}); +export const zBooleanFieldValue = z.boolean(); +export const zBooleanFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zBooleanFieldType, + value: zBooleanFieldValue, +}); +export const zBooleanFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zBooleanFieldType, +}); +export const zBooleanFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zBooleanFieldType, + default: zBooleanFieldValue, +}); +export const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zBooleanFieldType, +}); +export type BooleanFieldType = z.infer; +export type BooleanFieldValue = z.infer; +export type BooleanFieldInputInstance = z.infer; +export type BooleanFieldOutputInstance = z.infer; +export type BooleanFieldInputTemplate = z.infer; +export type BooleanFieldOutputTemplate = z.infer; +export const isBooleanFieldInputInstance = (val: unknown): val is BooleanFieldInputInstance => + zBooleanFieldInputInstance.safeParse(val).success; +export const isBooleanFieldInputTemplate = (val: unknown): val is BooleanFieldInputTemplate => + zBooleanFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region EnumField +export const zEnumFieldType = zFieldTypeBase.extend({ + name: z.literal('EnumField'), +}); +export const zEnumFieldValue = z.string(); +export const zEnumFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zEnumFieldType, + value: zEnumFieldValue, +}); +export const zEnumFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zEnumFieldType, +}); +export const zEnumFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zEnumFieldType, + default: zEnumFieldValue, + options: z.array(z.string()), + labels: z.record(z.string()).optional(), +}); +export const zEnumFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zEnumFieldType, +}); +export type EnumFieldType = z.infer; +export type EnumFieldValue = z.infer; +export type EnumFieldInputInstance = z.infer; +export type EnumFieldOutputInstance = z.infer; +export type EnumFieldInputTemplate = z.infer; +export type EnumFieldOutputTemplate = z.infer; +export const isEnumFieldInputInstance = (val: unknown): val is EnumFieldInputInstance => + zEnumFieldInputInstance.safeParse(val).success; +export const isEnumFieldInputTemplate = (val: unknown): val is EnumFieldInputTemplate => + zEnumFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region ImageField +export const zImageFieldType = zFieldTypeBase.extend({ + name: z.literal('ImageField'), +}); +export const zImageFieldValue = zImageField.optional(); +export const zImageFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zImageFieldType, + value: zImageFieldValue, +}); +export const zImageFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zImageFieldType, +}); +export const zImageFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zImageFieldType, + default: zImageFieldValue, +}); +export const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zImageFieldType, +}); +export type ImageFieldType = z.infer; +export type ImageFieldValue = z.infer; +export type ImageFieldInputInstance = z.infer; +export type ImageFieldOutputInstance = z.infer; +export type ImageFieldInputTemplate = z.infer; +export type ImageFieldOutputTemplate = z.infer; +export const isImageFieldInputInstance = (val: unknown): val is ImageFieldInputInstance => + zImageFieldInputInstance.safeParse(val).success; +export const isImageFieldInputTemplate = (val: unknown): val is ImageFieldInputTemplate => + zImageFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region BoardField +export const zBoardFieldType = zFieldTypeBase.extend({ + name: z.literal('BoardField'), +}); +export const zBoardFieldValue = zBoardField.optional(); +export const zBoardFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zBoardFieldType, + value: zBoardFieldValue, +}); +export const zBoardFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zBoardFieldType, +}); +export const zBoardFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zBoardFieldType, + default: zBoardFieldValue, +}); +export const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zBoardFieldType, +}); +export type BoardFieldType = z.infer; +export type BoardFieldValue = z.infer; +export type BoardFieldInputInstance = z.infer; +export type BoardFieldOutputInstance = z.infer; +export type BoardFieldInputTemplate = z.infer; +export type BoardFieldOutputTemplate = z.infer; +export const isBoardFieldInputInstance = (val: unknown): val is BoardFieldInputInstance => + zBoardFieldInputInstance.safeParse(val).success; +export const isBoardFieldInputTemplate = (val: unknown): val is BoardFieldInputTemplate => + zBoardFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region ColorField +export const zColorFieldType = zFieldTypeBase.extend({ + name: z.literal('ColorField'), +}); +export const zColorFieldValue = zColorField.optional(); +export const zColorFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zColorFieldType, + value: zColorFieldValue, +}); +export const zColorFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zColorFieldType, +}); +export const zColorFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zColorFieldType, + default: zColorFieldValue, +}); +export const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zColorFieldType, +}); +export type ColorFieldType = z.infer; +export type ColorFieldValue = z.infer; +export type ColorFieldInputInstance = z.infer; +export type ColorFieldOutputInstance = z.infer; +export type ColorFieldInputTemplate = z.infer; +export type ColorFieldOutputTemplate = z.infer; +export const isColorFieldInputInstance = (val: unknown): val is ColorFieldInputInstance => + zColorFieldInputInstance.safeParse(val).success; +export const isColorFieldInputTemplate = (val: unknown): val is ColorFieldInputTemplate => + zColorFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region MainModelField +export const zMainModelFieldType = zFieldTypeBase.extend({ + name: z.literal('MainModelField'), +}); +export const zMainModelFieldValue = zMainModelField.optional(); +export const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zMainModelFieldType, + value: zMainModelFieldValue, +}); +export const zMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zMainModelFieldType, +}); +export const zMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zMainModelFieldType, + default: zMainModelFieldValue, +}); +export const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zMainModelFieldType, +}); +export type MainModelFieldType = z.infer; +export type MainModelFieldValue = z.infer; +export type MainModelFieldInputInstance = z.infer; +export type MainModelFieldOutputInstance = z.infer; +export type MainModelFieldInputTemplate = z.infer; +export type MainModelFieldOutputTemplate = z.infer; +export const isMainModelFieldInputInstance = (val: unknown): val is MainModelFieldInputInstance => + zMainModelFieldInputInstance.safeParse(val).success; +export const isMainModelFieldInputTemplate = (val: unknown): val is MainModelFieldInputTemplate => + zMainModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region SDXLMainModelField +export const zSDXLMainModelFieldType = zFieldTypeBase.extend({ + name: z.literal('SDXLMainModelField'), +}); +export const zSDXLMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only. +export const zSDXLMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zSDXLMainModelFieldType, + value: zSDXLMainModelFieldValue, +}); +export const zSDXLMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zSDXLMainModelFieldType, +}); +export const zSDXLMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zSDXLMainModelFieldType, + default: zSDXLMainModelFieldValue, +}); +export const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zSDXLMainModelFieldType, +}); +export type SDXLMainModelFieldType = z.infer; +export type SDXLMainModelFieldValue = z.infer; +export type SDXLMainModelFieldInputInstance = z.infer; +export type SDXLMainModelFieldOutputInstance = z.infer; +export type SDXLMainModelFieldInputTemplate = z.infer; +export type SDXLMainModelFieldOutputTemplate = z.infer; +export const isSDXLMainModelFieldInputInstance = (val: unknown): val is SDXLMainModelFieldInputInstance => + zSDXLMainModelFieldInputInstance.safeParse(val).success; +export const isSDXLMainModelFieldInputTemplate = (val: unknown): val is SDXLMainModelFieldInputTemplate => + zSDXLMainModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region SDXLRefinerModelField +export const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({ + name: z.literal('SDXLRefinerModelField'), +}); +export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only. +export const zSDXLRefinerModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zSDXLRefinerModelFieldType, + value: zSDXLRefinerModelFieldValue, +}); +export const zSDXLRefinerModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zSDXLRefinerModelFieldType, +}); +export const zSDXLRefinerModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zSDXLRefinerModelFieldType, + default: zSDXLRefinerModelFieldValue, +}); +export const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zSDXLRefinerModelFieldType, +}); +export type SDXLRefinerModelFieldType = z.infer; +export type SDXLRefinerModelFieldValue = z.infer; +export type SDXLRefinerModelFieldInputInstance = z.infer; +export type SDXLRefinerModelFieldOutputInstance = z.infer; +export type SDXLRefinerModelFieldInputTemplate = z.infer; +export type SDXLRefinerModelFieldOutputTemplate = z.infer; +export const isSDXLRefinerModelFieldInputInstance = (val: unknown): val is SDXLRefinerModelFieldInputInstance => + zSDXLRefinerModelFieldInputInstance.safeParse(val).success; +export const isSDXLRefinerModelFieldInputTemplate = (val: unknown): val is SDXLRefinerModelFieldInputTemplate => + zSDXLRefinerModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region VAEModelField +export const zVAEModelFieldType = zFieldTypeBase.extend({ + name: z.literal('VAEModelField'), +}); +export const zVAEModelFieldValue = zVAEModelField.optional(); +export const zVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zVAEModelFieldType, + value: zVAEModelFieldValue, +}); +export const zVAEModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zVAEModelFieldType, +}); +export const zVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zVAEModelFieldType, + default: zVAEModelFieldValue, +}); +export const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zVAEModelFieldType, +}); +export type VAEModelFieldType = z.infer; +export type VAEModelFieldValue = z.infer; +export type VAEModelFieldInputInstance = z.infer; +export type VAEModelFieldOutputInstance = z.infer; +export type VAEModelFieldInputTemplate = z.infer; +export type VAEModelFieldOutputTemplate = z.infer; +export const isVAEModelFieldInputInstance = (val: unknown): val is VAEModelFieldInputInstance => + zVAEModelFieldInputInstance.safeParse(val).success; +export const isVAEModelFieldInputTemplate = (val: unknown): val is VAEModelFieldInputTemplate => + zVAEModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region LoRAModelField +export const zLoRAModelFieldType = zFieldTypeBase.extend({ + name: z.literal('LoRAModelField'), +}); +export const zLoRAModelFieldValue = zLoRAModelField.optional(); +export const zLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zLoRAModelFieldType, + value: zLoRAModelFieldValue, +}); +export const zLoRAModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zLoRAModelFieldType, +}); +export const zLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zLoRAModelFieldType, + default: zLoRAModelFieldValue, +}); +export const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zLoRAModelFieldType, +}); +export type LoRAModelFieldType = z.infer; +export type LoRAModelFieldValue = z.infer; +export type LoRAModelFieldInputInstance = z.infer; +export type LoRAModelFieldOutputInstance = z.infer; +export type LoRAModelFieldInputTemplate = z.infer; +export type LoRAModelFieldOutputTemplate = z.infer; +export const isLoRAModelFieldInputInstance = (val: unknown): val is LoRAModelFieldInputInstance => + zLoRAModelFieldInputInstance.safeParse(val).success; +export const isLoRAModelFieldInputTemplate = (val: unknown): val is LoRAModelFieldInputTemplate => + zLoRAModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region ControlNetModelField +export const zControlNetModelFieldType = zFieldTypeBase.extend({ + name: z.literal('ControlNetModelField'), +}); +export const zControlNetModelFieldValue = zControlNetModelField.optional(); +export const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zControlNetModelFieldType, + value: zControlNetModelFieldValue, +}); +export const zControlNetModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zControlNetModelFieldType, +}); +export const zControlNetModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zControlNetModelFieldType, + default: zControlNetModelFieldValue, +}); +export const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zControlNetModelFieldType, +}); +export type ControlNetModelFieldType = z.infer; +export type ControlNetModelFieldValue = z.infer; +export type ControlNetModelFieldInputInstance = z.infer; +export type ControlNetModelFieldOutputInstance = z.infer; +export type ControlNetModelFieldInputTemplate = z.infer; +export type ControlNetModelFieldOutputTemplate = z.infer; +export const isControlNetModelFieldInputInstance = (val: unknown): val is ControlNetModelFieldInputInstance => + zControlNetModelFieldInputInstance.safeParse(val).success; +export const isControlNetModelFieldInputTemplate = (val: unknown): val is ControlNetModelFieldInputTemplate => + zControlNetModelFieldInputTemplate.safeParse(val).success; +export const isControlNetModelFieldValue = (v: unknown): v is ControlNetModelFieldValue => + zControlNetModelFieldValue.safeParse(v).success; +// #endregion + +// #region IPAdapterModelField +export const zIPAdapterModelFieldType = zFieldTypeBase.extend({ + name: z.literal('IPAdapterModelField'), +}); +export const zIPAdapterModelFieldValue = zIPAdapterModelField.optional(); +export const zIPAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zIPAdapterModelFieldType, + value: zIPAdapterModelFieldValue, +}); +export const zIPAdapterModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zIPAdapterModelFieldType, +}); +export const zIPAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zIPAdapterModelFieldType, + default: zIPAdapterModelFieldValue, +}); +export const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zIPAdapterModelFieldType, +}); +export type IPAdapterModelFieldType = z.infer; +export type IPAdapterModelFieldValue = z.infer; +export type IPAdapterModelFieldInputInstance = z.infer; +export type IPAdapterModelFieldOutputInstance = z.infer; +export type IPAdapterModelFieldInputTemplate = z.infer; +export type IPAdapterModelFieldOutputTemplate = z.infer; +export const isIPAdapterModelFieldInputInstance = (val: unknown): val is IPAdapterModelFieldInputInstance => + zIPAdapterModelFieldInputInstance.safeParse(val).success; +export const isIPAdapterModelFieldInputTemplate = (val: unknown): val is IPAdapterModelFieldInputTemplate => + zIPAdapterModelFieldInputTemplate.safeParse(val).success; +export const isIPAdapterModelFieldValue = (val: unknown): val is IPAdapterModelFieldValue => + zIPAdapterModelFieldValue.safeParse(val).success; +// #endregion + +// #region T2IAdapterField +export const zT2IAdapterModelFieldType = zFieldTypeBase.extend({ + name: z.literal('T2IAdapterModelField'), +}); +export const zT2IAdapterModelFieldValue = zT2IAdapterModelField.optional(); +export const zT2IAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zT2IAdapterModelFieldType, + value: zT2IAdapterModelFieldValue, +}); +export const zT2IAdapterModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zT2IAdapterModelFieldType, +}); +export const zT2IAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zT2IAdapterModelFieldType, + default: zT2IAdapterModelFieldValue, +}); +export const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zT2IAdapterModelFieldType, +}); +export type T2IAdapterModelFieldType = z.infer; +export type T2IAdapterModelFieldValue = z.infer; +export type T2IAdapterModelFieldInputInstance = z.infer; +export type T2IAdapterModelFieldOutputInstance = z.infer; +export type T2IAdapterModelFieldInputTemplate = z.infer; +export type T2IAdapterModelFieldOutputTemplate = z.infer; +export const isT2IAdapterModelFieldInputInstance = (val: unknown): val is T2IAdapterModelFieldInputInstance => + zT2IAdapterModelFieldInputInstance.safeParse(val).success; +export const isT2IAdapterModelFieldInputTemplate = (val: unknown): val is T2IAdapterModelFieldInputTemplate => + zT2IAdapterModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region SchedulerField +export const zSchedulerFieldType = zFieldTypeBase.extend({ + name: z.literal('SchedulerField'), +}); +export const zSchedulerFieldValue = zSchedulerField.optional(); +export const zSchedulerFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zSchedulerFieldType, + value: zSchedulerFieldValue, +}); +export const zSchedulerFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zSchedulerFieldType, +}); +export const zSchedulerFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zSchedulerFieldType, + default: zSchedulerFieldValue, +}); +export const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zSchedulerFieldType, +}); +export type SchedulerFieldType = z.infer; +export type SchedulerFieldValue = z.infer; +export type SchedulerFieldInputInstance = z.infer; +export type SchedulerFieldOutputInstance = z.infer; +export type SchedulerFieldInputTemplate = z.infer; +export type SchedulerFieldOutputTemplate = z.infer; +export const isSchedulerFieldInputInstance = (val: unknown): val is SchedulerFieldInputInstance => + zSchedulerFieldInputInstance.safeParse(val).success; +export const isSchedulerFieldInputTemplate = (val: unknown): val is SchedulerFieldInputTemplate => + zSchedulerFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region StatelessField +/** + * StatelessField is a catchall for stateless fields with no UI input components. They do not + * do not support "direct" input, instead only accepting connections from other fields. + * + * This field type serves as a "generic" field type. + * + * Examples include: + * - Fields like UNetField or LatentsField where we do not allow direct UI input + * - Reserved fields like IsIntermediate + * - Any other field we don't have full-on schemas for + */ +export const zStatelessFieldType = zFieldTypeBase.extend({ + name: z.string().min(1), // stateless --> we accept the field's name as the type +}); +export const zStatelessFieldValue = z.undefined().catch(undefined); // stateless --> no value, but making this z.never() introduces a lot of extra TS fanagling +export const zStatelessFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zStatelessFieldType, + value: zStatelessFieldValue, +}); +export const zStatelessFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zStatelessFieldType, +}); +export const zStatelessFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zStatelessFieldType, + default: zStatelessFieldValue, + input: z.literal('connection'), // stateless --> only accepts connection inputs +}); +export const zStatelessFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zStatelessFieldType, +}); + +export type StatelessFieldType = z.infer; +export type StatelessFieldValue = z.infer; +export type StatelessFieldInputInstance = z.infer; +export type StatelessFieldOutputInstance = z.infer; +export type StatelessFieldInputTemplate = z.infer; +export type StatelessFieldOutputTemplate = z.infer; +// #endregion + +/** + * Here we define the main field unions: + * - FieldType + * - FieldValue + * - FieldInputInstance + * - FieldOutputInstance + * - FieldInputTemplate + * - FieldOutputTemplate + * + * All stateful fields are unioned together, and then that union is unioned with StatelessField. + * + * This allows us to interact with stateful fields without needing to worry about "generic" handling + * for all other StatelessFields. + */ + +// #region StatefulFieldType & FieldType +export const zStatefulFieldType = z.union([ + zIntegerFieldType, + zFloatFieldType, + zStringFieldType, + zBooleanFieldType, + zEnumFieldType, + zImageFieldType, + zBoardFieldType, + zMainModelFieldType, + zSDXLMainModelFieldType, + zSDXLRefinerModelFieldType, + zVAEModelFieldType, + zLoRAModelFieldType, + zControlNetModelFieldType, + zIPAdapterModelFieldType, + zT2IAdapterModelFieldType, + zColorFieldType, + zSchedulerFieldType, +]); +export type StatefulFieldType = z.infer; +export const isStatefulFieldType = (val: unknown): val is StatefulFieldType => + zStatefulFieldType.safeParse(val).success; + +export const zFieldType = z.union([zStatefulFieldType, zStatelessFieldType]); +export type FieldType = z.infer; +export const isFieldType = (val: unknown): val is FieldType => zFieldType.safeParse(val).success; +// #endregion + +// #region StatefulFieldValue & FieldValue +export const zStatefulFieldValue = z.union([ + zIntegerFieldValue, + zFloatFieldValue, + zStringFieldValue, + zBooleanFieldValue, + zEnumFieldValue, + zImageFieldValue, + zBoardFieldValue, + zMainModelFieldValue, + zSDXLMainModelFieldValue, + zSDXLRefinerModelFieldValue, + zVAEModelFieldValue, + zLoRAModelFieldValue, + zControlNetModelFieldValue, + zIPAdapterModelFieldValue, + zT2IAdapterModelFieldValue, + zColorFieldValue, + zSchedulerFieldValue, +]); +export type StatefulFieldValue = z.infer; +export const isStatefulFieldValue = (val: unknown): val is StatefulFieldValue => + zStatefulFieldValue.safeParse(val).success; + +export const zFieldValue = z.union([zStatefulFieldValue, zStatelessFieldValue]); +export type FieldValue = z.infer; +export const isFieldValue = (val: unknown): val is FieldValue => zFieldValue.safeParse(val).success; +// #endregion + +// #region StatefulFieldInputInstance & FieldInputInstance +export const zStatefulFieldInputInstance = z.union([ + zIntegerFieldInputInstance, + zFloatFieldInputInstance, + zStringFieldInputInstance, + zBooleanFieldInputInstance, + zEnumFieldInputInstance, + zImageFieldInputInstance, + zBoardFieldInputInstance, + zMainModelFieldInputInstance, + zSDXLMainModelFieldInputInstance, + zSDXLRefinerModelFieldInputInstance, + zVAEModelFieldInputInstance, + zLoRAModelFieldInputInstance, + zControlNetModelFieldInputInstance, + zIPAdapterModelFieldInputInstance, + zT2IAdapterModelFieldInputInstance, + zColorFieldInputInstance, + zSchedulerFieldInputInstance, +]); +export type StatefulFieldInputInstance = z.infer; +export const isStatefulFieldInputInstance = (val: unknown): val is StatefulFieldInputInstance => + zStatefulFieldInputInstance.safeParse(val).success; + +export const zFieldInputInstance = z.union([zStatefulFieldInputInstance, zStatelessFieldInputInstance]); +export type FieldInputInstance = z.infer; +export const isFieldInputInstance = (val: unknown): val is FieldInputInstance => + zFieldInputInstance.safeParse(val).success; +// #endregion + +// #region StatefulFieldOutputInstance & FieldOutputInstance +export const zStatefulFieldOutputInstance = z.union([ + zIntegerFieldOutputInstance, + zFloatFieldOutputInstance, + zStringFieldOutputInstance, + zBooleanFieldOutputInstance, + zEnumFieldOutputInstance, + zImageFieldOutputInstance, + zBoardFieldOutputInstance, + zMainModelFieldOutputInstance, + zSDXLMainModelFieldOutputInstance, + zSDXLRefinerModelFieldOutputInstance, + zVAEModelFieldOutputInstance, + zLoRAModelFieldOutputInstance, + zControlNetModelFieldOutputInstance, + zIPAdapterModelFieldOutputInstance, + zT2IAdapterModelFieldOutputInstance, + zColorFieldOutputInstance, + zSchedulerFieldOutputInstance, +]); +export type StatefulFieldOutputInstance = z.infer; +export const isStatefulFieldOutputInstance = (val: unknown): val is StatefulFieldOutputInstance => + zStatefulFieldOutputInstance.safeParse(val).success; + +export const zFieldOutputInstance = z.union([zStatefulFieldOutputInstance, zStatelessFieldOutputInstance]); +export type FieldOutputInstance = z.infer; +export const isFieldOutputInstance = (val: unknown): val is FieldOutputInstance => + zFieldOutputInstance.safeParse(val).success; +// #endregion + +// #region StatefulFieldInputTemplate & FieldInputTemplate +export const zStatefulFieldInputTemplate = z.union([ + zIntegerFieldInputTemplate, + zFloatFieldInputTemplate, + zStringFieldInputTemplate, + zBooleanFieldInputTemplate, + zEnumFieldInputTemplate, + zImageFieldInputTemplate, + zBoardFieldInputTemplate, + zMainModelFieldInputTemplate, + zSDXLMainModelFieldInputTemplate, + zSDXLRefinerModelFieldInputTemplate, + zVAEModelFieldInputTemplate, + zLoRAModelFieldInputTemplate, + zControlNetModelFieldInputTemplate, + zIPAdapterModelFieldInputTemplate, + zT2IAdapterModelFieldInputTemplate, + zColorFieldInputTemplate, + zSchedulerFieldInputTemplate, + zStatelessFieldInputTemplate, +]); +export type StatefulFieldInputTemplate = z.infer; +export const isStatefulFieldInputTemplate = (val: unknown): val is StatefulFieldInputTemplate => + zStatefulFieldInputTemplate.safeParse(val).success; + +export const zFieldInputTemplate = z.union([zStatefulFieldInputTemplate, zStatelessFieldInputTemplate]); +export type FieldInputTemplate = z.infer; +export const isFieldInputTemplate = (val: unknown): val is FieldInputTemplate => + zFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region StatefulFieldOutputTemplate & FieldOutputTemplate +export const zStatefulFieldOutputTemplate = z.union([ + zIntegerFieldOutputTemplate, + zFloatFieldOutputTemplate, + zStringFieldOutputTemplate, + zBooleanFieldOutputTemplate, + zEnumFieldOutputTemplate, + zImageFieldOutputTemplate, + zBoardFieldOutputTemplate, + zMainModelFieldOutputTemplate, + zSDXLMainModelFieldOutputTemplate, + zSDXLRefinerModelFieldOutputTemplate, + zVAEModelFieldOutputTemplate, + zLoRAModelFieldOutputTemplate, + zControlNetModelFieldOutputTemplate, + zIPAdapterModelFieldOutputTemplate, + zT2IAdapterModelFieldOutputTemplate, + zColorFieldOutputTemplate, + zSchedulerFieldOutputTemplate, +]); +export type StatefulFieldOutputTemplate = z.infer; +export const isStatefulFieldOutputTemplate = (val: unknown): val is StatefulFieldOutputTemplate => + zStatefulFieldOutputTemplate.safeParse(val).success; + +export const zFieldOutputTemplate = z.union([zStatefulFieldOutputTemplate, zStatelessFieldOutputTemplate]); +export type FieldOutputTemplate = z.infer; +export const isFieldOutputTemplate = (val: unknown): val is FieldOutputTemplate => + zFieldOutputTemplate.safeParse(val).success; +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/invocation.ts b/invokeai/frontend/web/src/features/nodes/types/v2/invocation.ts new file mode 100644 index 0000000000..86ec70fd9b --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/invocation.ts @@ -0,0 +1,93 @@ +import type { Edge, Node } from 'reactflow'; +import { z } from 'zod'; + +import { zClassification, zProgressImage } from './common'; +import { zFieldInputInstance, zFieldInputTemplate, zFieldOutputInstance, zFieldOutputTemplate } from './field'; +import { zSemVer } from './semver'; + +// #region InvocationTemplate +export const zInvocationTemplate = z.object({ + type: z.string(), + title: z.string(), + description: z.string(), + tags: z.array(z.string().min(1)), + inputs: z.record(zFieldInputTemplate), + outputs: z.record(zFieldOutputTemplate), + outputType: z.string().min(1), + version: zSemVer, + useCache: z.boolean(), + nodePack: z.string().min(1).nullish(), + classification: zClassification, +}); +export type InvocationTemplate = z.infer; +// #endregion + +// #region NodeData +export const zInvocationNodeData = z.object({ + id: z.string().trim().min(1), + type: z.string().trim().min(1), + label: z.string(), + isOpen: z.boolean(), + notes: z.string(), + isIntermediate: z.boolean(), + useCache: z.boolean(), + version: zSemVer, + nodePack: z.string().min(1).nullish(), + inputs: z.record(zFieldInputInstance), + outputs: z.record(zFieldOutputInstance), +}); + +export const zNotesNodeData = z.object({ + id: z.string().trim().min(1), + type: z.literal('notes'), + label: z.string(), + isOpen: z.boolean(), + notes: z.string(), +}); +export const zCurrentImageNodeData = z.object({ + id: z.string().trim().min(1), + type: z.literal('current_image'), + label: z.string(), + isOpen: z.boolean(), +}); +export const zAnyNodeData = z.union([zInvocationNodeData, zNotesNodeData, zCurrentImageNodeData]); + +export type NotesNodeData = z.infer; +export type InvocationNodeData = z.infer; +export type CurrentImageNodeData = z.infer; +export type AnyNodeData = z.infer; + +export type InvocationNode = Node; +export type NotesNode = Node; +export type CurrentImageNode = Node; +export type AnyNode = Node; + +export const isInvocationNode = (node?: AnyNode): node is InvocationNode => Boolean(node && node.type === 'invocation'); +export const isNotesNode = (node?: AnyNode): node is NotesNode => Boolean(node && node.type === 'notes'); +export const isCurrentImageNode = (node?: AnyNode): node is CurrentImageNode => + Boolean(node && node.type === 'current_image'); +export const isInvocationNodeData = (node?: AnyNodeData): node is InvocationNodeData => + Boolean(node && !['notes', 'current_image'].includes(node.type)); // node.type may be 'notes', 'current_image', or any invocation type +// #endregion + +// #region NodeExecutionState +export const zNodeStatus = z.enum(['PENDING', 'IN_PROGRESS', 'COMPLETED', 'FAILED']); +export const zNodeExecutionState = z.object({ + nodeId: z.string().trim().min(1), + status: zNodeStatus, + progress: z.number().nullable(), + progressImage: zProgressImage.nullable(), + error: z.string().nullable(), + outputs: z.array(z.any()), +}); +export type NodeExecutionState = z.infer; +export type NodeStatus = z.infer; +// #endregion + +// #region Edges +export const zInvocationNodeEdgeExtra = z.object({ + type: z.union([z.literal('default'), z.literal('collapsed')]), +}); +export type InvocationNodeEdgeExtra = z.infer; +export type InvocationNodeEdge = Edge; +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/metadata.ts b/invokeai/frontend/web/src/features/nodes/types/v2/metadata.ts new file mode 100644 index 0000000000..0cc30499e3 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/metadata.ts @@ -0,0 +1,77 @@ +import { z } from 'zod'; + +import { + zControlField, + zIPAdapterField, + zLoRAModelField, + zMainModelField, + zSDXLRefinerModelField, + zT2IAdapterField, + zVAEModelField, +} from './common'; + +// #region Metadata-optimized versions of schemas +// TODO: It's possible that `deepPartial` will be deprecated: +// - https://github.com/colinhacks/zod/issues/2106 +// - https://github.com/colinhacks/zod/issues/2854 +export const zLoRAMetadataItem = z.object({ + lora: zLoRAModelField.deepPartial(), + weight: z.number(), +}); +const zControlNetMetadataItem = zControlField.deepPartial(); +const zIPAdapterMetadataItem = zIPAdapterField.deepPartial(); +const zT2IAdapterMetadataItem = zT2IAdapterField.deepPartial(); +const zSDXLRefinerModelMetadataItem = zSDXLRefinerModelField.deepPartial(); +const zModelMetadataItem = zMainModelField.deepPartial(); +const zVAEModelMetadataItem = zVAEModelField.deepPartial(); +export type LoRAMetadataItem = z.infer; +export type ControlNetMetadataItem = z.infer; +export type IPAdapterMetadataItem = z.infer; +export type T2IAdapterMetadataItem = z.infer; +export type SDXLRefinerModelMetadataItem = z.infer; +export type ModelMetadataItem = z.infer; +export type VAEModelMetadataItem = z.infer; +// #endregion + +// #region CoreMetadata +export const zCoreMetadata = z + .object({ + app_version: z.string().nullish().catch(null), + generation_mode: z.string().nullish().catch(null), + created_by: z.string().nullish().catch(null), + positive_prompt: z.string().nullish().catch(null), + negative_prompt: z.string().nullish().catch(null), + width: z.number().int().nullish().catch(null), + height: z.number().int().nullish().catch(null), + seed: z.number().int().nullish().catch(null), + rand_device: z.string().nullish().catch(null), + cfg_scale: z.number().nullish().catch(null), + cfg_rescale_multiplier: z.number().nullish().catch(null), + steps: z.number().int().nullish().catch(null), + scheduler: z.string().nullish().catch(null), + clip_skip: z.number().int().nullish().catch(null), + model: zModelMetadataItem.nullish().catch(null), + controlnets: z.array(zControlNetMetadataItem).nullish().catch(null), + ipAdapters: z.array(zIPAdapterMetadataItem).nullish().catch(null), + t2iAdapters: z.array(zT2IAdapterMetadataItem).nullish().catch(null), + loras: z.array(zLoRAMetadataItem).nullish().catch(null), + vae: zVAEModelMetadataItem.nullish().catch(null), + strength: z.number().nullish().catch(null), + hrf_enabled: z.boolean().nullish().catch(null), + hrf_strength: z.number().nullish().catch(null), + hrf_method: z.string().nullish().catch(null), + init_image: z.string().nullish().catch(null), + positive_style_prompt: z.string().nullish().catch(null), + negative_style_prompt: z.string().nullish().catch(null), + refiner_model: zSDXLRefinerModelMetadataItem.nullish().catch(null), + refiner_cfg_scale: z.number().nullish().catch(null), + refiner_steps: z.number().int().nullish().catch(null), + refiner_scheduler: z.string().nullish().catch(null), + refiner_positive_aesthetic_score: z.number().nullish().catch(null), + refiner_negative_aesthetic_score: z.number().nullish().catch(null), + refiner_start: z.number().nullish().catch(null), + }) + .passthrough(); +export type CoreMetadata = z.infer; + +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/openapi.ts b/invokeai/frontend/web/src/features/nodes/types/v2/openapi.ts new file mode 100644 index 0000000000..83d774439a --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/openapi.ts @@ -0,0 +1,86 @@ +import type { OpenAPIV3_1 } from 'openapi-types'; +import type { + InputFieldJSONSchemaExtra, + InvocationJSONSchemaExtra, + OutputFieldJSONSchemaExtra, +} from 'services/api/types'; + +// Janky customization of OpenAPI Schema :/ + +export type InvocationSchemaExtra = InvocationJSONSchemaExtra & { + output: OpenAPIV3_1.ReferenceObject; // the output of the invocation + title: string; + category?: string; + tags?: string[]; + version: string; + properties: Omit< + NonNullable & (InputFieldJSONSchemaExtra | OutputFieldJSONSchemaExtra), + 'type' + > & { + type: Omit & { + default: string; + }; + use_cache: Omit & { + default: boolean; + }; + }; +}; + +export type InvocationSchemaType = { + default: string; // the type of the invocation +}; + +export type InvocationBaseSchemaObject = Omit & + InvocationSchemaExtra; + +export type InvocationOutputSchemaObject = Omit & { + properties: OpenAPIV3_1.SchemaObject['properties'] & { + type: Omit & { + default: string; + }; + } & { + class: 'output'; + }; +}; + +export type InvocationFieldSchema = OpenAPIV3_1.SchemaObject & InputFieldJSONSchemaExtra; + +export type OpenAPIV3_1SchemaOrRef = OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject; + +export interface ArraySchemaObject extends InvocationBaseSchemaObject { + type: OpenAPIV3_1.ArraySchemaObjectType; + items: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject; +} +export interface NonArraySchemaObject extends InvocationBaseSchemaObject { + type?: OpenAPIV3_1.NonArraySchemaObjectType; +} + +export type InvocationSchemaObject = (ArraySchemaObject | NonArraySchemaObject) & { class: 'invocation' }; + +export const isSchemaObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined +): obj is OpenAPIV3_1.SchemaObject => Boolean(obj && !('$ref' in obj)); + +export const isArraySchemaObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined +): obj is OpenAPIV3_1.ArraySchemaObject => Boolean(obj && !('$ref' in obj) && obj.type === 'array'); + +export const isNonArraySchemaObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined +): obj is OpenAPIV3_1.NonArraySchemaObject => Boolean(obj && !('$ref' in obj) && obj.type !== 'array'); + +export const isRefObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined +): obj is OpenAPIV3_1.ReferenceObject => Boolean(obj && '$ref' in obj); + +export const isInvocationSchemaObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | InvocationSchemaObject +): obj is InvocationSchemaObject => 'class' in obj && obj.class === 'invocation'; + +export const isInvocationOutputSchemaObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | InvocationOutputSchemaObject +): obj is InvocationOutputSchemaObject => 'class' in obj && obj.class === 'output'; + +export const isInvocationFieldSchema = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject +): obj is InvocationFieldSchema => !('$ref' in obj); diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/semver.ts b/invokeai/frontend/web/src/features/nodes/types/v2/semver.ts new file mode 100644 index 0000000000..3ba330eac4 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/semver.ts @@ -0,0 +1,21 @@ +import { z } from 'zod'; + +// Schemas and types for working with semver + +const zVersionInt = z.coerce.number().int().min(0); + +export const zSemVer = z.string().refine((val) => { + const [major, minor, patch] = val.split('.'); + return ( + zVersionInt.safeParse(major).success && zVersionInt.safeParse(minor).success && zVersionInt.safeParse(patch).success + ); +}); + +export const zParsedSemver = zSemVer.transform((val) => { + const [major, minor, patch] = val.split('.'); + return { + major: Number(major), + minor: Number(minor), + patch: Number(patch), + }; +}); diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/workflow.ts b/invokeai/frontend/web/src/features/nodes/types/v2/workflow.ts new file mode 100644 index 0000000000..723a354013 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/v2/workflow.ts @@ -0,0 +1,89 @@ +import { z } from 'zod'; + +import { zFieldIdentifier } from './field'; +import { zInvocationNodeData, zNotesNodeData } from './invocation'; + +// #region Workflow misc +export const zXYPosition = z + .object({ + x: z.number(), + y: z.number(), + }) + .default({ x: 0, y: 0 }); +export type XYPosition = z.infer; + +export const zDimension = z.number().gt(0).nullish(); +export type Dimension = z.infer; + +export const zWorkflowCategory = z.enum(['user', 'default', 'project']); +export type WorkflowCategory = z.infer; +// #endregion + +// #region Workflow Nodes +export const zWorkflowInvocationNode = z.object({ + id: z.string().trim().min(1), + type: z.literal('invocation'), + data: zInvocationNodeData, + width: zDimension, + height: zDimension, + position: zXYPosition, +}); +export const zWorkflowNotesNode = z.object({ + id: z.string().trim().min(1), + type: z.literal('notes'), + data: zNotesNodeData, + width: zDimension, + height: zDimension, + position: zXYPosition, +}); +export const zWorkflowNode = z.union([zWorkflowInvocationNode, zWorkflowNotesNode]); + +export type WorkflowInvocationNode = z.infer; +export type WorkflowNotesNode = z.infer; +export type WorkflowNode = z.infer; + +export const isWorkflowInvocationNode = (val: unknown): val is WorkflowInvocationNode => + zWorkflowInvocationNode.safeParse(val).success; +// #endregion + +// #region Workflow Edges +export const zWorkflowEdgeBase = z.object({ + id: z.string().trim().min(1), + source: z.string().trim().min(1), + target: z.string().trim().min(1), +}); +export const zWorkflowEdgeDefault = zWorkflowEdgeBase.extend({ + type: z.literal('default'), + sourceHandle: z.string().trim().min(1), + targetHandle: z.string().trim().min(1), +}); +export const zWorkflowEdgeCollapsed = zWorkflowEdgeBase.extend({ + type: z.literal('collapsed'), +}); +export const zWorkflowEdge = z.union([zWorkflowEdgeDefault, zWorkflowEdgeCollapsed]); + +export type WorkflowEdgeDefault = z.infer; +export type WorkflowEdgeCollapsed = z.infer; +export type WorkflowEdge = z.infer; +// #endregion + +// #region Workflow +export const zWorkflowV2 = z.object({ + id: z.string().min(1).optional(), + name: z.string(), + author: z.string(), + description: z.string(), + version: z.string(), + contact: z.string(), + tags: z.string(), + notes: z.string(), + nodes: z.array(zWorkflowNode), + edges: z.array(zWorkflowEdge), + exposedFields: z.array(zFieldIdentifier), + meta: z.object({ + category: zWorkflowCategory.default('user'), + version: z.literal('2.0.0'), + }), +}); +export type WorkflowV2 = z.infer; +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/workflow.ts b/invokeai/frontend/web/src/features/nodes/types/workflow.ts index 723a354013..adad7c0f21 100644 --- a/invokeai/frontend/web/src/features/nodes/types/workflow.ts +++ b/invokeai/frontend/web/src/features/nodes/types/workflow.ts @@ -24,16 +24,12 @@ export const zWorkflowInvocationNode = z.object({ id: z.string().trim().min(1), type: z.literal('invocation'), data: zInvocationNodeData, - width: zDimension, - height: zDimension, position: zXYPosition, }); export const zWorkflowNotesNode = z.object({ id: z.string().trim().min(1), type: z.literal('notes'), data: zNotesNodeData, - width: zDimension, - height: zDimension, position: zXYPosition, }); export const zWorkflowNode = z.union([zWorkflowInvocationNode, zWorkflowNotesNode]); @@ -68,7 +64,7 @@ export type WorkflowEdge = z.infer; // #endregion // #region Workflow -export const zWorkflowV2 = z.object({ +export const zWorkflowV3 = z.object({ id: z.string().min(1).optional(), name: z.string(), author: z.string(), @@ -82,8 +78,8 @@ export const zWorkflowV2 = z.object({ exposedFields: z.array(zFieldIdentifier), meta: z.object({ category: zWorkflowCategory.default('user'), - version: z.literal('2.0.0'), + version: z.literal('3.0.0'), }), }); -export type WorkflowV2 = z.infer; +export type WorkflowV3 = z.infer; // #endregion diff --git a/invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts b/invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts index ea40bd4660..af19aa86ea 100644 --- a/invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts +++ b/invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts @@ -1,5 +1,5 @@ import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants'; -import type { FieldInputInstance, FieldOutputInstance } from 'features/nodes/types/field'; +import type { FieldInputInstance } from 'features/nodes/types/field'; import type { InvocationNode, InvocationTemplate } from 'features/nodes/types/invocation'; import { buildFieldInputInstance } from 'features/nodes/util/schema/buildFieldInputInstance'; import { reduce } from 'lodash-es'; @@ -24,25 +24,6 @@ export const buildInvocationNode = (position: XYPosition, template: InvocationTe {} as Record ); - const outputs = reduce( - template.outputs, - (outputsAccumulator, outputTemplate, outputName) => { - const fieldId = uuidv4(); - - const outputFieldValue: FieldOutputInstance = { - id: fieldId, - name: outputName, - type: outputTemplate.type, - fieldKind: 'output', - }; - - outputsAccumulator[outputName] = outputFieldValue; - - return outputsAccumulator; - }, - {} as Record - ); - const node: InvocationNode = { ...SHARED_NODE_PROPERTIES, id: nodeId, @@ -58,7 +39,6 @@ export const buildInvocationNode = (position: XYPosition, template: InvocationTe isIntermediate: type === 'save_image' ? false : true, useCache: template.useCache, inputs, - outputs, }, }; diff --git a/invokeai/frontend/web/src/features/nodes/util/node/nodeUpdate.ts b/invokeai/frontend/web/src/features/nodes/util/node/nodeUpdate.ts index f195c49d30..5ece51d0f3 100644 --- a/invokeai/frontend/web/src/features/nodes/util/node/nodeUpdate.ts +++ b/invokeai/frontend/web/src/features/nodes/util/node/nodeUpdate.ts @@ -54,6 +54,5 @@ export const updateNode = (node: InvocationNode, template: InvocationTemplate): // Remove any fields that are not in the template clone.data.inputs = pick(clone.data.inputs, keys(defaults.data.inputs)); - clone.data.outputs = pick(clone.data.outputs, keys(defaults.data.outputs)); return clone; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts index dd3cf0ad7b..f8097566c9 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts @@ -23,11 +23,8 @@ const FIELD_VALUE_FALLBACK_MAP: Record = export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => { const fieldInstance: FieldInputInstance = { - id, name: template.name, - type: template.type, label: '', - fieldKind: 'input' as const, value: template.default ?? get(FIELD_VALUE_FALLBACK_MAP, template.type.name), }; diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts index 70775a9882..720da16464 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts @@ -2,8 +2,8 @@ import { logger } from 'app/logging/logger'; import { parseify } from 'common/util/serialize'; import type { NodesState, WorkflowsState } from 'features/nodes/store/types'; import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; -import { zWorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; +import { zWorkflowV3 } from 'features/nodes/types/workflow'; import i18n from 'i18n'; import { cloneDeep, pick } from 'lodash-es'; import { fromZodError } from 'zod-validation-error'; @@ -25,14 +25,14 @@ const workflowKeys = [ 'exposedFields', 'meta', 'id', -] satisfies (keyof WorkflowV2)[]; +] satisfies (keyof WorkflowV3)[]; -export type BuildWorkflowFunction = (arg: BuildWorkflowArg) => WorkflowV2; +export type BuildWorkflowFunction = (arg: BuildWorkflowArg) => WorkflowV3; -export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV2 => { +export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV3 => { const clonedWorkflow = pick(cloneDeep(workflow), workflowKeys); - const newWorkflow: WorkflowV2 = { + const newWorkflow: WorkflowV3 = { ...clonedWorkflow, nodes: [], edges: [], @@ -45,8 +45,6 @@ export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflo type: node.type, data: cloneDeep(node.data), position: { ...node.position }, - width: node.width, - height: node.height, }); } else if (isNotesNode(node) && node.type) { newWorkflow.nodes.push({ @@ -54,8 +52,6 @@ export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflo type: node.type, data: cloneDeep(node.data), position: { ...node.position }, - width: node.width, - height: node.height, }); } }); @@ -83,12 +79,12 @@ export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflo return newWorkflow; }; -export const buildWorkflowWithValidation = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV2 | null => { +export const buildWorkflowWithValidation = ({ nodes, edges, workflow }: BuildWorkflowArg): WorkflowV3 | null => { // builds what really, really should be a valid workflow const workflowToValidate = buildWorkflowFast({ nodes, edges, workflow }); // but bc we are storing this in the DB, let's be extra sure - const result = zWorkflowV2.safeParse(workflowToValidate); + const result = zWorkflowV3.safeParse(workflowToValidate); if (!result.success) { const { message } = fromZodError(result.error, { diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts index a2677f3d17..a023c96ba9 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts @@ -6,8 +6,10 @@ import { zSemVer } from 'features/nodes/types/semver'; import { FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING } from 'features/nodes/types/v1/fieldTypeMap'; import type { WorkflowV1 } from 'features/nodes/types/v1/workflowV1'; import { zWorkflowV1 } from 'features/nodes/types/v1/workflowV1'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; -import { zWorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV2 } from 'features/nodes/types/v2/workflow'; +import { zWorkflowV2 } from 'features/nodes/types/v2/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; +import { zWorkflowV3 } from 'features/nodes/types/workflow'; import { t } from 'i18next'; import { forEach } from 'lodash-es'; import { z } from 'zod'; @@ -30,7 +32,7 @@ const zWorkflowMetaVersion = z.object({ * - Workflow schema version bumped to 2.0.0 */ const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => { - const invocationTemplates = $store.get()?.getState().nodeTemplates.templates; + const invocationTemplates = $store.get()?.getState().nodes.templates; if (!invocationTemplates) { throw new Error(t('app.storeNotInitialized')); @@ -70,26 +72,34 @@ const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => { return zWorkflowV2.parse(workflowToMigrate); }; +const migrateV2toV3 = (workflowToMigrate: WorkflowV2): WorkflowV3 => { + // Bump version + (workflowToMigrate as unknown as WorkflowV3).meta.version = '3.0.0'; + // Parsing strips out any extra properties not in the latest version + return zWorkflowV3.parse(workflowToMigrate); +}; + /** * Parses a workflow and migrates it to the latest version if necessary. */ -export const parseAndMigrateWorkflow = (data: unknown): WorkflowV2 => { +export const parseAndMigrateWorkflow = (data: unknown): WorkflowV3 => { const workflowVersionResult = zWorkflowMetaVersion.safeParse(data); if (!workflowVersionResult.success) { throw new WorkflowVersionError(t('nodes.unableToGetWorkflowVersion')); } - const { version } = workflowVersionResult.data.meta; + let workflow = data as WorkflowV1 | WorkflowV2 | WorkflowV3; - if (version === '1.0.0') { - const v1 = zWorkflowV1.parse(data); - return migrateV1toV2(v1); + if (workflow.meta.version === '1.0.0') { + const v1 = zWorkflowV1.parse(workflow); + workflow = migrateV1toV2(v1); } - if (version === '2.0.0') { - return zWorkflowV2.parse(data); + if (workflow.meta.version === '2.0.0') { + const v2 = zWorkflowV2.parse(workflow); + workflow = migrateV2toV3(v2); } - throw new WorkflowVersionError(t('nodes.unrecognizedWorkflowVersion', { version })); + return workflow as WorkflowV3; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts index 848d2aee77..5096e588b0 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts @@ -1,6 +1,6 @@ import { parseify } from 'common/util/serialize'; import type { InvocationTemplate } from 'features/nodes/types/invocation'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; import { isWorkflowInvocationNode } from 'features/nodes/types/workflow'; import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate'; import { t } from 'i18next'; @@ -16,7 +16,7 @@ type WorkflowWarning = { }; type ValidateWorkflowResult = { - workflow: WorkflowV2; + workflow: WorkflowV3; warnings: WorkflowWarning[]; }; diff --git a/invokeai/frontend/web/src/features/workflowLibrary/hooks/useSaveWorkflow.ts b/invokeai/frontend/web/src/features/workflowLibrary/hooks/useSaveWorkflow.ts index 5d484b6897..7b49d70213 100644 --- a/invokeai/frontend/web/src/features/workflowLibrary/hooks/useSaveWorkflow.ts +++ b/invokeai/frontend/web/src/features/workflowLibrary/hooks/useSaveWorkflow.ts @@ -3,7 +3,7 @@ import { useToast } from '@invoke-ai/ui-library'; import { useAppDispatch } from 'app/store/storeHooks'; import { $builtWorkflow } from 'features/nodes/hooks/useWorkflowWatcher'; import { workflowIDChanged, workflowSaved } from 'features/nodes/store/workflowSlice'; -import type { WorkflowV2 } from 'features/nodes/types/workflow'; +import type { WorkflowV3 } from 'features/nodes/types/workflow'; import { workflowUpdated } from 'features/workflowLibrary/store/actions'; import { useCallback, useRef } from 'react'; import { useTranslation } from 'react-i18next'; @@ -18,7 +18,7 @@ type UseSaveLibraryWorkflowReturn = { type UseSaveLibraryWorkflow = () => UseSaveLibraryWorkflowReturn; -export const isWorkflowWithID = (workflow: WorkflowV2): workflow is O.Required => +export const isWorkflowWithID = (workflow: WorkflowV3): workflow is O.Required => Boolean(workflow.id); export const useSaveLibraryWorkflow: UseSaveLibraryWorkflow = () => {