From 811d0da0f0fb6cceda0eb5b148bdb2bd93852b0d Mon Sep 17 00:00:00 2001 From: Shukri Date: Sat, 18 May 2024 03:37:34 +0200 Subject: [PATCH 001/207] docs: fix link to. install reqs --- docs/installation/020_INSTALL_MANUAL.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/installation/020_INSTALL_MANUAL.md b/docs/installation/020_INSTALL_MANUAL.md index 36859a5795..0d7150387c 100644 --- a/docs/installation/020_INSTALL_MANUAL.md +++ b/docs/installation/020_INSTALL_MANUAL.md @@ -10,7 +10,7 @@ InvokeAI is distributed as a python package on PyPI, installable with `pip`. The ### Requirements -Before you start, go through the [installation requirements]. +Before you start, go through the [installation requirements](./INSTALL_REQUIREMENTS.md). ### Installation Walkthrough From a5d08c981b800f565ea4cc7e9137466868338142 Mon Sep 17 00:00:00 2001 From: Shukri Date: Sat, 18 May 2024 03:53:30 +0200 Subject: [PATCH 002/207] docs: fix typo in --root arg of invokeai-web --- docs/installation/020_INSTALL_MANUAL.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/installation/020_INSTALL_MANUAL.md b/docs/installation/020_INSTALL_MANUAL.md index 0d7150387c..a3868c8fcb 100644 --- a/docs/installation/020_INSTALL_MANUAL.md +++ b/docs/installation/020_INSTALL_MANUAL.md @@ -116,4 +116,4 @@ Before you start, go through the [installation requirements](./INSTALL_REQUIREME !!! warning - If the virtual environment is _not_ inside the root directory, then you _must_ specify the path to the root directory with `--root_dir \path\to\invokeai` or the `INVOKEAI_ROOT` environment variable. + If the virtual environment is _not_ inside the root directory, then you _must_ specify the path to the root directory with `--root \path\to\invokeai` or the `INVOKEAI_ROOT` environment variable. From e8387d75239f7ee6e52aa9cf5e978d01277f1eba Mon Sep 17 00:00:00 2001 From: Shukri Date: Sat, 18 May 2024 03:55:49 +0200 Subject: [PATCH 003/207] docs: add link to tool on pytorch website --- docs/installation/020_INSTALL_MANUAL.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/installation/020_INSTALL_MANUAL.md b/docs/installation/020_INSTALL_MANUAL.md index a3868c8fcb..f589848b05 100644 --- a/docs/installation/020_INSTALL_MANUAL.md +++ b/docs/installation/020_INSTALL_MANUAL.md @@ -79,7 +79,7 @@ Before you start, go through the [installation requirements](./INSTALL_REQUIREME 1. Install the InvokeAI Package. The base command is `pip install InvokeAI --use-pep517`, but you may need to change this depending on your system and the desired features. - - You may need to provide an [extra index URL]. Select your platform configuration using [this tool on the PyTorch website]. Copy the `--extra-index-url` string from this and append it to your install command. + - You may need to provide an [extra index URL]. Select your platform configuration using [this tool on the PyTorch website](https://pytorch.org/get-started/locally/). Copy the `--extra-index-url` string from this and append it to your install command. !!! example "Install with an extra index URL" From 124d34a8cc9155db06052f1f3e806088a2e7e3d1 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 18 May 2024 14:25:19 +1000 Subject: [PATCH 004/207] docs: add link for `--extra-index-url` --- docs/installation/020_INSTALL_MANUAL.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/installation/020_INSTALL_MANUAL.md b/docs/installation/020_INSTALL_MANUAL.md index f589848b05..059834eb45 100644 --- a/docs/installation/020_INSTALL_MANUAL.md +++ b/docs/installation/020_INSTALL_MANUAL.md @@ -79,7 +79,7 @@ Before you start, go through the [installation requirements](./INSTALL_REQUIREME 1. Install the InvokeAI Package. The base command is `pip install InvokeAI --use-pep517`, but you may need to change this depending on your system and the desired features. - - You may need to provide an [extra index URL]. Select your platform configuration using [this tool on the PyTorch website](https://pytorch.org/get-started/locally/). Copy the `--extra-index-url` string from this and append it to your install command. + - You may need to provide an [extra index URL](https://pip.pypa.io/en/stable/cli/pip_install/#cmdoption-extra-index-url). Select your platform configuration using [this tool on the PyTorch website](https://pytorch.org/get-started/locally/). Copy the `--extra-index-url` string from this and append it to your install command. !!! example "Install with an extra index URL" From 5127fd6320c44ef77f0bb8b074b5468ae4470141 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 18 May 2024 14:16:58 +1000 Subject: [PATCH 005/207] fix(ui): control adapter autoprocess jank If you change the control model and the new model has the same default processor, we would still re-process the image, even if there was no need to do so. With this change, if the image and processor config are unchanged, we bail out. --- .../listeners/controlAdapterPreprocessor.ts | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts index 2a59cc0317..ad464249df 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts @@ -16,6 +16,7 @@ import { CA_PROCESSOR_DATA } from 'features/controlLayers/util/controlAdapters'; import { isImageOutput } from 'features/nodes/types/common'; import { addToast } from 'features/system/store/systemSlice'; import { t } from 'i18next'; +import { isEqual } from 'lodash-es'; import { getImageDTO } from 'services/api/endpoints/images'; import { queueApi } from 'services/api/endpoints/queue'; import type { BatchConfig } from 'services/api/types'; @@ -47,8 +48,10 @@ const cancelProcessorBatch = async (dispatch: AppDispatch, layerId: string, batc export const addControlAdapterPreprocessor = (startAppListening: AppStartListening) => { startAppListening({ matcher, - effect: async (action, { dispatch, getState, cancelActiveListeners, delay, take, signal }) => { + effect: async (action, { dispatch, getState, getOriginalState, cancelActiveListeners, delay, take, signal }) => { const layerId = caLayerRecalled.match(action) ? action.payload.id : action.payload.layerId; + const state = getState(); + const originalState = getOriginalState(); // Cancel any in-progress instances of this listener cancelActiveListeners(); @@ -57,18 +60,27 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni // Delay before starting actual work await delay(DEBOUNCE_MS); - // Double-check that we are still eligible for processing - const state = getState(); const layer = state.controlLayers.present.layers.filter(isControlAdapterLayer).find((l) => l.id === layerId); - // If we have no image or there is no processor config, bail if (!layer) { return; } + // We should only process if the processor settings or image have changed + const originalLayer = originalState.controlLayers.present.layers + .filter(isControlAdapterLayer) + .find((l) => l.id === layerId); + const originalImage = originalLayer?.controlAdapter.image; + const originalConfig = originalLayer?.controlAdapter.processorConfig; + const image = layer.controlAdapter.image; const config = layer.controlAdapter.processorConfig; + if (isEqual(config, originalConfig) && isEqual(image, originalImage)) { + // Neither config nor image have changed, we can bail + return; + } + if (!image || !config) { // The user has reset the image or config, so we should clear the processed image dispatch(caLayerProcessedImageChanged({ layerId, imageDTO: null })); From af3fd26d4e74b4aa005ab363c0badd13bd30a616 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 18 May 2024 14:19:54 +1000 Subject: [PATCH 006/207] fix(ui): bug when clearing processor When clearing the processor config, we shouldn't re-process the image. This logic wasn't handled correctly, but coincidentally the bug didn't cause a user-facing issue. Without a config, we had a runtime error when trying to build the node for the processor graph and the listener failed. So while we didn't re-process the image, it was because there was an error, not because the logic was correct. Fix this by bailing if there is no image or config. --- .../listeners/controlAdapterPreprocessor.ts | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts index ad464249df..3dc8db93f9 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlAdapterPreprocessor.ts @@ -82,8 +82,11 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni } if (!image || !config) { - // The user has reset the image or config, so we should clear the processed image + // - If we have no image, we have nothing to process + // - If we have no processor config, we have nothing to process + // Clear the processed image and bail dispatch(caLayerProcessedImageChanged({ layerId, imageDTO: null })); + return; } // At this point, the user has stopped fiddling with the processor settings and there is a processor selected. @@ -93,8 +96,8 @@ export const addControlAdapterPreprocessor = (startAppListening: AppStartListeni cancelProcessorBatch(dispatch, layerId, layer.controlAdapter.processorPendingBatchId); } - // @ts-expect-error: TS isn't able to narrow the typing of buildNode and `config` will error... - const processorNode = CA_PROCESSOR_DATA[config.type].buildNode(image, config); + // TODO(psyche): I can't get TS to be happy, it thinkgs `config` is `never` but it should be inferred from the generic... I'll just cast it for now + const processorNode = CA_PROCESSOR_DATA[config.type].buildNode(image, config as never); const enqueueBatchArg: BatchConfig = { prepend: true, batch: { From 85a5a7c47a76a75ba3244bf449084886b03c414e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 17 May 2024 20:08:32 +1000 Subject: [PATCH 007/207] feat(ui): add `originalType` to FieldType, improved connection validation We now keep track of the original field type, derived from the python type annotation in addition to the override type provided by `ui_type`. This makes `ui_type` work more like it sound like it should work - change the UI input component only. Connection validation is extend to also check the original types. If there is any match between two fields' "final" or original types, we consider the connection valid.This change is backwards-compatible; there is no workflow migration needed. --- .../nodes/hooks/useIsValidConnection.ts | 5 +- .../store/util/findConnectionToValidHandle.ts | 6 +- .../util/makeIsConnectionValidSelector.ts | 5 +- .../util/validateSourceAndTargetTypes.ts | 26 +- .../web/src/features/nodes/types/field.ts | 231 ++++++++++++------ .../util/schema/buildFieldInputTemplate.ts | 188 ++++---------- .../util/schema/buildFieldOutputTemplate.ts | 4 +- .../nodes/util/schema/parseFieldType.test.ts | 6 +- .../nodes/util/schema/parseFieldType.ts | 25 +- .../nodes/util/schema/parseSchema.test.ts | 152 ++++++++++++ .../features/nodes/util/schema/parseSchema.ts | 200 ++++++++------- 11 files changed, 502 insertions(+), 346 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts index 00b4b40176..14a7a728e0 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts @@ -4,9 +4,8 @@ import { useAppSelector, useAppStore } from 'app/store/storeHooks'; import { $templates } from 'features/nodes/store/nodesSlice'; import { getIsGraphAcyclic } from 'features/nodes/store/util/getIsGraphAcyclic'; import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector'; -import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes'; +import { areTypesEqual, validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes'; import type { InvocationNodeData } from 'features/nodes/types/invocation'; -import { isEqual } from 'lodash-es'; import { useCallback } from 'react'; import type { Connection, Node } from 'reactflow'; @@ -70,7 +69,7 @@ export const useIsValidConnection = () => { // Collect nodes shouldn't mix and match field types const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); if (collectItemType) { - return isEqual(sourceFieldTemplate.type, collectItemType); + return areTypesEqual(sourceFieldTemplate.type, collectItemType); } } diff --git a/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts b/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts index 1f33c52371..e0411ee67e 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts @@ -1,12 +1,12 @@ import type { PendingConnection, Templates } from 'features/nodes/store/types'; import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector'; import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation'; -import { differenceWith, isEqual, map } from 'lodash-es'; +import { differenceWith, map } from 'lodash-es'; import type { Connection } from 'reactflow'; import { assert } from 'tsafe'; import { getIsGraphAcyclic } from './getIsGraphAcyclic'; -import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes'; +import { areTypesEqual, validateSourceAndTargetTypes } from './validateSourceAndTargetTypes'; export const getFirstValidConnection = ( templates: Templates, @@ -83,7 +83,7 @@ export const getFirstValidConnection = ( // Narrow candidates to same field type as already is connected to the collect node const collectItemType = getCollectItemType(templates, nodes, edges, pendingConnection.node.id); if (collectItemType) { - candidateFields = candidateFields.filter((field) => isEqual(field.type, collectItemType)); + candidateFields = candidateFields.filter((field) => areTypesEqual(field.type, collectItemType)); } } const candidateField = candidateFields.find((field) => { diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts index 90e75e0d87..e7f659508f 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts @@ -4,12 +4,11 @@ import type { PendingConnection, Templates } from 'features/nodes/store/types'; import type { FieldType } from 'features/nodes/types/field'; import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation'; import i18n from 'i18next'; -import { isEqual } from 'lodash-es'; import type { HandleType } from 'reactflow'; import { assert } from 'tsafe'; import { getIsGraphAcyclic } from './getIsGraphAcyclic'; -import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes'; +import { areTypesEqual, validateSourceAndTargetTypes } from './validateSourceAndTargetTypes'; export const getCollectItemType = ( templates: Templates, @@ -111,7 +110,7 @@ export const makeConnectionErrorSelector = ( // Collect nodes shouldn't mix and match field types const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); if (collectItemType) { - if (!isEqual(sourceType, collectItemType)) { + if (!areTypesEqual(sourceType, collectItemType)) { return i18n.t('nodes.cannotMixAndMatchCollectionItemTypes'); } } diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts index 3cbfb5b89c..cc5a6bb596 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts @@ -1,5 +1,25 @@ -import type { FieldType } from 'features/nodes/types/field'; -import { isEqual } from 'lodash-es'; +import { type FieldType, isStatefulFieldType } from 'features/nodes/types/field'; +import { isEqual, omit } from 'lodash-es'; + +export const areTypesEqual = (sourceType: FieldType, targetType: FieldType) => { + const _sourceType = isStatefulFieldType(sourceType) ? omit(sourceType, 'originalType') : sourceType; + const _targetType = isStatefulFieldType(targetType) ? omit(targetType, 'originalType') : targetType; + const _sourceTypeOriginal = isStatefulFieldType(sourceType) ? sourceType.originalType : sourceType; + const _targetTypeOriginal = isStatefulFieldType(targetType) ? targetType.originalType : targetType; + if (isEqual(_sourceType, _targetType)) { + return true; + } + if (isEqual(_sourceType, _targetTypeOriginal)) { + return true; + } + if (isEqual(_sourceTypeOriginal, _targetType)) { + return true; + } + if (isEqual(_sourceTypeOriginal, _targetTypeOriginal)) { + return true; + } + return false; +}; /** * Validates that the source and target types are compatible for a connection. @@ -15,7 +35,7 @@ export const validateSourceAndTargetTypes = (sourceType: FieldType, targetType: return false; } - if (isEqual(sourceType, targetType)) { + if (areTypesEqual(sourceType, targetType)) { return true; } diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts index 87b0839bc3..37e2a26397 100644 --- a/invokeai/frontend/web/src/features/nodes/types/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -66,16 +66,114 @@ export const zFieldIdentifier = z.object({ export type FieldIdentifier = z.infer; // #endregion -// #region IntegerField +// #region Field Types +const zStatelessFieldType = zFieldTypeBase.extend({ + name: z.string().min(1), // stateless --> we accept the field's name as the type +}); const zIntegerFieldType = zFieldTypeBase.extend({ name: z.literal('IntegerField'), + originalType: zStatelessFieldType.optional(), }); +const zFloatFieldType = zFieldTypeBase.extend({ + name: z.literal('FloatField'), + originalType: zStatelessFieldType.optional(), +}); +const zStringFieldType = zFieldTypeBase.extend({ + name: z.literal('StringField'), + originalType: zStatelessFieldType.optional(), +}); +const zBooleanFieldType = zFieldTypeBase.extend({ + name: z.literal('BooleanField'), + originalType: zStatelessFieldType.optional(), +}); +const zEnumFieldType = zFieldTypeBase.extend({ + name: z.literal('EnumField'), + originalType: zStatelessFieldType.optional(), +}); +const zImageFieldType = zFieldTypeBase.extend({ + name: z.literal('ImageField'), + originalType: zStatelessFieldType.optional(), +}); +const zBoardFieldType = zFieldTypeBase.extend({ + name: z.literal('BoardField'), + originalType: zStatelessFieldType.optional(), +}); +const zColorFieldType = zFieldTypeBase.extend({ + name: z.literal('ColorField'), + originalType: zStatelessFieldType.optional(), +}); +const zMainModelFieldType = zFieldTypeBase.extend({ + name: z.literal('MainModelField'), + originalType: zStatelessFieldType.optional(), +}); +const zSDXLMainModelFieldType = zFieldTypeBase.extend({ + name: z.literal('SDXLMainModelField'), + originalType: zStatelessFieldType.optional(), +}); +const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({ + name: z.literal('SDXLRefinerModelField'), + originalType: zStatelessFieldType.optional(), +}); +const zVAEModelFieldType = zFieldTypeBase.extend({ + name: z.literal('VAEModelField'), + originalType: zStatelessFieldType.optional(), +}); +const zLoRAModelFieldType = zFieldTypeBase.extend({ + name: z.literal('LoRAModelField'), + originalType: zStatelessFieldType.optional(), +}); +const zControlNetModelFieldType = zFieldTypeBase.extend({ + name: z.literal('ControlNetModelField'), + originalType: zStatelessFieldType.optional(), +}); +const zIPAdapterModelFieldType = zFieldTypeBase.extend({ + name: z.literal('IPAdapterModelField'), + originalType: zStatelessFieldType.optional(), +}); +const zT2IAdapterModelFieldType = zFieldTypeBase.extend({ + name: z.literal('T2IAdapterModelField'), + originalType: zStatelessFieldType.optional(), +}); +const zSchedulerFieldType = zFieldTypeBase.extend({ + name: z.literal('SchedulerField'), + originalType: zStatelessFieldType.optional(), +}); +const zStatefulFieldType = z.union([ + zIntegerFieldType, + zFloatFieldType, + zStringFieldType, + zBooleanFieldType, + zEnumFieldType, + zImageFieldType, + zBoardFieldType, + zMainModelFieldType, + zSDXLMainModelFieldType, + zSDXLRefinerModelFieldType, + zVAEModelFieldType, + zLoRAModelFieldType, + zControlNetModelFieldType, + zIPAdapterModelFieldType, + zT2IAdapterModelFieldType, + zColorFieldType, + zSchedulerFieldType, +]); +export type StatefulFieldType = z.infer; +const statefulFieldTypeNames = zStatefulFieldType.options.map((o) => o.shape.name.value); +export const isStatefulFieldType = (fieldType: FieldType): fieldType is StatefulFieldType => + statefulFieldTypeNames.includes(fieldType.name as any); +const zFieldType = z.union([zStatefulFieldType, zStatelessFieldType]); +export type FieldType = z.infer; +// #endregion + +// #region IntegerField + export const zIntegerFieldValue = z.number().int(); const zIntegerFieldInputInstance = zFieldInputInstanceBase.extend({ value: zIntegerFieldValue, }); const zIntegerFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zIntegerFieldType, + originalType: zFieldType.optional(), default: zIntegerFieldValue, multipleOf: z.number().int().optional(), maximum: z.number().int().optional(), @@ -85,6 +183,7 @@ const zIntegerFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zIntegerFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zIntegerFieldType, + originalType: zFieldType.optional(), }); export type IntegerFieldValue = z.infer; export type IntegerFieldInputInstance = z.infer; @@ -96,15 +195,14 @@ export const isIntegerFieldInputTemplate = (val: unknown): val is IntegerFieldIn // #endregion // #region FloatField -const zFloatFieldType = zFieldTypeBase.extend({ - name: z.literal('FloatField'), -}); + export const zFloatFieldValue = z.number(); const zFloatFieldInputInstance = zFieldInputInstanceBase.extend({ value: zFloatFieldValue, }); const zFloatFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zFloatFieldType, + originalType: zFieldType.optional(), default: zFloatFieldValue, multipleOf: z.number().optional(), maximum: z.number().optional(), @@ -114,6 +212,7 @@ const zFloatFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zFloatFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zFloatFieldType, + originalType: zFieldType.optional(), }); export type FloatFieldValue = z.infer; export type FloatFieldInputInstance = z.infer; @@ -125,21 +224,21 @@ export const isFloatFieldInputTemplate = (val: unknown): val is FloatFieldInputT // #endregion // #region StringField -const zStringFieldType = zFieldTypeBase.extend({ - name: z.literal('StringField'), -}); + export const zStringFieldValue = z.string(); const zStringFieldInputInstance = zFieldInputInstanceBase.extend({ value: zStringFieldValue, }); const zStringFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zStringFieldType, + originalType: zFieldType.optional(), default: zStringFieldValue, maxLength: z.number().int().optional(), minLength: z.number().int().optional(), }); const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zStringFieldType, + originalType: zFieldType.optional(), }); export type StringFieldValue = z.infer; @@ -152,19 +251,19 @@ export const isStringFieldInputTemplate = (val: unknown): val is StringFieldInpu // #endregion // #region BooleanField -const zBooleanFieldType = zFieldTypeBase.extend({ - name: z.literal('BooleanField'), -}); + export const zBooleanFieldValue = z.boolean(); const zBooleanFieldInputInstance = zFieldInputInstanceBase.extend({ value: zBooleanFieldValue, }); const zBooleanFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zBooleanFieldType, + originalType: zFieldType.optional(), default: zBooleanFieldValue, }); const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zBooleanFieldType, + originalType: zFieldType.optional(), }); export type BooleanFieldValue = z.infer; export type BooleanFieldInputInstance = z.infer; @@ -176,21 +275,21 @@ export const isBooleanFieldInputTemplate = (val: unknown): val is BooleanFieldIn // #endregion // #region EnumField -const zEnumFieldType = zFieldTypeBase.extend({ - name: z.literal('EnumField'), -}); + export const zEnumFieldValue = z.string(); const zEnumFieldInputInstance = zFieldInputInstanceBase.extend({ value: zEnumFieldValue, }); const zEnumFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zEnumFieldType, + originalType: zFieldType.optional(), default: zEnumFieldValue, options: z.array(z.string()), labels: z.record(z.string()).optional(), }); const zEnumFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zEnumFieldType, + originalType: zFieldType.optional(), }); export type EnumFieldValue = z.infer; export type EnumFieldInputInstance = z.infer; @@ -202,19 +301,19 @@ export const isEnumFieldInputTemplate = (val: unknown): val is EnumFieldInputTem // #endregion // #region ImageField -const zImageFieldType = zFieldTypeBase.extend({ - name: z.literal('ImageField'), -}); + export const zImageFieldValue = zImageField.optional(); const zImageFieldInputInstance = zFieldInputInstanceBase.extend({ value: zImageFieldValue, }); const zImageFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zImageFieldType, + originalType: zFieldType.optional(), default: zImageFieldValue, }); const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zImageFieldType, + originalType: zFieldType.optional(), }); export type ImageFieldValue = z.infer; export type ImageFieldInputInstance = z.infer; @@ -226,19 +325,19 @@ export const isImageFieldInputTemplate = (val: unknown): val is ImageFieldInputT // #endregion // #region BoardField -const zBoardFieldType = zFieldTypeBase.extend({ - name: z.literal('BoardField'), -}); + export const zBoardFieldValue = zBoardField.optional(); const zBoardFieldInputInstance = zFieldInputInstanceBase.extend({ value: zBoardFieldValue, }); const zBoardFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zBoardFieldType, + originalType: zFieldType.optional(), default: zBoardFieldValue, }); const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zBoardFieldType, + originalType: zFieldType.optional(), }); export type BoardFieldValue = z.infer; export type BoardFieldInputInstance = z.infer; @@ -250,19 +349,19 @@ export const isBoardFieldInputTemplate = (val: unknown): val is BoardFieldInputT // #endregion // #region ColorField -const zColorFieldType = zFieldTypeBase.extend({ - name: z.literal('ColorField'), -}); + export const zColorFieldValue = zColorField.optional(); const zColorFieldInputInstance = zFieldInputInstanceBase.extend({ value: zColorFieldValue, }); const zColorFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zColorFieldType, + originalType: zFieldType.optional(), default: zColorFieldValue, }); const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zColorFieldType, + originalType: zFieldType.optional(), }); export type ColorFieldValue = z.infer; export type ColorFieldInputInstance = z.infer; @@ -274,19 +373,19 @@ export const isColorFieldInputTemplate = (val: unknown): val is ColorFieldInputT // #endregion // #region MainModelField -const zMainModelFieldType = zFieldTypeBase.extend({ - name: z.literal('MainModelField'), -}); + export const zMainModelFieldValue = zModelIdentifierField.optional(); const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ value: zMainModelFieldValue, }); const zMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zMainModelFieldType, + originalType: zFieldType.optional(), default: zMainModelFieldValue, }); const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zMainModelFieldType, + originalType: zFieldType.optional(), }); export type MainModelFieldValue = z.infer; export type MainModelFieldInputInstance = z.infer; @@ -298,19 +397,19 @@ export const isMainModelFieldInputTemplate = (val: unknown): val is MainModelFie // #endregion // #region SDXLMainModelField -const zSDXLMainModelFieldType = zFieldTypeBase.extend({ - name: z.literal('SDXLMainModelField'), -}); + const zSDXLMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only. const zSDXLMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ value: zSDXLMainModelFieldValue, }); const zSDXLMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zSDXLMainModelFieldType, + originalType: zFieldType.optional(), default: zSDXLMainModelFieldValue, }); const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zSDXLMainModelFieldType, + originalType: zFieldType.optional(), }); export type SDXLMainModelFieldInputInstance = z.infer; export type SDXLMainModelFieldInputTemplate = z.infer; @@ -321,9 +420,7 @@ export const isSDXLMainModelFieldInputTemplate = (val: unknown): val is SDXLMain // #endregion // #region SDXLRefinerModelField -const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({ - name: z.literal('SDXLRefinerModelField'), -}); + /** @alias */ // tells knip to ignore this duplicate export export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only. const zSDXLRefinerModelFieldInputInstance = zFieldInputInstanceBase.extend({ @@ -331,10 +428,12 @@ const zSDXLRefinerModelFieldInputInstance = zFieldInputInstanceBase.extend({ }); const zSDXLRefinerModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zSDXLRefinerModelFieldType, + originalType: zFieldType.optional(), default: zSDXLRefinerModelFieldValue, }); const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zSDXLRefinerModelFieldType, + originalType: zFieldType.optional(), }); export type SDXLRefinerModelFieldValue = z.infer; export type SDXLRefinerModelFieldInputInstance = z.infer; @@ -346,19 +445,19 @@ export const isSDXLRefinerModelFieldInputTemplate = (val: unknown): val is SDXLR // #endregion // #region VAEModelField -const zVAEModelFieldType = zFieldTypeBase.extend({ - name: z.literal('VAEModelField'), -}); + export const zVAEModelFieldValue = zModelIdentifierField.optional(); const zVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({ value: zVAEModelFieldValue, }); const zVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zVAEModelFieldType, + originalType: zFieldType.optional(), default: zVAEModelFieldValue, }); const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zVAEModelFieldType, + originalType: zFieldType.optional(), }); export type VAEModelFieldValue = z.infer; export type VAEModelFieldInputInstance = z.infer; @@ -370,19 +469,19 @@ export const isVAEModelFieldInputTemplate = (val: unknown): val is VAEModelField // #endregion // #region LoRAModelField -const zLoRAModelFieldType = zFieldTypeBase.extend({ - name: z.literal('LoRAModelField'), -}); + export const zLoRAModelFieldValue = zModelIdentifierField.optional(); const zLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({ value: zLoRAModelFieldValue, }); const zLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zLoRAModelFieldType, + originalType: zFieldType.optional(), default: zLoRAModelFieldValue, }); const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zLoRAModelFieldType, + originalType: zFieldType.optional(), }); export type LoRAModelFieldValue = z.infer; export type LoRAModelFieldInputInstance = z.infer; @@ -394,19 +493,19 @@ export const isLoRAModelFieldInputTemplate = (val: unknown): val is LoRAModelFie // #endregion // #region ControlNetModelField -const zControlNetModelFieldType = zFieldTypeBase.extend({ - name: z.literal('ControlNetModelField'), -}); + export const zControlNetModelFieldValue = zModelIdentifierField.optional(); const zControlNetModelFieldInputInstance = zFieldInputInstanceBase.extend({ value: zControlNetModelFieldValue, }); const zControlNetModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zControlNetModelFieldType, + originalType: zFieldType.optional(), default: zControlNetModelFieldValue, }); const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zControlNetModelFieldType, + originalType: zFieldType.optional(), }); export type ControlNetModelFieldValue = z.infer; export type ControlNetModelFieldInputInstance = z.infer; @@ -418,19 +517,19 @@ export const isControlNetModelFieldInputTemplate = (val: unknown): val is Contro // #endregion // #region IPAdapterModelField -const zIPAdapterModelFieldType = zFieldTypeBase.extend({ - name: z.literal('IPAdapterModelField'), -}); + export const zIPAdapterModelFieldValue = zModelIdentifierField.optional(); const zIPAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({ value: zIPAdapterModelFieldValue, }); const zIPAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zIPAdapterModelFieldType, + originalType: zFieldType.optional(), default: zIPAdapterModelFieldValue, }); const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zIPAdapterModelFieldType, + originalType: zFieldType.optional(), }); export type IPAdapterModelFieldValue = z.infer; export type IPAdapterModelFieldInputInstance = z.infer; @@ -442,19 +541,19 @@ export const isIPAdapterModelFieldInputTemplate = (val: unknown): val is IPAdapt // #endregion // #region T2IAdapterField -const zT2IAdapterModelFieldType = zFieldTypeBase.extend({ - name: z.literal('T2IAdapterModelField'), -}); + export const zT2IAdapterModelFieldValue = zModelIdentifierField.optional(); const zT2IAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend({ value: zT2IAdapterModelFieldValue, }); const zT2IAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zT2IAdapterModelFieldType, + originalType: zFieldType.optional(), default: zT2IAdapterModelFieldValue, }); const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zT2IAdapterModelFieldType, + originalType: zFieldType.optional(), }); export type T2IAdapterModelFieldValue = z.infer; export type T2IAdapterModelFieldInputInstance = z.infer; @@ -466,19 +565,19 @@ export const isT2IAdapterModelFieldInputTemplate = (val: unknown): val is T2IAda // #endregion // #region SchedulerField -const zSchedulerFieldType = zFieldTypeBase.extend({ - name: z.literal('SchedulerField'), -}); + export const zSchedulerFieldValue = zSchedulerField.optional(); const zSchedulerFieldInputInstance = zFieldInputInstanceBase.extend({ value: zSchedulerFieldValue, }); const zSchedulerFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zSchedulerFieldType, + originalType: zFieldType.optional(), default: zSchedulerFieldValue, }); const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zSchedulerFieldType, + originalType: zFieldType.optional(), }); export type SchedulerFieldValue = z.infer; export type SchedulerFieldInputInstance = z.infer; @@ -501,20 +600,20 @@ export const isSchedulerFieldInputTemplate = (val: unknown): val is SchedulerFie * - Reserved fields like IsIntermediate * - Any other field we don't have full-on schemas for */ -const zStatelessFieldType = zFieldTypeBase.extend({ - name: z.string().min(1), // stateless --> we accept the field's name as the type -}); + const zStatelessFieldValue = z.undefined().catch(undefined); // stateless --> no value, but making this z.never() introduces a lot of extra TS fanagling const zStatelessFieldInputInstance = zFieldInputInstanceBase.extend({ value: zStatelessFieldValue, }); const zStatelessFieldInputTemplate = zFieldInputTemplateBase.extend({ type: zStatelessFieldType, + originalType: zFieldType.optional(), default: zStatelessFieldValue, input: z.literal('connection'), // stateless --> only accepts connection inputs }); const zStatelessFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zStatelessFieldType, + originalType: zFieldType.optional(), }); export type StatelessFieldInputTemplate = z.infer; @@ -535,34 +634,6 @@ export type StatelessFieldInputTemplate = z.infer; -export const isStatefulFieldType = (val: unknown): val is StatefulFieldType => - zStatefulFieldType.safeParse(val).success; - -const zFieldType = z.union([zStatefulFieldType, zStatelessFieldType]); -export type FieldType = z.infer; -// #endregion - // #region StatefulFieldValue & FieldValue export const zStatefulFieldValue = z.union([ zIntegerFieldValue, diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts index 3e8278ea6a..6b4c4d8b29 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts @@ -30,26 +30,16 @@ import { isNumber, startCase } from 'lodash-es'; // eslint-disable-next-line @typescript-eslint/no-explicit-any type FieldInputTemplateBuilder = // valid `any`! - (arg: { - schemaObject: InvocationFieldSchema; - baseField: Omit; - isCollection: boolean; - isCollectionOrScalar: boolean; - }) => T; + (arg: { schemaObject: InvocationFieldSchema; baseField: Omit; fieldType: T['type'] }) => T; const buildIntegerFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: IntegerFieldInputTemplate = { ...baseField, - type: { - name: 'IntegerField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? 0, }; @@ -79,16 +69,11 @@ const buildIntegerFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: FloatFieldInputTemplate = { ...baseField, - type: { - name: 'FloatField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? 0, }; @@ -118,16 +103,11 @@ const buildFloatFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: StringFieldInputTemplate = { ...baseField, - type: { - name: 'StringField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? '', }; @@ -145,16 +125,11 @@ const buildStringFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: BooleanFieldInputTemplate = { ...baseField, - type: { - name: 'BooleanField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? false, }; @@ -164,16 +139,11 @@ const buildBooleanFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: MainModelFieldInputTemplate = { ...baseField, - type: { - name: 'MainModelField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? undefined, }; @@ -183,16 +153,11 @@ const buildMainModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: SDXLMainModelFieldInputTemplate = { ...baseField, - type: { - name: 'SDXLMainModelField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? undefined, }; @@ -202,16 +167,11 @@ const buildSDXLMainModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: SDXLRefinerModelFieldInputTemplate = { ...baseField, - type: { - name: 'SDXLRefinerModelField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? undefined, }; @@ -221,16 +181,11 @@ const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: VAEModelFieldInputTemplate = { ...baseField, - type: { - name: 'VAEModelField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? undefined, }; @@ -240,16 +195,11 @@ const buildVAEModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: LoRAModelFieldInputTemplate = { ...baseField, - type: { - name: 'LoRAModelField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? undefined, }; @@ -259,16 +209,11 @@ const buildLoRAModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: ControlNetModelFieldInputTemplate = { ...baseField, - type: { - name: 'ControlNetModelField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? undefined, }; @@ -278,16 +223,11 @@ const buildControlNetModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: IPAdapterModelFieldInputTemplate = { ...baseField, - type: { - name: 'IPAdapterModelField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? undefined, }; @@ -297,16 +237,11 @@ const buildIPAdapterModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: T2IAdapterModelFieldInputTemplate = { ...baseField, - type: { - name: 'T2IAdapterModelField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? undefined, }; @@ -316,16 +251,11 @@ const buildT2IAdapterModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: BoardFieldInputTemplate = { ...baseField, - type: { - name: 'BoardField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? undefined, }; @@ -335,16 +265,11 @@ const buildBoardFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: ImageFieldInputTemplate = { ...baseField, - type: { - name: 'ImageField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? undefined, }; @@ -354,8 +279,7 @@ const buildImageFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { let options: EnumFieldInputTemplate['options'] = []; if (schemaObject.anyOf) { @@ -383,11 +307,7 @@ const buildEnumFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: ColorFieldInputTemplate = { ...baseField, - type: { - name: 'ColorField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? { r: 127, g: 127, b: 127, a: 255 }, }; @@ -418,16 +333,11 @@ const buildColorFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, - isCollection, - isCollectionOrScalar, + fieldType, }) => { const template: SchedulerFieldInputTemplate = { ...baseField, - type: { - name: 'SchedulerField', - isCollection, - isCollectionOrScalar, - }, + type: fieldType, default: schemaObject.default ?? 'euler', }; @@ -452,7 +362,7 @@ export const TEMPLATE_BUILDER_MAP: Record connection only inputs - type: fieldType, - default: undefined, // stateless --> no default value - }; - return template; + return template; + } else { + // This is a StatelessField, create it directly. + const template: StatelessFieldInputTemplate = { + ...baseField, + input: 'connection', // stateless --> connection only inputs + type: fieldType, + default: undefined, // stateless --> no default value + }; + + return template; + } }; diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldOutputTemplate.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldOutputTemplate.ts index 8c789493ad..abbe2c3488 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldOutputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldOutputTemplate.ts @@ -9,7 +9,7 @@ export const buildFieldOutputTemplate = ( ): FieldOutputTemplate => { const { title, description, ui_hidden, ui_type, ui_order } = fieldSchema; - const fieldOutputTemplate: FieldOutputTemplate = { + const template: FieldOutputTemplate = { fieldKind: 'output', name: fieldName, title: title ?? (fieldName ? startCase(fieldName) : ''), @@ -20,5 +20,5 @@ export const buildFieldOutputTemplate = ( ui_order, }; - return fieldOutputTemplate; + return template; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.test.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.test.ts index d7011ad6f8..cc12b45aa6 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.test.ts @@ -244,7 +244,7 @@ const specialCases: ParseFieldTypeTestCase[] = [ enum: ['ddim', 'ddpm', 'deis'], ui_type: 'SchedulerField', }, - expected: { name: 'SchedulerField', isCollection: false, isCollectionOrScalar: false }, + expected: { name: 'EnumField', isCollection: false, isCollectionOrScalar: false }, }, { name: 'Explicit ui_type (AnyField)', @@ -253,7 +253,7 @@ const specialCases: ParseFieldTypeTestCase[] = [ enum: ['ddim', 'ddpm', 'deis'], ui_type: 'AnyField', }, - expected: { name: 'AnyField', isCollection: false, isCollectionOrScalar: false }, + expected: { name: 'EnumField', isCollection: false, isCollectionOrScalar: false }, }, { name: 'Explicit ui_type (CollectionField)', @@ -262,7 +262,7 @@ const specialCases: ParseFieldTypeTestCase[] = [ enum: ['ddim', 'ddpm', 'deis'], ui_type: 'CollectionField', }, - expected: { name: 'CollectionField', isCollection: true, isCollectionOrScalar: false }, + expected: { name: 'EnumField', isCollection: false, isCollectionOrScalar: false }, }, ]; diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts index 13da6b3831..6f6ecaa5bb 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts @@ -6,14 +6,8 @@ import { UnsupportedUnionError, } from 'features/nodes/types/error'; import type { FieldType } from 'features/nodes/types/field'; -import type { InvocationFieldSchema, OpenAPIV3_1SchemaOrRef } from 'features/nodes/types/openapi'; -import { - isArraySchemaObject, - isInvocationFieldSchema, - isNonArraySchemaObject, - isRefObject, - isSchemaObject, -} from 'features/nodes/types/openapi'; +import type { OpenAPIV3_1SchemaOrRef } from 'features/nodes/types/openapi'; +import { isArraySchemaObject, isNonArraySchemaObject, isRefObject, isSchemaObject } from 'features/nodes/types/openapi'; import { t } from 'i18next'; import { isArray } from 'lodash-es'; import type { OpenAPIV3_1 } from 'openapi-types'; @@ -35,7 +29,7 @@ const OPENAPI_TO_FIELD_TYPE_MAP: Record = { boolean: 'BooleanField', }; -const isCollectionFieldType = (fieldType: string) => { +export const isCollectionFieldType = (fieldType: string) => { /** * CollectionField is `list[Any]` in the pydantic schema, but we need to distinguish between * it and other `list[Any]` fields, due to its special internal handling. @@ -48,18 +42,7 @@ const isCollectionFieldType = (fieldType: string) => { return false; }; -export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef | InvocationFieldSchema): FieldType => { - if (isInvocationFieldSchema(schemaObject)) { - // Check if this field has an explicit type provided by the node schema - const { ui_type } = schemaObject; - if (ui_type) { - return { - name: ui_type, - isCollection: isCollectionFieldType(ui_type), - isCollectionOrScalar: false, - }; - } - } +export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType => { if (isSchemaObject(schemaObject)) { if (schemaObject.const) { // Fields with a single const value are defined as `Literal["value"]` in the pydantic schema - it's actually an enum diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.test.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.test.ts index 6c0a6635c7..480387a8a4 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.test.ts @@ -97,6 +97,11 @@ const expected = { name: 'SchedulerField', isCollection: false, isCollectionOrScalar: false, + originalType: { + name: 'EnumField', + isCollection: false, + isCollectionOrScalar: false, + }, }, default: 'euler', }, @@ -111,6 +116,11 @@ const expected = { name: 'SchedulerField', isCollection: false, isCollectionOrScalar: false, + originalType: { + name: 'EnumField', + isCollection: false, + isCollectionOrScalar: false, + }, }, ui_hidden: false, ui_type: 'SchedulerField', @@ -141,6 +151,11 @@ const expected = { name: 'MainModelField', isCollection: false, isCollectionOrScalar: false, + originalType: { + name: 'ModelIdentifierField', + isCollection: false, + isCollectionOrScalar: false, + }, }, }, }, @@ -186,6 +201,48 @@ const expected = { nodePack: 'invokeai', classification: 'stable', }, + collect: { + title: 'Collect', + type: 'collect', + version: '1.0.0', + tags: [], + description: 'Collects values into a collection', + outputType: 'collect_output', + inputs: { + item: { + name: 'item', + title: 'Collection Item', + required: false, + description: 'The item to collect (all inputs must be of the same type)', + fieldKind: 'input', + input: 'connection', + ui_hidden: false, + ui_type: 'CollectionItemField', + type: { + name: 'CollectionItemField', + isCollection: false, + isCollectionOrScalar: false, + }, + }, + }, + outputs: { + collection: { + fieldKind: 'output', + name: 'collection', + title: 'Collection', + description: 'The collection of input items', + type: { + name: 'CollectionField', + isCollection: true, + isCollectionOrScalar: false, + }, + ui_hidden: false, + ui_type: 'CollectionField', + }, + }, + useCache: true, + classification: 'stable', + }, }; const schema = { @@ -785,6 +842,101 @@ const schema = { type: 'object', class: 'output', }, + CollectInvocation: { + properties: { + id: { + type: 'string', + title: 'Id', + description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', + field_kind: 'node_attribute', + }, + is_intermediate: { + type: 'boolean', + title: 'Is Intermediate', + description: 'Whether or not this is an intermediate invocation.', + default: false, + field_kind: 'node_attribute', + ui_type: 'IsIntermediate', + }, + use_cache: { + type: 'boolean', + title: 'Use Cache', + description: 'Whether or not to use the cache', + default: true, + field_kind: 'node_attribute', + }, + item: { + anyOf: [ + {}, + { + type: 'null', + }, + ], + title: 'Collection Item', + description: 'The item to collect (all inputs must be of the same type)', + field_kind: 'input', + input: 'connection', + orig_required: false, + ui_hidden: false, + ui_type: 'CollectionItemField', + }, + collection: { + items: {}, + type: 'array', + title: 'Collection', + description: 'The collection, will be provided on execution', + default: [], + field_kind: 'input', + input: 'any', + orig_default: [], + orig_required: false, + ui_hidden: true, + }, + type: { + type: 'string', + enum: ['collect'], + const: 'collect', + title: 'type', + default: 'collect', + field_kind: 'node_attribute', + }, + }, + type: 'object', + required: ['type', 'id'], + title: 'CollectInvocation', + description: 'Collects values into a collection', + classification: 'stable', + version: '1.0.0', + output: { + $ref: '#/components/schemas/CollectInvocationOutput', + }, + class: 'invocation', + }, + CollectInvocationOutput: { + properties: { + collection: { + description: 'The collection of input items', + field_kind: 'output', + items: {}, + title: 'Collection', + type: 'array', + ui_hidden: false, + ui_type: 'CollectionField', + }, + type: { + const: 'collect_output', + default: 'collect_output', + enum: ['collect_output'], + field_kind: 'node_attribute', + title: 'type', + type: 'string', + }, + }, + required: ['collection', 'type', 'type'], + title: 'CollectInvocationOutput', + type: 'object', + class: 'output', + }, }, }, } as OpenAPIV3_1.Document; diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts index 3178209f93..0638a52954 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts @@ -1,23 +1,29 @@ import { logger } from 'app/logging/logger'; +import { deepClone } from 'common/util/deepClone'; import { parseify } from 'common/util/serialize'; import type { Templates } from 'features/nodes/store/types'; import { FieldParseError } from 'features/nodes/types/error'; -import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field'; +import { + type FieldInputTemplate, + type FieldOutputTemplate, + type FieldType, + isStatefulFieldType, +} from 'features/nodes/types/field'; import type { InvocationTemplate } from 'features/nodes/types/invocation'; -import type { InvocationSchemaObject } from 'features/nodes/types/openapi'; +import type { InvocationFieldSchema, InvocationSchemaObject } from 'features/nodes/types/openapi'; import { isInvocationFieldSchema, isInvocationOutputSchemaObject, isInvocationSchemaObject, } from 'features/nodes/types/openapi'; import { t } from 'i18next'; -import { reduce } from 'lodash-es'; +import { isEqual, reduce } from 'lodash-es'; import type { OpenAPIV3_1 } from 'openapi-types'; import { serializeError } from 'serialize-error'; import { buildFieldInputTemplate } from './buildFieldInputTemplate'; import { buildFieldOutputTemplate } from './buildFieldOutputTemplate'; -import { parseFieldType } from './parseFieldType'; +import { isCollectionFieldType, parseFieldType } from './parseFieldType'; const RESERVED_INPUT_FIELD_NAMES = ['id', 'type', 'use_cache']; const RESERVED_OUTPUT_FIELD_NAMES = ['type']; @@ -94,51 +100,43 @@ export const parseSchema = ( return inputsAccumulator; } - try { - const fieldType = parseFieldType(property); + const fieldTypeOverride = property.ui_type + ? { + name: property.ui_type, + isCollection: isCollectionFieldType(property.ui_type), + isCollectionOrScalar: false, + } + : null; - if (isReservedFieldType(fieldType.name)) { - logger('nodes').trace( - { node: type, field: propertyName, schema: parseify(property) }, - 'Skipped reserved input field' - ); - return inputsAccumulator; - } + const originalFieldType = getFieldType(property, propertyName, type, 'input'); - const fieldInputTemplate = buildFieldInputTemplate(property, propertyName, fieldType); - - inputsAccumulator[propertyName] = fieldInputTemplate; - } catch (e) { - if (e instanceof FieldParseError) { - logger('nodes').warn( - { - node: type, - field: propertyName, - schema: parseify(property), - }, - t('nodes.inputFieldTypeParseError', { - node: type, - field: propertyName, - message: e.message, - }) - ); - } else { - logger('nodes').warn( - { - node: type, - field: propertyName, - schema: parseify(property), - error: serializeError(e), - }, - t('nodes.inputFieldTypeParseError', { - node: type, - field: propertyName, - message: 'unknown error', - }) - ); - } + const fieldType = fieldTypeOverride ?? originalFieldType; + if (!fieldType) { + logger('nodes').trace( + { node: type, field: propertyName, schema: parseify(property) }, + 'Unable to parse field type' + ); + return inputsAccumulator; } + if (isReservedFieldType(fieldType.name)) { + logger('nodes').trace( + { node: type, field: propertyName, schema: parseify(property) }, + 'Skipped reserved input field' + ); + return inputsAccumulator; + } + + if (isStatefulFieldType(fieldType) && originalFieldType && !isEqual(originalFieldType, fieldType)) { + console.log('STATEFUL WITH ORIGINAL'); + fieldType.originalType = deepClone(originalFieldType); + console.log(fieldType); + } + + const fieldInputTemplate = buildFieldInputTemplate(property, propertyName, fieldType); + console.log(fieldInputTemplate); + inputsAccumulator[propertyName] = fieldInputTemplate; + return inputsAccumulator; }, {} @@ -183,54 +181,34 @@ export const parseSchema = ( return outputsAccumulator; } - try { - const fieldType = parseFieldType(property); + const fieldTypeOverride = property.ui_type + ? { + name: property.ui_type, + isCollection: isCollectionFieldType(property.ui_type), + isCollectionOrScalar: false, + } + : null; - if (!fieldType) { - logger('nodes').warn( - { - node: type, - field: propertyName, - schema: parseify(property), - }, - 'Missing output field type' - ); - return outputsAccumulator; - } + const originalFieldType = getFieldType(property, propertyName, type, 'output'); - const fieldOutputTemplate = buildFieldOutputTemplate(property, propertyName, fieldType); - - outputsAccumulator[propertyName] = fieldOutputTemplate; - } catch (e) { - if (e instanceof FieldParseError) { - logger('nodes').warn( - { - node: type, - field: propertyName, - schema: parseify(property), - }, - t('nodes.outputFieldTypeParseError', { - node: type, - field: propertyName, - message: e.message, - }) - ); - } else { - logger('nodes').warn( - { - node: type, - field: propertyName, - schema: parseify(property), - error: serializeError(e), - }, - t('nodes.outputFieldTypeParseError', { - node: type, - field: propertyName, - message: 'unknown error', - }) - ); - } + const fieldType = fieldTypeOverride ?? originalFieldType; + if (!fieldType) { + logger('nodes').trace( + { node: type, field: propertyName, schema: parseify(property) }, + 'Unable to parse field type' + ); + return outputsAccumulator; } + + if (isStatefulFieldType(fieldType) && originalFieldType && !isEqual(originalFieldType, fieldType)) { + console.log('STATEFUL WITH ORIGINAL'); + fieldType.originalType = deepClone(originalFieldType); + console.log(fieldType); + } + + const fieldOutputTemplate = buildFieldOutputTemplate(property, propertyName, fieldType); + + outputsAccumulator[propertyName] = fieldOutputTemplate; return outputsAccumulator; }, {} as Record @@ -259,3 +237,45 @@ export const parseSchema = ( return invocations; }; + +const getFieldType = ( + property: InvocationFieldSchema, + propertyName: string, + type: string, + kind: 'input' | 'output' +): FieldType | null => { + try { + return parseFieldType(property); + } catch (e) { + const tKey = kind === 'input' ? 'nodes.inputFieldTypeParseError' : 'nodes.outputFieldTypeParseError'; + if (e instanceof FieldParseError) { + logger('nodes').warn( + { + node: type, + field: propertyName, + schema: parseify(property), + }, + t(tKey, { + node: type, + field: propertyName, + message: e.message, + }) + ); + } else { + logger('nodes').warn( + { + node: type, + field: propertyName, + schema: parseify(property), + error: serializeError(e), + }, + t(tKey, { + node: type, + field: propertyName, + message: 'unknown error', + }) + ); + } + return null; + } +}; From fe7ed72c9c12334c46f0ada4bf449b7beca04ce0 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 17 May 2024 20:15:04 +1000 Subject: [PATCH 008/207] feat(nodes): make all `ModelIdentifierField` inputs accept connections --- .../controlnet_image_processors.py | 5 ++--- invokeai/app/invocations/ip_adapter.py | 5 ++--- invokeai/app/invocations/model.py | 22 +++++++++---------- invokeai/app/invocations/sdxl.py | 10 ++++----- invokeai/app/invocations/t2i_adapter.py | 5 ++--- 5 files changed, 21 insertions(+), 26 deletions(-) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index d2f01622b2..f5edd49874 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -24,7 +24,6 @@ from pydantic import BaseModel, Field, field_validator, model_validator from invokeai.app.invocations.fields import ( FieldDescriptions, ImageField, - Input, InputField, OutputField, UIType, @@ -80,13 +79,13 @@ class ControlOutput(BaseInvocationOutput): control: ControlField = OutputField(description=FieldDescriptions.control) -@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.1") +@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.1.2") class ControlNetInvocation(BaseInvocation): """Collects ControlNet info to pass to other nodes""" image: ImageField = InputField(description="The control image") control_model: ModelIdentifierField = InputField( - description=FieldDescriptions.controlnet_model, input=Input.Direct, ui_type=UIType.ControlNetModel + description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel ) control_weight: Union[float, List[float]] = InputField( default=1.0, ge=-1, le=2, description="The weight given to the ControlNet" diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index 34a30628da..de40879eef 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field, field_validator, model_validator from typing_extensions import Self from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, TensorField, UIType +from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, TensorField, UIType from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights @@ -58,7 +58,7 @@ class IPAdapterOutput(BaseInvocationOutput): CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"} -@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.0") +@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.4.1") class IPAdapterInvocation(BaseInvocation): """Collects IP-Adapter info to pass to other nodes.""" @@ -67,7 +67,6 @@ class IPAdapterInvocation(BaseInvocation): ip_adapter_model: ModelIdentifierField = InputField( description="The IP-Adapter model.", title="IP-Adapter Model", - input=Input.Direct, ui_order=-1, ui_type=UIType.IPAdapterModel, ) diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 245034c481..05f451b957 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -98,14 +98,12 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput): title="Main Model", tags=["model"], category="model", - version="1.0.2", + version="1.0.3", ) class MainModelLoaderInvocation(BaseInvocation): """Loads a main model, outputting its submodels.""" - model: ModelIdentifierField = InputField( - description=FieldDescriptions.main_model, input=Input.Direct, ui_type=UIType.MainModel - ) + model: ModelIdentifierField = InputField(description=FieldDescriptions.main_model, ui_type=UIType.MainModel) # TODO: precision? def invoke(self, context: InvocationContext) -> ModelLoaderOutput: @@ -134,12 +132,12 @@ class LoRALoaderOutput(BaseInvocationOutput): clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP") -@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.2") +@invocation("lora_loader", title="LoRA", tags=["model"], category="model", version="1.0.3") class LoRALoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" lora: ModelIdentifierField = InputField( - description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel + description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel ) weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) unet: Optional[UNetField] = InputField( @@ -197,12 +195,12 @@ class LoRASelectorOutput(BaseInvocationOutput): lora: LoRAField = OutputField(description="LoRA model and weight", title="LoRA") -@invocation("lora_selector", title="LoRA Selector", tags=["model"], category="model", version="1.0.0") +@invocation("lora_selector", title="LoRA Selector", tags=["model"], category="model", version="1.0.1") class LoRASelectorInvocation(BaseInvocation): """Selects a LoRA model and weight.""" lora: ModelIdentifierField = InputField( - description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel + description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel ) weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) @@ -273,13 +271,13 @@ class SDXLLoRALoaderOutput(BaseInvocationOutput): title="SDXL LoRA", tags=["lora", "model"], category="model", - version="1.0.2", + version="1.0.3", ) class SDXLLoRALoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" lora: ModelIdentifierField = InputField( - description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA", ui_type=UIType.LoRAModel + description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel ) weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) unet: Optional[UNetField] = InputField( @@ -414,12 +412,12 @@ class SDXLLoRACollectionLoader(BaseInvocation): return output -@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.2") +@invocation("vae_loader", title="VAE", tags=["vae", "model"], category="model", version="1.0.3") class VAELoaderInvocation(BaseInvocation): """Loads a VAE model, outputting a VaeLoaderOutput""" vae_model: ModelIdentifierField = InputField( - description=FieldDescriptions.vae_model, input=Input.Direct, title="VAE", ui_type=UIType.VAEModel + description=FieldDescriptions.vae_model, title="VAE", ui_type=UIType.VAEModel ) def invoke(self, context: InvocationContext) -> VAEOutput: diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index 9b1ee90350..1c0817cb92 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -1,4 +1,4 @@ -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType +from invokeai.app.invocations.fields import FieldDescriptions, InputField, OutputField, UIType from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.model_manager import SubModelType @@ -30,12 +30,12 @@ class SDXLRefinerModelLoaderOutput(BaseInvocationOutput): vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE") -@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.2") +@invocation("sdxl_model_loader", title="SDXL Main Model", tags=["model", "sdxl"], category="model", version="1.0.3") class SDXLModelLoaderInvocation(BaseInvocation): """Loads an sdxl base model, outputting its submodels.""" model: ModelIdentifierField = InputField( - description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type=UIType.SDXLMainModel + description=FieldDescriptions.sdxl_main_model, ui_type=UIType.SDXLMainModel ) # TODO: precision? @@ -67,13 +67,13 @@ class SDXLModelLoaderInvocation(BaseInvocation): title="SDXL Refiner Model", tags=["model", "sdxl", "refiner"], category="model", - version="1.0.2", + version="1.0.3", ) class SDXLRefinerModelLoaderInvocation(BaseInvocation): """Loads an sdxl refiner model, outputting its submodels.""" model: ModelIdentifierField = InputField( - description=FieldDescriptions.sdxl_refiner_model, input=Input.Direct, ui_type=UIType.SDXLRefinerModel + description=FieldDescriptions.sdxl_refiner_model, ui_type=UIType.SDXLRefinerModel ) # TODO: precision? diff --git a/invokeai/app/invocations/t2i_adapter.py b/invokeai/app/invocations/t2i_adapter.py index b22a089d3f..04f9a6c695 100644 --- a/invokeai/app/invocations/t2i_adapter.py +++ b/invokeai/app/invocations/t2i_adapter.py @@ -8,7 +8,7 @@ from invokeai.app.invocations.baseinvocation import ( invocation, invocation_output, ) -from invokeai.app.invocations.fields import FieldDescriptions, ImageField, Input, InputField, OutputField, UIType +from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.services.shared.invocation_context import InvocationContext @@ -45,7 +45,7 @@ class T2IAdapterOutput(BaseInvocationOutput): @invocation( - "t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.2" + "t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.3" ) class T2IAdapterInvocation(BaseInvocation): """Collects T2I-Adapter info to pass to other nodes.""" @@ -55,7 +55,6 @@ class T2IAdapterInvocation(BaseInvocation): t2i_adapter_model: ModelIdentifierField = InputField( description="The T2I-Adapter model.", title="T2I-Adapter Model", - input=Input.Direct, ui_order=-1, ui_type=UIType.T2IAdapterModel, ) From 2cbf7d9221bdce7e7dee5072000e1cf601236318 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 17 May 2024 20:18:20 +1000 Subject: [PATCH 009/207] fix(ui): stupid ts --- invokeai/frontend/web/src/features/nodes/types/field.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts index 37e2a26397..4dcc478352 100644 --- a/invokeai/frontend/web/src/features/nodes/types/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -160,7 +160,7 @@ const zStatefulFieldType = z.union([ export type StatefulFieldType = z.infer; const statefulFieldTypeNames = zStatefulFieldType.options.map((o) => o.shape.name.value); export const isStatefulFieldType = (fieldType: FieldType): fieldType is StatefulFieldType => - statefulFieldTypeNames.includes(fieldType.name as any); + (statefulFieldTypeNames as string[]).includes(fieldType.name); const zFieldType = z.union([zStatefulFieldType, zStatelessFieldType]); export type FieldType = z.infer; // #endregion From 6a2c53f6c5e9ad8d7c742b5e7e762ccc48e221c8 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 17 May 2024 20:33:44 +1000 Subject: [PATCH 010/207] fix(ui): do not allow comparison between undefined original types --- .../nodes/store/util/validateSourceAndTargetTypes.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts index cc5a6bb596..45b771b5b4 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts @@ -9,13 +9,13 @@ export const areTypesEqual = (sourceType: FieldType, targetType: FieldType) => { if (isEqual(_sourceType, _targetType)) { return true; } - if (isEqual(_sourceType, _targetTypeOriginal)) { + if (_targetTypeOriginal && isEqual(_sourceType, _targetTypeOriginal)) { return true; } - if (isEqual(_sourceTypeOriginal, _targetType)) { + if (_sourceTypeOriginal && isEqual(_sourceTypeOriginal, _targetType)) { return true; } - if (isEqual(_sourceTypeOriginal, _targetTypeOriginal)) { + if (_sourceTypeOriginal && _targetTypeOriginal && isEqual(_sourceTypeOriginal, _targetTypeOriginal)) { return true; } return false; From a012bb6e071ab3dbd1d023d45068f698155531ff Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 17 May 2024 20:47:00 +1000 Subject: [PATCH 011/207] feat(ui): add ModelIdentifierField field type This new field type accepts _any_ model. A field renderer lets the user select any available model. --- .../Invocation/fields/InputFieldRenderer.tsx | 7 ++ .../ModelIdentifierFieldInputComponent.tsx | 68 +++++++++++++++++++ .../src/features/nodes/store/nodesSlice.ts | 6 ++ .../web/src/features/nodes/types/field.ts | 32 +++++++++ .../util/schema/buildFieldInputInstance.ts | 1 + .../util/schema/buildFieldInputTemplate.ts | 16 +++++ 6 files changed, 130 insertions(+) create mode 100644 invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent.tsx diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx index b6e331c114..99937ceec4 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx @@ -1,3 +1,4 @@ +import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent'; import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance'; import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate'; import { @@ -23,6 +24,8 @@ import { isLoRAModelFieldInputTemplate, isMainModelFieldInputInstance, isMainModelFieldInputTemplate, + isModelIdentifierFieldInputInstance, + isModelIdentifierFieldInputTemplate, isSchedulerFieldInputInstance, isSchedulerFieldInputTemplate, isSDXLMainModelFieldInputInstance, @@ -95,6 +98,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { return ; } + if (isModelIdentifierFieldInputInstance(fieldInstance) && isModelIdentifierFieldInputTemplate(fieldTemplate)) { + return ; + } + if (isSDXLRefinerModelFieldInputInstance(fieldInstance) && isSDXLRefinerModelFieldInputTemplate(fieldTemplate)) { return ; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent.tsx new file mode 100644 index 0000000000..6a0c9b63fa --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent.tsx @@ -0,0 +1,68 @@ +import { Combobox, Flex, FormControl } from '@invoke-ai/ui-library'; +import { EMPTY_ARRAY } from 'app/store/constants'; +import { useAppDispatch } from 'app/store/storeHooks'; +import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; +import { fieldModelIdentifierValueChanged } from 'features/nodes/store/nodesSlice'; +import type { ModelIdentifierFieldInputInstance, ModelIdentifierFieldInputTemplate } from 'features/nodes/types/field'; +import { memo, useCallback, useMemo } from 'react'; +import { modelConfigsAdapterSelectors, useGetModelConfigsQuery } from 'services/api/endpoints/models'; +import type { AnyModelConfig } from 'services/api/types'; + +import type { FieldComponentProps } from './types'; + +type Props = FieldComponentProps; + +const ModelIdentifierFieldInputComponent = (props: Props) => { + const { nodeId, field } = props; + const dispatch = useAppDispatch(); + const { data, isLoading } = useGetModelConfigsQuery(); + const _onChange = useCallback( + (value: AnyModelConfig | null) => { + if (!value) { + return; + } + dispatch( + fieldModelIdentifierValueChanged({ + nodeId, + fieldName: field.name, + value, + }) + ); + }, + [dispatch, field.name, nodeId] + ); + + const modelConfigs = useMemo(() => { + if (!data) { + return EMPTY_ARRAY; + } + + return modelConfigsAdapterSelectors.selectAll(data); + }, [data]); + + console.log(modelConfigs); + + const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ + modelConfigs, + onChange: _onChange, + isLoading, + selectedModel: field.value, + groupByType: true, + }); + + return ( + + + + + + ); +}; + +export default memo(ModelIdentifierFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 1f61c77e83..cec13e8df4 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -16,6 +16,7 @@ import type { IPAdapterModelFieldValue, LoRAModelFieldValue, MainModelFieldValue, + ModelIdentifierFieldValue, SchedulerFieldValue, SDXLRefinerModelFieldValue, StatefulFieldValue, @@ -35,6 +36,7 @@ import { zIPAdapterModelFieldValue, zLoRAModelFieldValue, zMainModelFieldValue, + zModelIdentifierFieldValue, zSchedulerFieldValue, zSDXLRefinerModelFieldValue, zStatefulFieldValue, @@ -344,6 +346,9 @@ export const nodesSlice = createSlice({ fieldMainModelValueChanged: (state, action: FieldValueAction) => { fieldValueReducer(state, action, zMainModelFieldValue); }, + fieldModelIdentifierValueChanged: (state, action: FieldValueAction) => { + fieldValueReducer(state, action, zModelIdentifierFieldValue); + }, fieldRefinerModelValueChanged: (state, action: FieldValueAction) => { fieldValueReducer(state, action, zSDXLRefinerModelFieldValue); }, @@ -469,6 +474,7 @@ export const { fieldT2IAdapterModelValueChanged, fieldLabelChanged, fieldLoRAModelValueChanged, + fieldModelIdentifierValueChanged, fieldMainModelValueChanged, fieldNumberValueChanged, fieldRefinerModelValueChanged, diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts index 4dcc478352..a98f773c7e 100644 --- a/invokeai/frontend/web/src/features/nodes/types/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -106,6 +106,10 @@ const zMainModelFieldType = zFieldTypeBase.extend({ name: z.literal('MainModelField'), originalType: zStatelessFieldType.optional(), }); +const zModelIdentifierFieldType = zFieldTypeBase.extend({ + name: z.literal('ModelIdentifierField'), + originalType: zStatelessFieldType.optional(), +}); const zSDXLMainModelFieldType = zFieldTypeBase.extend({ name: z.literal('SDXLMainModelField'), originalType: zStatelessFieldType.optional(), @@ -146,6 +150,7 @@ const zStatefulFieldType = z.union([ zEnumFieldType, zImageFieldType, zBoardFieldType, + zModelIdentifierFieldType, zMainModelFieldType, zSDXLMainModelFieldType, zSDXLRefinerModelFieldType, @@ -396,6 +401,29 @@ export const isMainModelFieldInputTemplate = (val: unknown): val is MainModelFie zMainModelFieldInputTemplate.safeParse(val).success; // #endregion +// #region ModelIdentifierField +export const zModelIdentifierFieldValue = zModelIdentifierField.optional(); +const zModelIdentifierFieldInputInstance = zFieldInputInstanceBase.extend({ + value: zModelIdentifierFieldValue, +}); +const zModelIdentifierFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zModelIdentifierFieldType, + originalType: zFieldType.optional(), + default: zModelIdentifierFieldValue, +}); +const zModelIdentifierFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zModelIdentifierFieldType, + originalType: zFieldType.optional(), +}); +export type ModelIdentifierFieldValue = z.infer; +export type ModelIdentifierFieldInputInstance = z.infer; +export type ModelIdentifierFieldInputTemplate = z.infer; +export const isModelIdentifierFieldInputInstance = (val: unknown): val is ModelIdentifierFieldInputInstance => + zModelIdentifierFieldInputInstance.safeParse(val).success; +export const isModelIdentifierFieldInputTemplate = (val: unknown): val is ModelIdentifierFieldInputTemplate => + zModelIdentifierFieldInputTemplate.safeParse(val).success; +// #endregion + // #region SDXLMainModelField const zSDXLMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only. @@ -643,6 +671,7 @@ export const zStatefulFieldValue = z.union([ zEnumFieldValue, zImageFieldValue, zBoardFieldValue, + zModelIdentifierFieldValue, zMainModelFieldValue, zSDXLMainModelFieldValue, zSDXLRefinerModelFieldValue, @@ -669,6 +698,7 @@ const zStatefulFieldInputInstance = z.union([ zEnumFieldInputInstance, zImageFieldInputInstance, zBoardFieldInputInstance, + zModelIdentifierFieldInputInstance, zMainModelFieldInputInstance, zSDXLMainModelFieldInputInstance, zSDXLRefinerModelFieldInputInstance, @@ -696,6 +726,7 @@ const zStatefulFieldInputTemplate = z.union([ zEnumFieldInputTemplate, zImageFieldInputTemplate, zBoardFieldInputTemplate, + zModelIdentifierFieldInputTemplate, zMainModelFieldInputTemplate, zSDXLMainModelFieldInputTemplate, zSDXLRefinerModelFieldInputTemplate, @@ -724,6 +755,7 @@ const zStatefulFieldOutputTemplate = z.union([ zEnumFieldOutputTemplate, zImageFieldOutputTemplate, zBoardFieldOutputTemplate, + zModelIdentifierFieldOutputTemplate, zMainModelFieldOutputTemplate, zSDXLMainModelFieldOutputTemplate, zSDXLRefinerModelFieldOutputTemplate, diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts index f8097566c9..597779fd61 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts @@ -11,6 +11,7 @@ const FIELD_VALUE_FALLBACK_MAP: Record = IntegerField: 0, IPAdapterModelField: undefined, LoRAModelField: undefined, + ModelIdentifierField: undefined, MainModelField: undefined, SchedulerField: 'euler', SDXLMainModelField: undefined, diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts index 6b4c4d8b29..2b77274526 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts @@ -13,6 +13,7 @@ import type { IPAdapterModelFieldInputTemplate, LoRAModelFieldInputTemplate, MainModelFieldInputTemplate, + ModelIdentifierFieldInputTemplate, SchedulerFieldInputTemplate, SDXLMainModelFieldInputTemplate, SDXLRefinerModelFieldInputTemplate, @@ -136,6 +137,20 @@ const buildBooleanFieldInputTemplate: FieldInputTemplateBuilder = ({ + schemaObject, + baseField, + fieldType, +}) => { + const template: ModelIdentifierFieldInputTemplate = { + ...baseField, + type: fieldType, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + const buildMainModelFieldInputTemplate: FieldInputTemplateBuilder = ({ schemaObject, baseField, @@ -355,6 +370,7 @@ export const TEMPLATE_BUILDER_MAP: Record Date: Fri, 17 May 2024 20:47:46 +1000 Subject: [PATCH 012/207] feat(nodes): add `ModelIdentifierInvocation` This node allows a user to select _any_ model, outputting a `ModelIdentifierField` for that model. --- invokeai/app/invocations/model.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 05f451b957..6f78cf43bf 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -93,6 +93,32 @@ class ModelLoaderOutput(UNetOutput, CLIPOutput, VAEOutput): pass +@invocation_output("model_identifier_output") +class ModelIdentifierOutput(BaseInvocationOutput): + """Model identifier output""" + + model: ModelIdentifierField = OutputField(description="Model identifier", title="Model") + + +@invocation( + "model_identifier", + title="Model identifier", + tags=["model"], + category="model", + version="1.0.0", +) +class ModelIdentifierInvocation(BaseInvocation): + """Selects any model, outputting it.""" + + model: ModelIdentifierField = InputField(description="The model to select", title="Model") + + def invoke(self, context: InvocationContext) -> ModelIdentifierOutput: + if not context.models.exists(self.model.key): + raise Exception(f"Unknown model {self.model.key}") + + return ModelIdentifierOutput(model=self.model) + + @invocation( "main_model_loader", title="Main Model", From de1ea50e6dc24bff75a8c01741f6ad4ec41aea3c Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 17 May 2024 21:05:41 +1000 Subject: [PATCH 013/207] fix(ui): rebase resolution --- .../features/nodes/store/util/makeIsConnectionValidSelector.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts index e7f659508f..5a5972a376 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts @@ -106,7 +106,7 @@ export const makeConnectionErrorSelector = ( return i18n.t('nodes.cannotConnectToDirectInput'); } - if (targetNode.data.type === 'collect' && targetFieldName === 'item') { + if (targetNode?.data.type === 'collect' && targetFieldName === 'item') { // Collect nodes shouldn't mix and match field types const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); if (collectItemType) { From af7b194bec0c8ab108c32c59c4e640f453378dae Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 17 May 2024 21:05:53 +1000 Subject: [PATCH 014/207] chore(ui): lint --- .../fields/inputs/ModelIdentifierFieldInputComponent.tsx | 2 -- .../web/src/features/nodes/util/schema/parseSchema.ts | 5 ----- 2 files changed, 7 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent.tsx index 6a0c9b63fa..4019689978 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent.tsx @@ -40,8 +40,6 @@ const ModelIdentifierFieldInputComponent = (props: Props) => { return modelConfigsAdapterSelectors.selectAll(data); }, [data]); - console.log(modelConfigs); - const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({ modelConfigs, onChange: _onChange, diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts index 0638a52954..f9b93382f9 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts @@ -128,13 +128,10 @@ export const parseSchema = ( } if (isStatefulFieldType(fieldType) && originalFieldType && !isEqual(originalFieldType, fieldType)) { - console.log('STATEFUL WITH ORIGINAL'); fieldType.originalType = deepClone(originalFieldType); - console.log(fieldType); } const fieldInputTemplate = buildFieldInputTemplate(property, propertyName, fieldType); - console.log(fieldInputTemplate); inputsAccumulator[propertyName] = fieldInputTemplate; return inputsAccumulator; @@ -201,9 +198,7 @@ export const parseSchema = ( } if (isStatefulFieldType(fieldType) && originalFieldType && !isEqual(originalFieldType, fieldType)) { - console.log('STATEFUL WITH ORIGINAL'); fieldType.originalType = deepClone(originalFieldType); - console.log(fieldType); } const fieldOutputTemplate = buildFieldOutputTemplate(property, propertyName, fieldType); From 6658897210c12335cc7107466a03edecb89d028f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 18 May 2024 17:22:29 +1000 Subject: [PATCH 015/207] tidy(ui): tidy connection validation functions and logic --- .../flow/AddNodePopover/AddNodePopover.tsx | 3 +- .../src/features/nodes/hooks/useConnection.ts | 2 +- .../nodes/hooks/useConnectionState.ts | 2 +- .../nodes/hooks/useIsValidConnection.ts | 11 +- .../nodes/store/util/connectionValidation.ts | 386 ++++++++++++++++++ .../store/util/findConnectionToValidHandle.ts | 105 ----- .../nodes/store/util/getIsGraphAcyclic.ts | 21 - .../util/makeIsConnectionValidSelector.ts | 146 ------- .../util/validateSourceAndTargetTypes.ts | 90 ---- 9 files changed, 396 insertions(+), 370 deletions(-) create mode 100644 invokeai/frontend/web/src/features/nodes/store/util/connectionValidation.ts delete mode 100644 invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts delete mode 100644 invokeai/frontend/web/src/features/nodes/store/util/getIsGraphAcyclic.ts delete mode 100644 invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts delete mode 100644 invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx index 95104c683c..40fa13320a 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx @@ -17,8 +17,7 @@ import { nodeAdded, openAddNodePopover, } from 'features/nodes/store/nodesSlice'; -import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle'; -import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes'; +import { getFirstValidConnection, validateSourceAndTargetTypes } from 'features/nodes/store/util/connectionValidation'; import type { AnyNode } from 'features/nodes/types/invocation'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { filter, map, memoize, some } from 'lodash-es'; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts index df628ba5af..81eea993be 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts @@ -8,7 +8,7 @@ import { $templates, connectionMade, } from 'features/nodes/store/nodesSlice'; -import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle'; +import { getFirstValidConnection } from 'features/nodes/store/util/connectionValidation'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { useCallback, useMemo } from 'react'; import type { OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow'; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts index 728b492453..dfa8b0cf36 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts @@ -2,7 +2,7 @@ import { useStore } from '@nanostores/react'; import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { $pendingConnection, $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeIsConnectionValidSelector'; +import { makeConnectionErrorSelector } from 'features/nodes/store/util/connectionValidation.js'; import { useMemo } from 'react'; import { useFieldType } from './useFieldType.ts'; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts index 14a7a728e0..b92114bab2 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts @@ -2,9 +2,12 @@ import { useStore } from '@nanostores/react'; import { useAppSelector, useAppStore } from 'app/store/storeHooks'; import { $templates } from 'features/nodes/store/nodesSlice'; -import { getIsGraphAcyclic } from 'features/nodes/store/util/getIsGraphAcyclic'; -import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector'; -import { areTypesEqual, validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes'; +import { + areTypesEqual, + getCollectItemType, + getHasCycles, + validateSourceAndTargetTypes, +} from 'features/nodes/store/util/connectionValidation'; import type { InvocationNodeData } from 'features/nodes/types/invocation'; import { useCallback } from 'react'; import type { Connection, Node } from 'reactflow'; @@ -90,7 +93,7 @@ export const useIsValidConnection = () => { } // Graphs much be acyclic (no loops!) - return getIsGraphAcyclic(source, target, nodes, edges); + return !getHasCycles(source, target, nodes, edges); }, [shouldValidateGraph, templates, store] ); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/connectionValidation.ts b/invokeai/frontend/web/src/features/nodes/store/util/connectionValidation.ts new file mode 100644 index 0000000000..98de4284ad --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/connectionValidation.ts @@ -0,0 +1,386 @@ +import graphlib from '@dagrejs/graphlib'; +import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; +import type { PendingConnection, Templates } from 'features/nodes/store/types'; +import { type FieldType, isStatefulFieldType } from 'features/nodes/types/field'; +import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation'; +import i18n from 'i18next'; +import { differenceWith, isEqual, map, omit } from 'lodash-es'; +import type { Connection, Edge, HandleType, Node } from 'reactflow'; +import { assert } from 'tsafe'; + +/** + * Finds the first valid field for a pending connection between two nodes. + * @param templates The invocation templates + * @param nodes The current nodes + * @param edges The current edges + * @param pendingConnection The pending connection + * @param candidateNode The candidate node to which the connection is being made + * @param candidateTemplate The candidate template for the candidate node + * @returns The first valid connection, or null if no valid connection is found + */ +export const getFirstValidConnection = ( + templates: Templates, + nodes: AnyNode[], + edges: InvocationNodeEdge[], + pendingConnection: PendingConnection, + candidateNode: InvocationNode, + candidateTemplate: InvocationTemplate +): Connection | null => { + if (pendingConnection.node.id === candidateNode.id) { + // Cannot connect to self + return null; + } + + const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source'; + + if (pendingFieldKind === 'source') { + // Connecting from a source to a target + if (getHasCycles(pendingConnection.node.id, candidateNode.id, nodes, edges)) { + return null; + } + if (candidateNode.data.type === 'collect') { + // Special handling for collect node - the `item` field takes any number of connections + return { + source: pendingConnection.node.id, + sourceHandle: pendingConnection.fieldTemplate.name, + target: candidateNode.id, + targetHandle: 'item', + }; + } + // Only one connection per target field is allowed - look for an unconnected target field + const candidateFields = map(candidateTemplate.inputs); + const candidateConnectedFields = edges + .filter((edge) => edge.target === candidateNode.id) + .map((edge) => { + // Edges must always have a targetHandle, safe to assert here + assert(edge.targetHandle); + return edge.targetHandle; + }); + const candidateUnconnectedFields = differenceWith( + candidateFields, + candidateConnectedFields, + (field, connectedFieldName) => field.name === connectedFieldName + ); + const candidateField = candidateUnconnectedFields.find((field) => + validateSourceAndTargetTypes(pendingConnection.fieldTemplate.type, field.type) + ); + if (candidateField) { + return { + source: pendingConnection.node.id, + sourceHandle: pendingConnection.fieldTemplate.name, + target: candidateNode.id, + targetHandle: candidateField.name, + }; + } + } else { + // Connecting from a target to a source + // Ensure we there is not already an edge to the target, except for collect nodes + const isCollect = pendingConnection.node.data.type === 'collect'; + const isTargetAlreadyConnected = edges.some( + (e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name + ); + if (!isCollect && isTargetAlreadyConnected) { + return null; + } + + if (getHasCycles(candidateNode.id, pendingConnection.node.id, nodes, edges)) { + return null; + } + + // Sources/outputs can have any number of edges, we can take the first matching output field + let candidateFields = map(candidateTemplate.outputs); + if (isCollect) { + // Narrow candidates to same field type as already is connected to the collect node + const collectItemType = getCollectItemType(templates, nodes, edges, pendingConnection.node.id); + if (collectItemType) { + candidateFields = candidateFields.filter((field) => areTypesEqual(field.type, collectItemType)); + } + } + const candidateField = candidateFields.find((field) => { + const isValid = validateSourceAndTargetTypes(field.type, pendingConnection.fieldTemplate.type); + const isAlreadyConnected = edges.some((e) => e.source === candidateNode.id && e.sourceHandle === field.name); + return isValid && !isAlreadyConnected; + }); + if (candidateField) { + return { + source: candidateNode.id, + sourceHandle: candidateField.name, + target: pendingConnection.node.id, + targetHandle: pendingConnection.fieldTemplate.name, + }; + } + } + + return null; +}; + +/** + * Check if adding an edge between the source and target nodes would create a cycle in the graph. + * @param source The source node id + * @param target The target node id + * @param nodes The graph's current nodes + * @param edges The graph's current edges + * @returns True if the graph would be acyclic after adding the edge, false otherwise + */ +export const getHasCycles = (source: string, target: string, nodes: Node[], edges: Edge[]) => { + // construct graphlib graph from editor state + const g = new graphlib.Graph(); + + nodes.forEach((n) => { + g.setNode(n.id); + }); + + edges.forEach((e) => { + g.setEdge(e.source, e.target); + }); + + // add the candidate edge + g.setEdge(source, target); + + // check if the graph is acyclic + return !graphlib.alg.isAcyclic(g); +}; + +/** + * Given a collect node, return the type of the items it collects. The graph is traversed to find the first node and + * field connected to the collector's `item` input. The field type of that field is returned, else null if there is no + * input field. + * @param templates The current invocation templates + * @param nodes The current nodes + * @param edges The current edges + * @param nodeId The collect node's id + * @returns The type of the items the collect node collects, or null if there is no input field + */ +export const getCollectItemType = ( + templates: Templates, + nodes: AnyNode[], + edges: InvocationNodeEdge[], + nodeId: string +): FieldType | null => { + const firstEdgeToCollect = edges.find((edge) => edge.target === nodeId && edge.targetHandle === 'item'); + if (!firstEdgeToCollect?.sourceHandle) { + return null; + } + const node = nodes.find((n) => n.id === firstEdgeToCollect.source); + if (!node) { + return null; + } + const template = templates[node.data.type]; + if (!template) { + return null; + } + const fieldType = template.outputs[firstEdgeToCollect.sourceHandle]?.type ?? null; + return fieldType; +}; + +/** + * Creates a selector that validates a pending connection. + * + * NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts` + * TODO: Figure out how to do this without duplicating all the logic + * + * @param templates The invocation templates + * @param pendingConnection The current pending connection (if there is one) + * @param nodeId The id of the node for which the selector is being created + * @param fieldName The name of the field for which the selector is being created + * @param handleType The type of the handle for which the selector is being created + * @param fieldType The type of the field for which the selector is being created + * @returns + */ +export const makeConnectionErrorSelector = ( + templates: Templates, + pendingConnection: PendingConnection | null, + nodeId: string, + fieldName: string, + handleType: HandleType, + fieldType: FieldType +) => { + return createMemoizedSelector(selectNodesSlice, (nodesSlice) => { + const { nodes, edges } = nodesSlice; + + if (!pendingConnection) { + return i18n.t('nodes.noConnectionInProgress'); + } + + const connectionNodeId = pendingConnection.node.id; + const connectionFieldName = pendingConnection.fieldTemplate.name; + const connectionHandleType = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source'; + const connectionStartFieldType = pendingConnection.fieldTemplate.type; + + if (!connectionHandleType || !connectionNodeId || !connectionFieldName) { + return i18n.t('nodes.noConnectionData'); + } + + const targetType = handleType === 'target' ? fieldType : connectionStartFieldType; + const sourceType = handleType === 'source' ? fieldType : connectionStartFieldType; + + if (nodeId === connectionNodeId) { + return i18n.t('nodes.cannotConnectToSelf'); + } + + if (handleType === connectionHandleType) { + if (handleType === 'source') { + return i18n.t('nodes.cannotConnectOutputToOutput'); + } + return i18n.t('nodes.cannotConnectInputToInput'); + } + + // we have to figure out which is the target and which is the source + const targetNodeId = handleType === 'target' ? nodeId : connectionNodeId; + const targetFieldName = handleType === 'target' ? fieldName : connectionFieldName; + const sourceNodeId = handleType === 'source' ? nodeId : connectionNodeId; + const sourceFieldName = handleType === 'source' ? fieldName : connectionFieldName; + + if ( + edges.find((edge) => { + edge.target === targetNodeId && + edge.targetHandle === targetFieldName && + edge.source === sourceNodeId && + edge.sourceHandle === sourceFieldName; + }) + ) { + // We already have a connection from this source to this target + return i18n.t('nodes.cannotDuplicateConnection'); + } + + const targetNode = nodes.find((node) => node.id === targetNodeId); + assert(targetNode, `Target node not found: ${targetNodeId}`); + const targetTemplate = templates[targetNode.data.type]; + assert(targetTemplate, `Target template not found: ${targetNode.data.type}`); + + if (targetTemplate.inputs[targetFieldName]?.input === 'direct') { + return i18n.t('nodes.cannotConnectToDirectInput'); + } + if (targetNode.data.type === 'collect' && targetFieldName === 'item') { + // Collect nodes shouldn't mix and match field types + const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); + if (collectItemType) { + if (!areTypesEqual(sourceType, collectItemType)) { + return i18n.t('nodes.cannotMixAndMatchCollectionItemTypes'); + } + } + } + + if ( + edges.find((edge) => { + return edge.target === targetNodeId && edge.targetHandle === targetFieldName; + }) && + // except CollectionItem inputs can have multiples + targetType.name !== 'CollectionItemField' + ) { + return i18n.t('nodes.inputMayOnlyHaveOneConnection'); + } + + if (!validateSourceAndTargetTypes(sourceType, targetType)) { + return i18n.t('nodes.fieldTypesMustMatch'); + } + + const hasCycles = getHasCycles( + connectionHandleType === 'source' ? connectionNodeId : nodeId, + connectionHandleType === 'source' ? nodeId : connectionNodeId, + nodes, + edges + ); + + if (hasCycles) { + return i18n.t('nodes.connectionWouldCreateCycle'); + } + + return; + }); +}; + +/** + * Validates that the source and target types are compatible for a connection. + * @param sourceType The type of the source field. + * @param targetType The type of the target field. + * @returns True if the connection is valid, false otherwise. + */ +export const validateSourceAndTargetTypes = (sourceType: FieldType, targetType: FieldType) => { + // TODO: There's a bug with Collect -> Iterate nodes: + // https://github.com/invoke-ai/InvokeAI/issues/3956 + // Once this is resolved, we can remove this check. + if (sourceType.name === 'CollectionField' && targetType.name === 'CollectionField') { + return false; + } + + if (areTypesEqual(sourceType, targetType)) { + return true; + } + + /** + * Connection types must be the same for a connection, with exceptions: + * - CollectionItem can connect to any non-Collection + * - Non-Collections can connect to CollectionItem + * - Anything (non-Collections, Collections, CollectionOrScalar) can connect to CollectionOrScalar of the same base type + * - Generic Collection can connect to any other Collection or CollectionOrScalar + * - Any Collection can connect to a Generic Collection + */ + const isCollectionItemToNonCollection = sourceType.name === 'CollectionItemField' && !targetType.isCollection; + + const isNonCollectionToCollectionItem = + targetType.name === 'CollectionItemField' && !sourceType.isCollection && !sourceType.isCollectionOrScalar; + + const isAnythingToCollectionOrScalarOfSameBaseType = + targetType.isCollectionOrScalar && sourceType.name === targetType.name; + + const isGenericCollectionToAnyCollectionOrCollectionOrScalar = + sourceType.name === 'CollectionField' && (targetType.isCollection || targetType.isCollectionOrScalar); + + const isCollectionToGenericCollection = targetType.name === 'CollectionField' && sourceType.isCollection; + + const areBothTypesSingle = + !sourceType.isCollection && + !sourceType.isCollectionOrScalar && + !targetType.isCollection && + !targetType.isCollectionOrScalar; + + const isIntToFloat = areBothTypesSingle && sourceType.name === 'IntegerField' && targetType.name === 'FloatField'; + + const isIntOrFloatToString = + areBothTypesSingle && + (sourceType.name === 'IntegerField' || sourceType.name === 'FloatField') && + targetType.name === 'StringField'; + + const isTargetAnyType = targetType.name === 'AnyField'; + + // One of these must be true for the connection to be valid + return ( + isCollectionItemToNonCollection || + isNonCollectionToCollectionItem || + isAnythingToCollectionOrScalarOfSameBaseType || + isGenericCollectionToAnyCollectionOrCollectionOrScalar || + isCollectionToGenericCollection || + isIntToFloat || + isIntOrFloatToString || + isTargetAnyType + ); +}; + +/** + * Checks if two types are equal. If the field types have original types, those are also compared. Any match is + * considered equal. For example, if the source type and original target type match, the types are considered equal. + * @param sourceType The type of the source field. + * @param targetType The type of the target field. + * @returns True if the types are equal, false otherwise. + */ +export const areTypesEqual = (sourceType: FieldType, targetType: FieldType) => { + const _sourceType = isStatefulFieldType(sourceType) ? omit(sourceType, 'originalType') : sourceType; + const _targetType = isStatefulFieldType(targetType) ? omit(targetType, 'originalType') : targetType; + const _sourceTypeOriginal = isStatefulFieldType(sourceType) ? sourceType.originalType : sourceType; + const _targetTypeOriginal = isStatefulFieldType(targetType) ? targetType.originalType : targetType; + if (isEqual(_sourceType, _targetType)) { + return true; + } + if (_targetTypeOriginal && isEqual(_sourceType, _targetTypeOriginal)) { + return true; + } + if (_sourceTypeOriginal && isEqual(_sourceTypeOriginal, _targetType)) { + return true; + } + if (_sourceTypeOriginal && _targetTypeOriginal && isEqual(_sourceTypeOriginal, _targetTypeOriginal)) { + return true; + } + return false; +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts b/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts deleted file mode 100644 index e0411ee67e..0000000000 --- a/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts +++ /dev/null @@ -1,105 +0,0 @@ -import type { PendingConnection, Templates } from 'features/nodes/store/types'; -import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector'; -import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation'; -import { differenceWith, map } from 'lodash-es'; -import type { Connection } from 'reactflow'; -import { assert } from 'tsafe'; - -import { getIsGraphAcyclic } from './getIsGraphAcyclic'; -import { areTypesEqual, validateSourceAndTargetTypes } from './validateSourceAndTargetTypes'; - -export const getFirstValidConnection = ( - templates: Templates, - nodes: AnyNode[], - edges: InvocationNodeEdge[], - pendingConnection: PendingConnection, - candidateNode: InvocationNode, - candidateTemplate: InvocationTemplate -): Connection | null => { - if (pendingConnection.node.id === candidateNode.id) { - // Cannot connect to self - return null; - } - - const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source'; - - if (pendingFieldKind === 'source') { - // Connecting from a source to a target - if (!getIsGraphAcyclic(pendingConnection.node.id, candidateNode.id, nodes, edges)) { - return null; - } - if (candidateNode.data.type === 'collect') { - // Special handling for collect node - the `item` field takes any number of connections - return { - source: pendingConnection.node.id, - sourceHandle: pendingConnection.fieldTemplate.name, - target: candidateNode.id, - targetHandle: 'item', - }; - } - // Only one connection per target field is allowed - look for an unconnected target field - const candidateFields = map(candidateTemplate.inputs).filter((i) => i.input !== 'direct'); - const candidateConnectedFields = edges - .filter((edge) => edge.target === candidateNode.id) - .map((edge) => { - // Edges must always have a targetHandle, safe to assert here - assert(edge.targetHandle); - return edge.targetHandle; - }); - const candidateUnconnectedFields = differenceWith( - candidateFields, - candidateConnectedFields, - (field, connectedFieldName) => field.name === connectedFieldName - ); - const candidateField = candidateUnconnectedFields.find((field) => - validateSourceAndTargetTypes(pendingConnection.fieldTemplate.type, field.type) - ); - if (candidateField) { - return { - source: pendingConnection.node.id, - sourceHandle: pendingConnection.fieldTemplate.name, - target: candidateNode.id, - targetHandle: candidateField.name, - }; - } - } else { - // Connecting from a target to a source - // Ensure we there is not already an edge to the target, except for collect nodes - const isCollect = pendingConnection.node.data.type === 'collect'; - const isTargetAlreadyConnected = edges.some( - (e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name - ); - if (!isCollect && isTargetAlreadyConnected) { - return null; - } - - if (!getIsGraphAcyclic(candidateNode.id, pendingConnection.node.id, nodes, edges)) { - return null; - } - - // Sources/outputs can have any number of edges, we can take the first matching output field - let candidateFields = map(candidateTemplate.outputs); - if (isCollect) { - // Narrow candidates to same field type as already is connected to the collect node - const collectItemType = getCollectItemType(templates, nodes, edges, pendingConnection.node.id); - if (collectItemType) { - candidateFields = candidateFields.filter((field) => areTypesEqual(field.type, collectItemType)); - } - } - const candidateField = candidateFields.find((field) => { - const isValid = validateSourceAndTargetTypes(field.type, pendingConnection.fieldTemplate.type); - const isAlreadyConnected = edges.some((e) => e.source === candidateNode.id && e.sourceHandle === field.name); - return isValid && !isAlreadyConnected; - }); - if (candidateField) { - return { - source: candidateNode.id, - sourceHandle: candidateField.name, - target: pendingConnection.node.id, - targetHandle: pendingConnection.fieldTemplate.name, - }; - } - } - - return null; -}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getIsGraphAcyclic.ts b/invokeai/frontend/web/src/features/nodes/store/util/getIsGraphAcyclic.ts deleted file mode 100644 index 2ef1c64c0e..0000000000 --- a/invokeai/frontend/web/src/features/nodes/store/util/getIsGraphAcyclic.ts +++ /dev/null @@ -1,21 +0,0 @@ -import graphlib from '@dagrejs/graphlib'; -import type { Edge, Node } from 'reactflow'; - -export const getIsGraphAcyclic = (source: string, target: string, nodes: Node[], edges: Edge[]) => { - // construct graphlib graph from editor state - const g = new graphlib.Graph(); - - nodes.forEach((n) => { - g.setNode(n.id); - }); - - edges.forEach((e) => { - g.setEdge(e.source, e.target); - }); - - // add the candidate edge - g.setEdge(source, target); - - // check if the graph is acyclic - return graphlib.alg.isAcyclic(g); -}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts deleted file mode 100644 index 5a5972a376..0000000000 --- a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts +++ /dev/null @@ -1,146 +0,0 @@ -import { createSelector } from '@reduxjs/toolkit'; -import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import type { PendingConnection, Templates } from 'features/nodes/store/types'; -import type { FieldType } from 'features/nodes/types/field'; -import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation'; -import i18n from 'i18next'; -import type { HandleType } from 'reactflow'; -import { assert } from 'tsafe'; - -import { getIsGraphAcyclic } from './getIsGraphAcyclic'; -import { areTypesEqual, validateSourceAndTargetTypes } from './validateSourceAndTargetTypes'; - -export const getCollectItemType = ( - templates: Templates, - nodes: AnyNode[], - edges: InvocationNodeEdge[], - nodeId: string -): FieldType | null => { - const firstEdgeToCollect = edges.find((edge) => edge.target === nodeId && edge.targetHandle === 'item'); - if (!firstEdgeToCollect?.sourceHandle) { - return null; - } - const node = nodes.find((n) => n.id === firstEdgeToCollect.source); - if (!node) { - return null; - } - const template = templates[node.data.type]; - if (!template) { - return null; - } - const fieldType = template.outputs[firstEdgeToCollect.sourceHandle]?.type ?? null; - return fieldType; -}; - -/** - * NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts` - * TODO: Figure out how to do this without duplicating all the logic - */ - -export const makeConnectionErrorSelector = ( - templates: Templates, - pendingConnection: PendingConnection | null, - nodeId: string, - fieldName: string, - handleType: HandleType, - fieldType?: FieldType | null -) => { - return createSelector(selectNodesSlice, (nodesSlice) => { - const { nodes, edges } = nodesSlice; - - if (!fieldType) { - return i18n.t('nodes.noFieldType'); - } - - if (!pendingConnection) { - return i18n.t('nodes.noConnectionInProgress'); - } - - const connectionNodeId = pendingConnection.node.id; - const connectionFieldName = pendingConnection.fieldTemplate.name; - const connectionHandleType = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source'; - const connectionStartFieldType = pendingConnection.fieldTemplate.type; - - if (!connectionHandleType || !connectionNodeId || !connectionFieldName) { - return i18n.t('nodes.noConnectionData'); - } - - const targetType = handleType === 'target' ? fieldType : connectionStartFieldType; - const sourceType = handleType === 'source' ? fieldType : connectionStartFieldType; - - if (nodeId === connectionNodeId) { - return i18n.t('nodes.cannotConnectToSelf'); - } - - if (handleType === connectionHandleType) { - if (handleType === 'source') { - return i18n.t('nodes.cannotConnectOutputToOutput'); - } - return i18n.t('nodes.cannotConnectInputToInput'); - } - - // we have to figure out which is the target and which is the source - const targetNodeId = handleType === 'target' ? nodeId : connectionNodeId; - const targetFieldName = handleType === 'target' ? fieldName : connectionFieldName; - const sourceNodeId = handleType === 'source' ? nodeId : connectionNodeId; - const sourceFieldName = handleType === 'source' ? fieldName : connectionFieldName; - - if ( - edges.find((edge) => { - edge.target === targetNodeId && - edge.targetHandle === targetFieldName && - edge.source === sourceNodeId && - edge.sourceHandle === sourceFieldName; - }) - ) { - // We already have a connection from this source to this target - return i18n.t('nodes.cannotDuplicateConnection'); - } - - const targetNode = nodes.find((node) => node.id === targetNodeId); - assert(targetNode, `Target node not found: ${targetNodeId}`); - const targetTemplate = templates[targetNode.data.type]; - assert(targetTemplate, `Target template not found: ${targetNode.data.type}`); - - if (targetTemplate.inputs[targetFieldName]?.input === 'direct') { - return i18n.t('nodes.cannotConnectToDirectInput'); - } - - if (targetNode?.data.type === 'collect' && targetFieldName === 'item') { - // Collect nodes shouldn't mix and match field types - const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); - if (collectItemType) { - if (!areTypesEqual(sourceType, collectItemType)) { - return i18n.t('nodes.cannotMixAndMatchCollectionItemTypes'); - } - } - } - - if ( - edges.find((edge) => { - return edge.target === targetNodeId && edge.targetHandle === targetFieldName; - }) && - // except CollectionItem inputs can have multiples - targetType.name !== 'CollectionItemField' - ) { - return i18n.t('nodes.inputMayOnlyHaveOneConnection'); - } - - if (!validateSourceAndTargetTypes(sourceType, targetType)) { - return i18n.t('nodes.fieldTypesMustMatch'); - } - - const isGraphAcyclic = getIsGraphAcyclic( - connectionHandleType === 'source' ? connectionNodeId : nodeId, - connectionHandleType === 'source' ? nodeId : connectionNodeId, - nodes, - edges - ); - - if (!isGraphAcyclic) { - return i18n.t('nodes.connectionWouldCreateCycle'); - } - - return; - }); -}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts deleted file mode 100644 index 45b771b5b4..0000000000 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts +++ /dev/null @@ -1,90 +0,0 @@ -import { type FieldType, isStatefulFieldType } from 'features/nodes/types/field'; -import { isEqual, omit } from 'lodash-es'; - -export const areTypesEqual = (sourceType: FieldType, targetType: FieldType) => { - const _sourceType = isStatefulFieldType(sourceType) ? omit(sourceType, 'originalType') : sourceType; - const _targetType = isStatefulFieldType(targetType) ? omit(targetType, 'originalType') : targetType; - const _sourceTypeOriginal = isStatefulFieldType(sourceType) ? sourceType.originalType : sourceType; - const _targetTypeOriginal = isStatefulFieldType(targetType) ? targetType.originalType : targetType; - if (isEqual(_sourceType, _targetType)) { - return true; - } - if (_targetTypeOriginal && isEqual(_sourceType, _targetTypeOriginal)) { - return true; - } - if (_sourceTypeOriginal && isEqual(_sourceTypeOriginal, _targetType)) { - return true; - } - if (_sourceTypeOriginal && _targetTypeOriginal && isEqual(_sourceTypeOriginal, _targetTypeOriginal)) { - return true; - } - return false; -}; - -/** - * Validates that the source and target types are compatible for a connection. - * @param sourceType The type of the source field. - * @param targetType The type of the target field. - * @returns True if the connection is valid, false otherwise. - */ -export const validateSourceAndTargetTypes = (sourceType: FieldType, targetType: FieldType) => { - // TODO: There's a bug with Collect -> Iterate nodes: - // https://github.com/invoke-ai/InvokeAI/issues/3956 - // Once this is resolved, we can remove this check. - if (sourceType.name === 'CollectionField' && targetType.name === 'CollectionField') { - return false; - } - - if (areTypesEqual(sourceType, targetType)) { - return true; - } - - /** - * Connection types must be the same for a connection, with exceptions: - * - CollectionItem can connect to any non-Collection - * - Non-Collections can connect to CollectionItem - * - Anything (non-Collections, Collections, CollectionOrScalar) can connect to CollectionOrScalar of the same base type - * - Generic Collection can connect to any other Collection or CollectionOrScalar - * - Any Collection can connect to a Generic Collection - */ - - const isCollectionItemToNonCollection = sourceType.name === 'CollectionItemField' && !targetType.isCollection; - - const isNonCollectionToCollectionItem = - targetType.name === 'CollectionItemField' && !sourceType.isCollection && !sourceType.isCollectionOrScalar; - - const isAnythingToCollectionOrScalarOfSameBaseType = - targetType.isCollectionOrScalar && sourceType.name === targetType.name; - - const isGenericCollectionToAnyCollectionOrCollectionOrScalar = - sourceType.name === 'CollectionField' && (targetType.isCollection || targetType.isCollectionOrScalar); - - const isCollectionToGenericCollection = targetType.name === 'CollectionField' && sourceType.isCollection; - - const areBothTypesSingle = - !sourceType.isCollection && - !sourceType.isCollectionOrScalar && - !targetType.isCollection && - !targetType.isCollectionOrScalar; - - const isIntToFloat = areBothTypesSingle && sourceType.name === 'IntegerField' && targetType.name === 'FloatField'; - - const isIntOrFloatToString = - areBothTypesSingle && - (sourceType.name === 'IntegerField' || sourceType.name === 'FloatField') && - targetType.name === 'StringField'; - - const isTargetAnyType = targetType.name === 'AnyField'; - - // One of these must be true for the connection to be valid - return ( - isCollectionItemToNonCollection || - isNonCollectionToCollectionItem || - isAnythingToCollectionOrScalarOfSameBaseType || - isGenericCollectionToAnyCollectionOrCollectionOrScalar || - isCollectionToGenericCollection || - isIntToFloat || - isIntOrFloatToString || - isTargetAnyType - ); -}; From 9d127fee6bc7ca2ac93ec91a65fbfb4fb27fbc39 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 18 May 2024 17:28:56 +1000 Subject: [PATCH 016/207] feat(ui): makeConnectionErrorSelector now creates a parameterized selector --- .../nodes/hooks/useConnectionState.ts | 14 +- .../nodes/store/util/connectionValidation.ts | 176 +++++++++--------- 2 files changed, 93 insertions(+), 97 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts index dfa8b0cf36..9571ce2ee2 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts @@ -34,16 +34,8 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta ); const selectConnectionError = useMemo( - () => - makeConnectionErrorSelector( - templates, - pendingConnection, - nodeId, - fieldName, - kind === 'inputs' ? 'target' : 'source', - fieldType - ), - [templates, pendingConnection, nodeId, fieldName, kind, fieldType] + () => makeConnectionErrorSelector(templates, nodeId, fieldName, kind === 'inputs' ? 'target' : 'source', fieldType), + [templates, nodeId, fieldName, kind, fieldType] ); const isConnected = useAppSelector(selectIsConnected); @@ -58,7 +50,7 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind] ); }, [fieldName, kind, nodeId, pendingConnection]); - const connectionError = useAppSelector(selectConnectionError); + const connectionError = useAppSelector((s) => selectConnectionError(s, pendingConnection)); const shouldDim = useMemo( () => Boolean(isConnectionInProgress && connectionError && !isConnectionStartField), diff --git a/invokeai/frontend/web/src/features/nodes/store/util/connectionValidation.ts b/invokeai/frontend/web/src/features/nodes/store/util/connectionValidation.ts index 98de4284ad..907426b51d 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/connectionValidation.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/connectionValidation.ts @@ -1,7 +1,8 @@ import graphlib from '@dagrejs/graphlib'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import type { RootState } from 'app/store/store'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import type { PendingConnection, Templates } from 'features/nodes/store/types'; +import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types'; import { type FieldType, isStatefulFieldType } from 'features/nodes/types/field'; import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation'; import i18n from 'i18next'; @@ -190,105 +191,108 @@ export const getCollectItemType = ( */ export const makeConnectionErrorSelector = ( templates: Templates, - pendingConnection: PendingConnection | null, nodeId: string, fieldName: string, handleType: HandleType, fieldType: FieldType ) => { - return createMemoizedSelector(selectNodesSlice, (nodesSlice) => { - const { nodes, edges } = nodesSlice; + return createMemoizedSelector( + selectNodesSlice, + (state: RootState, pendingConnection: PendingConnection | null) => pendingConnection, + (nodesSlice: NodesState, pendingConnection: PendingConnection | null) => { + const { nodes, edges } = nodesSlice; - if (!pendingConnection) { - return i18n.t('nodes.noConnectionInProgress'); - } - - const connectionNodeId = pendingConnection.node.id; - const connectionFieldName = pendingConnection.fieldTemplate.name; - const connectionHandleType = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source'; - const connectionStartFieldType = pendingConnection.fieldTemplate.type; - - if (!connectionHandleType || !connectionNodeId || !connectionFieldName) { - return i18n.t('nodes.noConnectionData'); - } - - const targetType = handleType === 'target' ? fieldType : connectionStartFieldType; - const sourceType = handleType === 'source' ? fieldType : connectionStartFieldType; - - if (nodeId === connectionNodeId) { - return i18n.t('nodes.cannotConnectToSelf'); - } - - if (handleType === connectionHandleType) { - if (handleType === 'source') { - return i18n.t('nodes.cannotConnectOutputToOutput'); + if (!pendingConnection) { + return i18n.t('nodes.noConnectionInProgress'); } - return i18n.t('nodes.cannotConnectInputToInput'); - } - // we have to figure out which is the target and which is the source - const targetNodeId = handleType === 'target' ? nodeId : connectionNodeId; - const targetFieldName = handleType === 'target' ? fieldName : connectionFieldName; - const sourceNodeId = handleType === 'source' ? nodeId : connectionNodeId; - const sourceFieldName = handleType === 'source' ? fieldName : connectionFieldName; + const connectionNodeId = pendingConnection.node.id; + const connectionFieldName = pendingConnection.fieldTemplate.name; + const connectionHandleType = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source'; + const connectionStartFieldType = pendingConnection.fieldTemplate.type; - if ( - edges.find((edge) => { - edge.target === targetNodeId && - edge.targetHandle === targetFieldName && - edge.source === sourceNodeId && - edge.sourceHandle === sourceFieldName; - }) - ) { - // We already have a connection from this source to this target - return i18n.t('nodes.cannotDuplicateConnection'); - } + if (!connectionHandleType || !connectionNodeId || !connectionFieldName) { + return i18n.t('nodes.noConnectionData'); + } - const targetNode = nodes.find((node) => node.id === targetNodeId); - assert(targetNode, `Target node not found: ${targetNodeId}`); - const targetTemplate = templates[targetNode.data.type]; - assert(targetTemplate, `Target template not found: ${targetNode.data.type}`); + const targetType = handleType === 'target' ? fieldType : connectionStartFieldType; + const sourceType = handleType === 'source' ? fieldType : connectionStartFieldType; - if (targetTemplate.inputs[targetFieldName]?.input === 'direct') { - return i18n.t('nodes.cannotConnectToDirectInput'); - } - if (targetNode.data.type === 'collect' && targetFieldName === 'item') { - // Collect nodes shouldn't mix and match field types - const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); - if (collectItemType) { - if (!areTypesEqual(sourceType, collectItemType)) { - return i18n.t('nodes.cannotMixAndMatchCollectionItemTypes'); + if (nodeId === connectionNodeId) { + return i18n.t('nodes.cannotConnectToSelf'); + } + + if (handleType === connectionHandleType) { + if (handleType === 'source') { + return i18n.t('nodes.cannotConnectOutputToOutput'); + } + return i18n.t('nodes.cannotConnectInputToInput'); + } + + // we have to figure out which is the target and which is the source + const targetNodeId = handleType === 'target' ? nodeId : connectionNodeId; + const targetFieldName = handleType === 'target' ? fieldName : connectionFieldName; + const sourceNodeId = handleType === 'source' ? nodeId : connectionNodeId; + const sourceFieldName = handleType === 'source' ? fieldName : connectionFieldName; + + if ( + edges.find((edge) => { + edge.target === targetNodeId && + edge.targetHandle === targetFieldName && + edge.source === sourceNodeId && + edge.sourceHandle === sourceFieldName; + }) + ) { + // We already have a connection from this source to this target + return i18n.t('nodes.cannotDuplicateConnection'); + } + + const targetNode = nodes.find((node) => node.id === targetNodeId); + assert(targetNode, `Target node not found: ${targetNodeId}`); + const targetTemplate = templates[targetNode.data.type]; + assert(targetTemplate, `Target template not found: ${targetNode.data.type}`); + + if (targetTemplate.inputs[targetFieldName]?.input === 'direct') { + return i18n.t('nodes.cannotConnectToDirectInput'); + } + if (targetNode.data.type === 'collect' && targetFieldName === 'item') { + // Collect nodes shouldn't mix and match field types + const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); + if (collectItemType) { + if (!areTypesEqual(sourceType, collectItemType)) { + return i18n.t('nodes.cannotMixAndMatchCollectionItemTypes'); + } } } + + if ( + edges.find((edge) => { + return edge.target === targetNodeId && edge.targetHandle === targetFieldName; + }) && + // except CollectionItem inputs can have multiples + targetType.name !== 'CollectionItemField' + ) { + return i18n.t('nodes.inputMayOnlyHaveOneConnection'); + } + + if (!validateSourceAndTargetTypes(sourceType, targetType)) { + return i18n.t('nodes.fieldTypesMustMatch'); + } + + const hasCycles = getHasCycles( + connectionHandleType === 'source' ? connectionNodeId : nodeId, + connectionHandleType === 'source' ? nodeId : connectionNodeId, + nodes, + edges + ); + + if (hasCycles) { + return i18n.t('nodes.connectionWouldCreateCycle'); + } + + return; } - - if ( - edges.find((edge) => { - return edge.target === targetNodeId && edge.targetHandle === targetFieldName; - }) && - // except CollectionItem inputs can have multiples - targetType.name !== 'CollectionItemField' - ) { - return i18n.t('nodes.inputMayOnlyHaveOneConnection'); - } - - if (!validateSourceAndTargetTypes(sourceType, targetType)) { - return i18n.t('nodes.fieldTypesMustMatch'); - } - - const hasCycles = getHasCycles( - connectionHandleType === 'source' ? connectionNodeId : nodeId, - connectionHandleType === 'source' ? nodeId : connectionNodeId, - nodes, - edges - ); - - if (hasCycles) { - return i18n.t('nodes.connectionWouldCreateCycle'); - } - - return; - }); + ); }; /** From 468644ab189edb82dfb3cb9cedcd4e94df705a2e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 18 May 2024 17:37:28 +1000 Subject: [PATCH 017/207] fix(ui): rebase conflict --- .../web/src/features/nodes/store/util/connectionValidation.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/invokeai/frontend/web/src/features/nodes/store/util/connectionValidation.ts b/invokeai/frontend/web/src/features/nodes/store/util/connectionValidation.ts index 907426b51d..a2f723fcfe 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/connectionValidation.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/connectionValidation.ts @@ -255,6 +255,7 @@ export const makeConnectionErrorSelector = ( if (targetTemplate.inputs[targetFieldName]?.input === 'direct') { return i18n.t('nodes.cannotConnectToDirectInput'); } + if (targetNode.data.type === 'collect' && targetFieldName === 'item') { // Collect nodes shouldn't mix and match field types const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); From 9f7841a04bb09403dd744c001dbb0b661df03388 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 18 May 2024 18:46:03 +1000 Subject: [PATCH 018/207] tidy(ui): clean up addnodepopover hotkeys --- .../flow/AddNodePopover/AddNodePopover.tsx | 25 ++++++++----------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx index 40fa13320a..214fc069f9 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx @@ -21,7 +21,6 @@ import { getFirstValidConnection, validateSourceAndTargetTypes } from 'features/ import type { AnyNode } from 'features/nodes/types/invocation'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { filter, map, memoize, some } from 'lodash-es'; -import type { KeyboardEventHandler } from 'react'; import { memo, useCallback, useMemo, useRef } from 'react'; import { flushSync } from 'react-dom'; import { useHotkeys } from 'react-hotkeys-hook'; @@ -159,25 +158,24 @@ const AddNodePopover = () => { ); const handleHotkeyOpen: HotkeyCallback = useCallback((e) => { - e.preventDefault(); - openAddNodePopover(); - flushSync(() => { - selectRef.current?.inputRef?.focus(); - }); + if (!$isAddNodePopoverOpen.get()) { + e.preventDefault(); + openAddNodePopover(); + flushSync(() => { + selectRef.current?.inputRef?.focus(); + }); + } }, []); const handleHotkeyClose: HotkeyCallback = useCallback(() => { - closeAddNodePopover(); - }, []); - - useHotkeys(['shift+a', 'space'], handleHotkeyOpen); - useHotkeys(['escape'], handleHotkeyClose); - const onKeyDown: KeyboardEventHandler = useCallback((e) => { - if (e.key === 'Escape') { + if ($isAddNodePopoverOpen.get()) { closeAddNodePopover(); } }, []); + useHotkeys(['shift+a', 'space'], handleHotkeyOpen); + useHotkeys(['escape'], handleHotkeyClose, { enableOnFormTags: ['TEXTAREA'] }); + const noOptionsMessage = useCallback(() => t('nodes.noMatchingNodes'), [t]); return ( @@ -214,7 +212,6 @@ const AddNodePopover = () => { filterOption={filterOption} onChange={onChange} onMenuClose={closeAddNodePopover} - onKeyDown={onKeyDown} inputRef={inputRef} closeMenuOnSelect={false} /> From 6b4e464d1780b53b6d7c05edefdeec2efa958085 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 18 May 2024 18:55:37 +1000 Subject: [PATCH 019/207] fix(ui): rework edge update logic --- .../features/nodes/components/flow/Flow.tsx | 76 ++++++++++--------- .../src/features/nodes/store/nodesSlice.ts | 11 ++- 2 files changed, 46 insertions(+), 41 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx index 656de737c7..501513919a 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx @@ -8,12 +8,13 @@ import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection' import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher'; import { $cursorPos, + $didUpdateEdge, $isAddNodePopoverOpen, $isUpdatingEdge, + $lastEdgeUpdateMouseEvent, $pendingConnection, $viewport, connectionMade, - edgeAdded, edgeDeleted, edgesChanged, edgesDeleted, @@ -24,6 +25,7 @@ import { undo, } from 'features/nodes/store/nodesSlice'; import { $flow } from 'features/nodes/store/reactFlowInstance'; +import { isString } from 'lodash-es'; import type { CSSProperties, MouseEvent } from 'react'; import { memo, useCallback, useMemo, useRef } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; @@ -39,7 +41,7 @@ import type { ReactFlowProps, ReactFlowState, } from 'reactflow'; -import { Background, ReactFlow, useStore as useReactFlowStore } from 'reactflow'; +import { Background, ReactFlow, useStore as useReactFlowStore, useUpdateNodeInternals } from 'reactflow'; import CustomConnectionLine from './connectionLines/CustomConnectionLine'; import InvocationCollapsedEdge from './edges/InvocationCollapsedEdge'; @@ -81,6 +83,7 @@ export const Flow = memo(() => { const flowWrapper = useRef(null); const isValidConnection = useIsValidConnection(); const cancelConnection = useReactFlowStore(selectCancelConnection); + const updateNodeInternals = useUpdateNodeInternals(); useWorkflowWatcher(); useSyncExecutionState(); const [borderRadius] = useToken('radii', ['base']); @@ -157,45 +160,46 @@ export const Flow = memo(() => { * where the edge is deleted if you click it accidentally). */ - // We have a ref for cursor position, but it is the *projected* cursor position. - // Easiest to just keep track of the last mouse event for this particular feature - const edgeUpdateMouseEvent = useRef(); - - const onEdgeUpdateStart: NonNullable = useCallback( - (e, edge, _handleType) => { - $isUpdatingEdge.set(true); - // update mouse event - edgeUpdateMouseEvent.current = e; - // always delete the edge when starting an updated - dispatch(edgeDeleted(edge.id)); - }, - [dispatch] - ); + const onEdgeUpdateStart: NonNullable = useCallback((e, _edge, _handleType) => { + $isUpdatingEdge.set(true); + $didUpdateEdge.set(false); + $lastEdgeUpdateMouseEvent.set(e); + }, []); const onEdgeUpdate: OnEdgeUpdateFunc = useCallback( - (_oldEdge, newConnection) => { - // Because we deleted the edge when the update started, we must create a new edge from the connection + (edge, newConnection) => { + // This event is fired when an edge update is successful + $didUpdateEdge.set(true); + // When an edge update is successful, we need to delete the old edge and create a new one + dispatch(edgeDeleted(edge.id)); dispatch(connectionMade(newConnection)); + // Because we shift the position of handles depending on whether a field is connected or not, we must use + // updateNodeInternals to tell reactflow to recalculate the positions of the handles + const nodesToUpdate = [edge.source, edge.target, newConnection.source, newConnection.target].filter(isString); + updateNodeInternals(nodesToUpdate); }, - [dispatch] + [dispatch, updateNodeInternals] ); const onEdgeUpdateEnd: NonNullable = useCallback( (e, edge, _handleType) => { - $isUpdatingEdge.set(false); - $pendingConnection.set(null); - // Handle the case where user begins a drag but didn't move the cursor - we deleted the edge when starting - // the edge update - we need to add it back - if ( - // ignore touch events - !('touches' in e) && - edgeUpdateMouseEvent.current?.clientX === e.clientX && - edgeUpdateMouseEvent.current?.clientY === e.clientY - ) { - dispatch(edgeAdded(edge)); + const didUpdateEdge = $didUpdateEdge.get(); + // Fall back to a reasonable default event + const lastEvent = $lastEdgeUpdateMouseEvent.get() ?? { clientX: 0, clientY: 0 }; + // We have to narrow this event down to MouseEvents - could be TouchEvent + const didMouseMove = + !('touches' in e) && Math.hypot(e.clientX - lastEvent.clientX, e.clientY - lastEvent.clientY) > 5; + + // If we got this far and did not successfully update an edge, and the mouse moved away from the handle, + // the user probably intended to delete the edge + if (!didUpdateEdge && didMouseMove) { + dispatch(edgeDeleted(edge.id)); } - // reset mouse event - edgeUpdateMouseEvent.current = undefined; + + $isUpdatingEdge.set(false); + $didUpdateEdge.set(false); + $pendingConnection.set(null); + $lastEdgeUpdateMouseEvent.set(null); }, [dispatch] ); @@ -255,9 +259,11 @@ export const Flow = memo(() => { useHotkeys(['meta+shift+z', 'ctrl+shift+z'], onRedoHotkey); const onEscapeHotkey = useCallback(() => { - $pendingConnection.set(null); - $isAddNodePopoverOpen.set(false); - cancelConnection(); + if (!$isUpdatingEdge.get()) { + $pendingConnection.set(null); + $isAddNodePopoverOpen.set(false); + cancelConnection(); + } }, [cancelConnection]); useHotkeys('esc', onEscapeHotkey); diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index cec13e8df4..83632c16e1 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -47,6 +47,7 @@ import { import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation'; import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation'; import { atom } from 'nanostores'; +import type { MouseEvent } from 'react'; import type { Connection, Edge, EdgeChange, EdgeRemoveChange, Node, NodeChange, Viewport, XYPosition } from 'reactflow'; import { addEdge, applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from 'reactflow'; import type { UndoableOptions } from 'redux-undo'; @@ -125,9 +126,6 @@ export const nodesSlice = createSlice({ edgesChanged: (state, action: PayloadAction) => { state.edges = applyEdgeChanges(action.payload, state.edges); }, - edgeAdded: (state, action: PayloadAction) => { - state.edges = addEdge(action.payload, state.edges); - }, connectionMade: (state, action: PayloadAction) => { state.edges = addEdge({ ...action.payload, type: 'default' }, state.edges); }, @@ -495,7 +493,6 @@ export const { notesNodeValueChanged, selectedAll, selectionPasted, - edgeAdded, undo, redo, } = nodesSlice.actions; @@ -507,6 +504,9 @@ export const $copiedEdges = atom([]); export const $edgesToCopiedNodes = atom([]); export const $pendingConnection = atom(null); export const $isUpdatingEdge = atom(false); +export const $didUpdateEdge = atom(false); +export const $lastEdgeUpdateMouseEvent = atom(null); + export const $viewport = atom({ x: 0, y: 0, zoom: 1 }); export const $isAddNodePopoverOpen = atom(false); export const closeAddNodePopover = () => { @@ -609,6 +609,5 @@ export const isAnyNodeOrEdgeMutation = isAnyOf( nodesDeleted, nodeUseCacheChanged, notesNodeValueChanged, - selectionPasted, - edgeAdded + selectionPasted ); From 6f7160b9fd70d1cccefd2fb3ddd397e8580ed86d Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 18 May 2024 18:59:14 +1000 Subject: [PATCH 020/207] fix(ui): call updateNodeInternals when making connections --- .../web/src/features/nodes/hooks/useConnection.ts | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts index 81eea993be..f0dba67bf5 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts @@ -10,13 +10,15 @@ import { } from 'features/nodes/store/nodesSlice'; import { getFirstValidConnection } from 'features/nodes/store/util/connectionValidation'; import { isInvocationNode } from 'features/nodes/types/invocation'; +import { isString } from 'lodash-es'; import { useCallback, useMemo } from 'react'; -import type { OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow'; +import { type OnConnect, type OnConnectEnd, type OnConnectStart, useUpdateNodeInternals } from 'reactflow'; import { assert } from 'tsafe'; export const useConnection = () => { const store = useAppStore(); const templates = useStore($templates); + const updateNodeInternals = useUpdateNodeInternals(); const onConnectStart = useCallback( (event, params) => { @@ -41,9 +43,11 @@ export const useConnection = () => { (connection) => { const { dispatch } = store; dispatch(connectionMade(connection)); + const nodesToUpdate = [connection.source, connection.target].filter(isString); + updateNodeInternals(nodesToUpdate); $pendingConnection.set(null); }, - [store] + [store, updateNodeInternals] ); const onConnectEnd = useCallback(() => { const { dispatch } = store; @@ -80,13 +84,15 @@ export const useConnection = () => { ); if (connection) { dispatch(connectionMade(connection)); + const nodesToUpdate = [connection.source, connection.target].filter(isString); + updateNodeInternals(nodesToUpdate); } $pendingConnection.set(null); } else { // The mouse is not over a node - we should open the add node popover $isAddNodePopoverOpen.set(true); } - }, [store, templates]); + }, [store, templates, updateNodeInternals]); const api = useMemo(() => ({ onConnectStart, onConnect, onConnectEnd }), [onConnectStart, onConnect, onConnectEnd]); return api; From 3fcb2720d73c6ad13e972ad5f525448579f5796e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 00:11:15 +1000 Subject: [PATCH 021/207] tests(ui): add tests for consolidated connection validation --- invokeai/frontend/web/public/locales/en.json | 3 + .../flow/AddNodePopover/AddNodePopover.tsx | 5 +- .../src/features/nodes/hooks/useConnection.ts | 2 +- .../nodes/hooks/useIsValidConnection.ts | 12 +- .../nodes/store/util/areTypesEqual.test.ts | 101 ++ .../nodes/store/util/areTypesEqual.ts | 30 + .../nodes/store/util/connectionValidation.ts | 271 +---- .../store/util/getCollectItemType.test.ts | 16 + .../nodes/store/util/getCollectItemType.ts | 35 + .../store/util/getFirstValidConnection.ts | 116 ++ .../nodes/store/util/getHasCycles.test.ts | 23 + .../features/nodes/store/util/getHasCycles.ts | 30 + .../features/nodes/store/util/testUtils.ts | 1073 +++++++++++++++++ .../store/util/validateConnection.test.ts | 149 +++ .../nodes/store/util/validateConnection.ts | 109 ++ .../util/validateConnectionTypes.test.ts | 222 ++++ .../store/util/validateConnectionTypes.ts | 69 ++ .../web/src/features/nodes/types/field.ts | 19 - .../nodes/util/schema/parseSchema.test.ts | 937 +------------- 19 files changed, 1999 insertions(+), 1223 deletions(-) create mode 100644 invokeai/frontend/web/src/features/nodes/store/util/areTypesEqual.test.ts create mode 100644 invokeai/frontend/web/src/features/nodes/store/util/areTypesEqual.ts create mode 100644 invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts create mode 100644 invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.ts create mode 100644 invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts create mode 100644 invokeai/frontend/web/src/features/nodes/store/util/getHasCycles.test.ts create mode 100644 invokeai/frontend/web/src/features/nodes/store/util/getHasCycles.ts create mode 100644 invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts create mode 100644 invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts create mode 100644 invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts create mode 100644 invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.test.ts create mode 100644 invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 7de7a8e01c..1f44e641fc 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -775,6 +775,9 @@ "cannotConnectToSelf": "Cannot connect to self", "cannotDuplicateConnection": "Cannot create duplicate connections", "cannotMixAndMatchCollectionItemTypes": "Cannot mix and match collection item types", + "missingNode": "Missing invocation node", + "missingInvocationTemplate": "Missing invocation template", + "missingFieldTemplate": "Missing field template", "nodePack": "Node pack", "collection": "Collection", "collectionFieldType": "{{name}} Collection", diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx index 214fc069f9..14d69b4720 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx @@ -17,7 +17,8 @@ import { nodeAdded, openAddNodePopover, } from 'features/nodes/store/nodesSlice'; -import { getFirstValidConnection, validateSourceAndTargetTypes } from 'features/nodes/store/util/connectionValidation'; +import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection'; +import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; import type { AnyNode } from 'features/nodes/types/invocation'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { filter, map, memoize, some } from 'lodash-es'; @@ -77,7 +78,7 @@ const AddNodePopover = () => { return some(fields, (field) => { const sourceType = pendingFieldKind === 'input' ? field.type : pendingConnection.fieldTemplate.type; const targetType = pendingFieldKind === 'output' ? field.type : pendingConnection.fieldTemplate.type; - return validateSourceAndTargetTypes(sourceType, targetType); + return validateConnectionTypes(sourceType, targetType); }); }); }, [templates, pendingConnection]); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts index f0dba67bf5..0190a0b29e 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts @@ -8,7 +8,7 @@ import { $templates, connectionMade, } from 'features/nodes/store/nodesSlice'; -import { getFirstValidConnection } from 'features/nodes/store/util/connectionValidation'; +import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { isString } from 'lodash-es'; import { useCallback, useMemo } from 'react'; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts index b92114bab2..77c4e3c75b 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts @@ -2,12 +2,10 @@ import { useStore } from '@nanostores/react'; import { useAppSelector, useAppStore } from 'app/store/storeHooks'; import { $templates } from 'features/nodes/store/nodesSlice'; -import { - areTypesEqual, - getCollectItemType, - getHasCycles, - validateSourceAndTargetTypes, -} from 'features/nodes/store/util/connectionValidation'; +import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual'; +import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType'; +import { getHasCycles } from 'features/nodes/store/util/getHasCycles'; +import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; import type { InvocationNodeData } from 'features/nodes/types/invocation'; import { useCallback } from 'react'; import type { Connection, Node } from 'reactflow'; @@ -88,7 +86,7 @@ export const useIsValidConnection = () => { } // Must use the originalType here if it exists - if (!validateSourceAndTargetTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) { + if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) { return false; } diff --git a/invokeai/frontend/web/src/features/nodes/store/util/areTypesEqual.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/areTypesEqual.test.ts new file mode 100644 index 0000000000..7be307d07e --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/areTypesEqual.test.ts @@ -0,0 +1,101 @@ +import { describe, expect, it } from 'vitest'; + +import { areTypesEqual } from './areTypesEqual'; + +describe(areTypesEqual.name, () => { + it('should handle equal source and target type', () => { + const sourceType = { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + originalType: { + name: 'Foo', + isCollection: false, + isCollectionOrScalar: false, + }, + }; + const targetType = { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + originalType: { + name: 'Bar', + isCollection: false, + isCollectionOrScalar: false, + }, + }; + expect(areTypesEqual(sourceType, targetType)).toBe(true); + }); + + it('should handle equal source type and original target type', () => { + const sourceType = { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + originalType: { + name: 'Foo', + isCollection: false, + isCollectionOrScalar: false, + }, + }; + const targetType = { + name: 'Bar', + isCollection: false, + isCollectionOrScalar: false, + originalType: { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + }, + }; + expect(areTypesEqual(sourceType, targetType)).toBe(true); + }); + + it('should handle equal original source type and target type', () => { + const sourceType = { + name: 'Foo', + isCollection: false, + isCollectionOrScalar: false, + originalType: { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + }, + }; + const targetType = { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + originalType: { + name: 'Bar', + isCollection: false, + isCollectionOrScalar: false, + }, + }; + expect(areTypesEqual(sourceType, targetType)).toBe(true); + }); + + it('should handle equal original source type and original target type', () => { + const sourceType = { + name: 'Foo', + isCollection: false, + isCollectionOrScalar: false, + originalType: { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + }, + }; + const targetType = { + name: 'Bar', + isCollection: false, + isCollectionOrScalar: false, + originalType: { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + }, + }; + expect(areTypesEqual(sourceType, targetType)).toBe(true); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/areTypesEqual.ts b/invokeai/frontend/web/src/features/nodes/store/util/areTypesEqual.ts new file mode 100644 index 0000000000..e01b48b972 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/areTypesEqual.ts @@ -0,0 +1,30 @@ +import type { FieldType } from 'features/nodes/types/field'; +import { isEqual, omit } from 'lodash-es'; + +/** + * Checks if two types are equal. If the field types have original types, those are also compared. Any match is + * considered equal. For example, if the source type and original target type match, the types are considered equal. + * @param sourceType The type of the source field. + * @param targetType The type of the target field. + * @returns True if the types are equal, false otherwise. + */ + +export const areTypesEqual = (sourceType: FieldType, targetType: FieldType) => { + const _sourceType = 'originalType' in sourceType ? omit(sourceType, 'originalType') : sourceType; + const _targetType = 'originalType' in targetType ? omit(targetType, 'originalType') : targetType; + const _sourceTypeOriginal = 'originalType' in sourceType ? sourceType.originalType : null; + const _targetTypeOriginal = 'originalType' in targetType ? targetType.originalType : null; + if (isEqual(_sourceType, _targetType)) { + return true; + } + if (_targetTypeOriginal && isEqual(_sourceType, _targetTypeOriginal)) { + return true; + } + if (_sourceTypeOriginal && isEqual(_sourceTypeOriginal, _targetType)) { + return true; + } + if (_sourceTypeOriginal && _targetTypeOriginal && isEqual(_sourceTypeOriginal, _targetTypeOriginal)) { + return true; + } + return false; +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/connectionValidation.ts b/invokeai/frontend/web/src/features/nodes/store/util/connectionValidation.ts index a2f723fcfe..7819221f8a 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/connectionValidation.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/connectionValidation.ts @@ -1,179 +1,16 @@ -import graphlib from '@dagrejs/graphlib'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import type { RootState } from 'app/store/store'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types'; -import { type FieldType, isStatefulFieldType } from 'features/nodes/types/field'; -import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation'; +import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; +import type { FieldType } from 'features/nodes/types/field'; import i18n from 'i18next'; -import { differenceWith, isEqual, map, omit } from 'lodash-es'; -import type { Connection, Edge, HandleType, Node } from 'reactflow'; +import type { HandleType } from 'reactflow'; import { assert } from 'tsafe'; -/** - * Finds the first valid field for a pending connection between two nodes. - * @param templates The invocation templates - * @param nodes The current nodes - * @param edges The current edges - * @param pendingConnection The pending connection - * @param candidateNode The candidate node to which the connection is being made - * @param candidateTemplate The candidate template for the candidate node - * @returns The first valid connection, or null if no valid connection is found - */ -export const getFirstValidConnection = ( - templates: Templates, - nodes: AnyNode[], - edges: InvocationNodeEdge[], - pendingConnection: PendingConnection, - candidateNode: InvocationNode, - candidateTemplate: InvocationTemplate -): Connection | null => { - if (pendingConnection.node.id === candidateNode.id) { - // Cannot connect to self - return null; - } - - const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source'; - - if (pendingFieldKind === 'source') { - // Connecting from a source to a target - if (getHasCycles(pendingConnection.node.id, candidateNode.id, nodes, edges)) { - return null; - } - if (candidateNode.data.type === 'collect') { - // Special handling for collect node - the `item` field takes any number of connections - return { - source: pendingConnection.node.id, - sourceHandle: pendingConnection.fieldTemplate.name, - target: candidateNode.id, - targetHandle: 'item', - }; - } - // Only one connection per target field is allowed - look for an unconnected target field - const candidateFields = map(candidateTemplate.inputs); - const candidateConnectedFields = edges - .filter((edge) => edge.target === candidateNode.id) - .map((edge) => { - // Edges must always have a targetHandle, safe to assert here - assert(edge.targetHandle); - return edge.targetHandle; - }); - const candidateUnconnectedFields = differenceWith( - candidateFields, - candidateConnectedFields, - (field, connectedFieldName) => field.name === connectedFieldName - ); - const candidateField = candidateUnconnectedFields.find((field) => - validateSourceAndTargetTypes(pendingConnection.fieldTemplate.type, field.type) - ); - if (candidateField) { - return { - source: pendingConnection.node.id, - sourceHandle: pendingConnection.fieldTemplate.name, - target: candidateNode.id, - targetHandle: candidateField.name, - }; - } - } else { - // Connecting from a target to a source - // Ensure we there is not already an edge to the target, except for collect nodes - const isCollect = pendingConnection.node.data.type === 'collect'; - const isTargetAlreadyConnected = edges.some( - (e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name - ); - if (!isCollect && isTargetAlreadyConnected) { - return null; - } - - if (getHasCycles(candidateNode.id, pendingConnection.node.id, nodes, edges)) { - return null; - } - - // Sources/outputs can have any number of edges, we can take the first matching output field - let candidateFields = map(candidateTemplate.outputs); - if (isCollect) { - // Narrow candidates to same field type as already is connected to the collect node - const collectItemType = getCollectItemType(templates, nodes, edges, pendingConnection.node.id); - if (collectItemType) { - candidateFields = candidateFields.filter((field) => areTypesEqual(field.type, collectItemType)); - } - } - const candidateField = candidateFields.find((field) => { - const isValid = validateSourceAndTargetTypes(field.type, pendingConnection.fieldTemplate.type); - const isAlreadyConnected = edges.some((e) => e.source === candidateNode.id && e.sourceHandle === field.name); - return isValid && !isAlreadyConnected; - }); - if (candidateField) { - return { - source: candidateNode.id, - sourceHandle: candidateField.name, - target: pendingConnection.node.id, - targetHandle: pendingConnection.fieldTemplate.name, - }; - } - } - - return null; -}; - -/** - * Check if adding an edge between the source and target nodes would create a cycle in the graph. - * @param source The source node id - * @param target The target node id - * @param nodes The graph's current nodes - * @param edges The graph's current edges - * @returns True if the graph would be acyclic after adding the edge, false otherwise - */ -export const getHasCycles = (source: string, target: string, nodes: Node[], edges: Edge[]) => { - // construct graphlib graph from editor state - const g = new graphlib.Graph(); - - nodes.forEach((n) => { - g.setNode(n.id); - }); - - edges.forEach((e) => { - g.setEdge(e.source, e.target); - }); - - // add the candidate edge - g.setEdge(source, target); - - // check if the graph is acyclic - return !graphlib.alg.isAcyclic(g); -}; - -/** - * Given a collect node, return the type of the items it collects. The graph is traversed to find the first node and - * field connected to the collector's `item` input. The field type of that field is returned, else null if there is no - * input field. - * @param templates The current invocation templates - * @param nodes The current nodes - * @param edges The current edges - * @param nodeId The collect node's id - * @returns The type of the items the collect node collects, or null if there is no input field - */ -export const getCollectItemType = ( - templates: Templates, - nodes: AnyNode[], - edges: InvocationNodeEdge[], - nodeId: string -): FieldType | null => { - const firstEdgeToCollect = edges.find((edge) => edge.target === nodeId && edge.targetHandle === 'item'); - if (!firstEdgeToCollect?.sourceHandle) { - return null; - } - const node = nodes.find((n) => n.id === firstEdgeToCollect.source); - if (!node) { - return null; - } - const template = templates[node.data.type]; - if (!template) { - return null; - } - const fieldType = template.outputs[firstEdgeToCollect.sourceHandle]?.type ?? null; - return fieldType; -}; +import { areTypesEqual } from './areTypesEqual'; +import { getCollectItemType } from './getCollectItemType'; +import { getHasCycles } from './getHasCycles'; /** * Creates a selector that validates a pending connection. @@ -276,7 +113,7 @@ export const makeConnectionErrorSelector = ( return i18n.t('nodes.inputMayOnlyHaveOneConnection'); } - if (!validateSourceAndTargetTypes(sourceType, targetType)) { + if (!validateConnectionTypes(sourceType, targetType)) { return i18n.t('nodes.fieldTypesMustMatch'); } @@ -295,97 +132,3 @@ export const makeConnectionErrorSelector = ( } ); }; - -/** - * Validates that the source and target types are compatible for a connection. - * @param sourceType The type of the source field. - * @param targetType The type of the target field. - * @returns True if the connection is valid, false otherwise. - */ -export const validateSourceAndTargetTypes = (sourceType: FieldType, targetType: FieldType) => { - // TODO: There's a bug with Collect -> Iterate nodes: - // https://github.com/invoke-ai/InvokeAI/issues/3956 - // Once this is resolved, we can remove this check. - if (sourceType.name === 'CollectionField' && targetType.name === 'CollectionField') { - return false; - } - - if (areTypesEqual(sourceType, targetType)) { - return true; - } - - /** - * Connection types must be the same for a connection, with exceptions: - * - CollectionItem can connect to any non-Collection - * - Non-Collections can connect to CollectionItem - * - Anything (non-Collections, Collections, CollectionOrScalar) can connect to CollectionOrScalar of the same base type - * - Generic Collection can connect to any other Collection or CollectionOrScalar - * - Any Collection can connect to a Generic Collection - */ - const isCollectionItemToNonCollection = sourceType.name === 'CollectionItemField' && !targetType.isCollection; - - const isNonCollectionToCollectionItem = - targetType.name === 'CollectionItemField' && !sourceType.isCollection && !sourceType.isCollectionOrScalar; - - const isAnythingToCollectionOrScalarOfSameBaseType = - targetType.isCollectionOrScalar && sourceType.name === targetType.name; - - const isGenericCollectionToAnyCollectionOrCollectionOrScalar = - sourceType.name === 'CollectionField' && (targetType.isCollection || targetType.isCollectionOrScalar); - - const isCollectionToGenericCollection = targetType.name === 'CollectionField' && sourceType.isCollection; - - const areBothTypesSingle = - !sourceType.isCollection && - !sourceType.isCollectionOrScalar && - !targetType.isCollection && - !targetType.isCollectionOrScalar; - - const isIntToFloat = areBothTypesSingle && sourceType.name === 'IntegerField' && targetType.name === 'FloatField'; - - const isIntOrFloatToString = - areBothTypesSingle && - (sourceType.name === 'IntegerField' || sourceType.name === 'FloatField') && - targetType.name === 'StringField'; - - const isTargetAnyType = targetType.name === 'AnyField'; - - // One of these must be true for the connection to be valid - return ( - isCollectionItemToNonCollection || - isNonCollectionToCollectionItem || - isAnythingToCollectionOrScalarOfSameBaseType || - isGenericCollectionToAnyCollectionOrCollectionOrScalar || - isCollectionToGenericCollection || - isIntToFloat || - isIntOrFloatToString || - isTargetAnyType - ); -}; - -/** - * Checks if two types are equal. If the field types have original types, those are also compared. Any match is - * considered equal. For example, if the source type and original target type match, the types are considered equal. - * @param sourceType The type of the source field. - * @param targetType The type of the target field. - * @returns True if the types are equal, false otherwise. - */ -export const areTypesEqual = (sourceType: FieldType, targetType: FieldType) => { - const _sourceType = isStatefulFieldType(sourceType) ? omit(sourceType, 'originalType') : sourceType; - const _targetType = isStatefulFieldType(targetType) ? omit(targetType, 'originalType') : targetType; - const _sourceTypeOriginal = isStatefulFieldType(sourceType) ? sourceType.originalType : sourceType; - const _targetTypeOriginal = isStatefulFieldType(targetType) ? targetType.originalType : targetType; - if (isEqual(_sourceType, _targetType)) { - return true; - } - if (_targetTypeOriginal && isEqual(_sourceType, _targetTypeOriginal)) { - return true; - } - if (_sourceTypeOriginal && isEqual(_sourceTypeOriginal, _targetType)) { - return true; - } - if (_sourceTypeOriginal && _targetTypeOriginal && isEqual(_sourceTypeOriginal, _targetTypeOriginal)) { - return true; - } - return false; -}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts new file mode 100644 index 0000000000..93c63b6f41 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts @@ -0,0 +1,16 @@ +import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType'; +import { add, buildEdge, collect, position, templates } from 'features/nodes/store/util/testUtils'; +import type { FieldType } from 'features/nodes/types/field'; +import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode'; +import { describe, expect, it } from 'vitest'; + +describe(getCollectItemType.name, () => { + it('should return the type of the items the collect node collects', () => { + const n1 = buildInvocationNode(position, add); + const n2 = buildInvocationNode(position, collect); + const nodes = [n1, n2]; + const edges = [buildEdge(n1.id, 'value', n2.id, 'item')]; + const result = getCollectItemType(templates, nodes, edges, n2.id); + expect(result).toEqual({ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false }); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.ts b/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.ts new file mode 100644 index 0000000000..9e0ce0fbee --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.ts @@ -0,0 +1,35 @@ +import type { Templates } from 'features/nodes/store/types'; +import type { FieldType } from 'features/nodes/types/field'; +import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation'; + +/** + * Given a collect node, return the type of the items it collects. The graph is traversed to find the first node and + * field connected to the collector's `item` input. The field type of that field is returned, else null if there is no + * input field. + * @param templates The current invocation templates + * @param nodes The current nodes + * @param edges The current edges + * @param nodeId The collect node's id + * @returns The type of the items the collect node collects, or null if there is no input field + */ +export const getCollectItemType = ( + templates: Templates, + nodes: AnyNode[], + edges: InvocationNodeEdge[], + nodeId: string +): FieldType | null => { + const firstEdgeToCollect = edges.find((edge) => edge.target === nodeId && edge.targetHandle === 'item'); + if (!firstEdgeToCollect?.sourceHandle) { + return null; + } + const node = nodes.find((n) => n.id === firstEdgeToCollect.source); + if (!node) { + return null; + } + const template = templates[node.data.type]; + if (!template) { + return null; + } + const fieldType = template.outputs[firstEdgeToCollect.sourceHandle]?.type ?? null; + return fieldType; +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts new file mode 100644 index 0000000000..98155f0c20 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts @@ -0,0 +1,116 @@ +import type { PendingConnection, Templates } from 'features/nodes/store/types'; +import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; +import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation'; +import { differenceWith, map } from 'lodash-es'; +import type { Connection } from 'reactflow'; +import { assert } from 'tsafe'; + +import { areTypesEqual } from './areTypesEqual'; +import { getCollectItemType } from './getCollectItemType'; +import { getHasCycles } from './getHasCycles'; + +/** + * Finds the first valid field for a pending connection between two nodes. + * @param templates The invocation templates + * @param nodes The current nodes + * @param edges The current edges + * @param pendingConnection The pending connection + * @param candidateNode The candidate node to which the connection is being made + * @param candidateTemplate The candidate template for the candidate node + * @returns The first valid connection, or null if no valid connection is found + */ + +export const getFirstValidConnection = ( + templates: Templates, + nodes: AnyNode[], + edges: InvocationNodeEdge[], + pendingConnection: PendingConnection, + candidateNode: InvocationNode, + candidateTemplate: InvocationTemplate +): Connection | null => { + if (pendingConnection.node.id === candidateNode.id) { + // Cannot connect to self + return null; + } + + const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source'; + + if (pendingFieldKind === 'source') { + // Connecting from a source to a target + if (getHasCycles(pendingConnection.node.id, candidateNode.id, nodes, edges)) { + return null; + } + if (candidateNode.data.type === 'collect') { + // Special handling for collect node - the `item` field takes any number of connections + return { + source: pendingConnection.node.id, + sourceHandle: pendingConnection.fieldTemplate.name, + target: candidateNode.id, + targetHandle: 'item', + }; + } + // Only one connection per target field is allowed - look for an unconnected target field + const candidateFields = map(candidateTemplate.inputs); + const candidateConnectedFields = edges + .filter((edge) => edge.target === candidateNode.id) + .map((edge) => { + // Edges must always have a targetHandle, safe to assert here + assert(edge.targetHandle); + return edge.targetHandle; + }); + const candidateUnconnectedFields = differenceWith( + candidateFields, + candidateConnectedFields, + (field, connectedFieldName) => field.name === connectedFieldName + ); + const candidateField = candidateUnconnectedFields.find((field) => validateConnectionTypes(pendingConnection.fieldTemplate.type, field.type) + ); + if (candidateField) { + return { + source: pendingConnection.node.id, + sourceHandle: pendingConnection.fieldTemplate.name, + target: candidateNode.id, + targetHandle: candidateField.name, + }; + } + } else { + // Connecting from a target to a source + // Ensure we there is not already an edge to the target, except for collect nodes + const isCollect = pendingConnection.node.data.type === 'collect'; + const isTargetAlreadyConnected = edges.some( + (e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name + ); + if (!isCollect && isTargetAlreadyConnected) { + return null; + } + + if (getHasCycles(candidateNode.id, pendingConnection.node.id, nodes, edges)) { + return null; + } + + // Sources/outputs can have any number of edges, we can take the first matching output field + let candidateFields = map(candidateTemplate.outputs); + if (isCollect) { + // Narrow candidates to same field type as already is connected to the collect node + const collectItemType = getCollectItemType(templates, nodes, edges, pendingConnection.node.id); + if (collectItemType) { + candidateFields = candidateFields.filter((field) => areTypesEqual(field.type, collectItemType)); + } + } + const candidateField = candidateFields.find((field) => { + const isValid = validateConnectionTypes(field.type, pendingConnection.fieldTemplate.type); + const isAlreadyConnected = edges.some((e) => e.source === candidateNode.id && e.sourceHandle === field.name); + return isValid && !isAlreadyConnected; + }); + if (candidateField) { + return { + source: candidateNode.id, + sourceHandle: candidateField.name, + target: pendingConnection.node.id, + targetHandle: pendingConnection.fieldTemplate.name, + }; + } + } + + return null; +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getHasCycles.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/getHasCycles.test.ts new file mode 100644 index 0000000000..872da36998 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/getHasCycles.test.ts @@ -0,0 +1,23 @@ +import { getHasCycles } from 'features/nodes/store/util/getHasCycles'; +import { add, buildEdge, position } from 'features/nodes/store/util/testUtils'; +import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode'; +import { describe, expect, it } from 'vitest'; + +describe(getHasCycles.name, () => { + const n1 = buildInvocationNode(position, add); + const n2 = buildInvocationNode(position, add); + const n3 = buildInvocationNode(position, add); + const nodes = [n1, n2, n3]; + + it('should return true if the graph WOULD have cycles after adding the edge', () => { + const edges = [buildEdge(n1.id, 'value', n2.id, 'a'), buildEdge(n2.id, 'value', n3.id, 'a')]; + const result = getHasCycles(n3.id, n1.id, nodes, edges); + expect(result).toBe(true); + }); + + it('should return false if the graph WOULD NOT have cycles after adding the edge', () => { + const edges = [buildEdge(n1.id, 'value', n2.id, 'a')]; + const result = getHasCycles(n2.id, n3.id, nodes, edges); + expect(result).toBe(false); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getHasCycles.ts b/invokeai/frontend/web/src/features/nodes/store/util/getHasCycles.ts new file mode 100644 index 0000000000..c1a4e51f0c --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/getHasCycles.ts @@ -0,0 +1,30 @@ +import graphlib from '@dagrejs/graphlib'; +import type { Edge, Node } from 'reactflow'; + +/** + * Check if adding an edge between the source and target nodes would create a cycle in the graph. + * @param source The source node id + * @param target The target node id + * @param nodes The graph's current nodes + * @param edges The graph's current edges + * @returns True if the graph would be acyclic after adding the edge, false otherwise + */ + +export const getHasCycles = (source: string, target: string, nodes: Node[], edges: Edge[]) => { + // construct graphlib graph from editor state + const g = new graphlib.Graph(); + + nodes.forEach((n) => { + g.setNode(n.id); + }); + + edges.forEach((e) => { + g.setEdge(e.source, e.target); + }); + + // add the candidate edge + g.setEdge(source, target); + + // check if the graph is acyclic + return !graphlib.alg.isAcyclic(g); +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts b/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts new file mode 100644 index 0000000000..efde3336e2 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts @@ -0,0 +1,1073 @@ +import type { Templates } from 'features/nodes/store/types'; +import type { InvocationTemplate } from 'features/nodes/types/invocation'; +import type { OpenAPIV3_1 } from 'openapi-types'; +import type { Edge, XYPosition } from 'reactflow'; + +export const buildEdge = (source: string, sourceHandle: string, target: string, targetHandle: string): Edge => ({ + source, + sourceHandle, + target, + targetHandle, + type: 'default', + id: `reactflow__edge-${source}${sourceHandle}-${target}${targetHandle}`, +}); + +export const position: XYPosition = { x: 0, y: 0 }; + +export const add: InvocationTemplate = { + title: 'Add Integers', + type: 'add', + version: '1.0.1', + tags: ['math', 'add'], + description: 'Adds two numbers', + outputType: 'integer_output', + inputs: { + a: { + name: 'a', + title: 'A', + required: false, + description: 'The first number', + fieldKind: 'input', + input: 'any', + ui_hidden: false, + type: { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + }, + default: 0, + }, + b: { + name: 'b', + title: 'B', + required: false, + description: 'The second number', + fieldKind: 'input', + input: 'any', + ui_hidden: false, + type: { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + }, + default: 0, + }, + }, + outputs: { + value: { + fieldKind: 'output', + name: 'value', + title: 'Value', + description: 'The output integer', + type: { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + }, + ui_hidden: false, + }, + }, + useCache: true, + nodePack: 'invokeai', + classification: 'stable', +}; + +export const sub: InvocationTemplate = { + title: 'Subtract Integers', + type: 'sub', + version: '1.0.1', + tags: ['math', 'subtract'], + description: 'Subtracts two numbers', + outputType: 'integer_output', + inputs: { + a: { + name: 'a', + title: 'A', + required: false, + description: 'The first number', + fieldKind: 'input', + input: 'any', + ui_hidden: false, + type: { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + }, + default: 0, + }, + b: { + name: 'b', + title: 'B', + required: false, + description: 'The second number', + fieldKind: 'input', + input: 'any', + ui_hidden: false, + type: { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + }, + default: 0, + }, + }, + outputs: { + value: { + fieldKind: 'output', + name: 'value', + title: 'Value', + description: 'The output integer', + type: { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + }, + ui_hidden: false, + }, + }, + useCache: true, + nodePack: 'invokeai', + classification: 'stable', +}; + +export const collect: InvocationTemplate = { + title: 'Collect', + type: 'collect', + version: '1.0.0', + tags: [], + description: 'Collects values into a collection', + outputType: 'collect_output', + inputs: { + item: { + name: 'item', + title: 'Collection Item', + required: false, + description: 'The item to collect (all inputs must be of the same type)', + fieldKind: 'input', + input: 'connection', + ui_hidden: false, + ui_type: 'CollectionItemField', + type: { + name: 'CollectionItemField', + isCollection: false, + isCollectionOrScalar: false, + }, + }, + }, + outputs: { + collection: { + fieldKind: 'output', + name: 'collection', + title: 'Collection', + description: 'The collection of input items', + type: { + name: 'CollectionField', + isCollection: true, + isCollectionOrScalar: false, + }, + ui_hidden: false, + ui_type: 'CollectionField', + }, + }, + useCache: true, + classification: 'stable', +}; + +export const scheduler: InvocationTemplate = { + title: 'Scheduler', + type: 'scheduler', + version: '1.0.0', + tags: ['scheduler'], + description: 'Selects a scheduler.', + outputType: 'scheduler_output', + inputs: { + scheduler: { + name: 'scheduler', + title: 'Scheduler', + required: false, + description: 'Scheduler to use during inference', + fieldKind: 'input', + input: 'any', + ui_hidden: false, + ui_type: 'SchedulerField', + type: { + name: 'SchedulerField', + isCollection: false, + isCollectionOrScalar: false, + originalType: { + name: 'EnumField', + isCollection: false, + isCollectionOrScalar: false, + }, + }, + default: 'euler', + }, + }, + outputs: { + scheduler: { + fieldKind: 'output', + name: 'scheduler', + title: 'Scheduler', + description: 'Scheduler to use during inference', + type: { + name: 'SchedulerField', + isCollection: false, + isCollectionOrScalar: false, + originalType: { + name: 'EnumField', + isCollection: false, + isCollectionOrScalar: false, + }, + }, + ui_hidden: false, + ui_type: 'SchedulerField', + }, + }, + useCache: true, + nodePack: 'invokeai', + classification: 'stable', +}; + +export const main_model_loader: InvocationTemplate = { + title: 'Main Model', + type: 'main_model_loader', + version: '1.0.2', + tags: ['model'], + description: 'Loads a main model, outputting its submodels.', + outputType: 'model_loader_output', + inputs: { + model: { + name: 'model', + title: 'Model', + required: true, + description: 'Main model (UNet, VAE, CLIP) to load', + fieldKind: 'input', + input: 'direct', + ui_hidden: false, + ui_type: 'MainModelField', + type: { + name: 'MainModelField', + isCollection: false, + isCollectionOrScalar: false, + originalType: { + name: 'ModelIdentifierField', + isCollection: false, + isCollectionOrScalar: false, + }, + }, + }, + }, + outputs: { + vae: { + fieldKind: 'output', + name: 'vae', + title: 'VAE', + description: 'VAE', + type: { + name: 'VAEField', + isCollection: false, + isCollectionOrScalar: false, + }, + ui_hidden: false, + }, + clip: { + fieldKind: 'output', + name: 'clip', + title: 'CLIP', + description: 'CLIP (tokenizer, text encoder, LoRAs) and skipped layer count', + type: { + name: 'CLIPField', + isCollection: false, + isCollectionOrScalar: false, + }, + ui_hidden: false, + }, + unet: { + fieldKind: 'output', + name: 'unet', + title: 'UNet', + description: 'UNet (scheduler, LoRAs)', + type: { + name: 'UNetField', + isCollection: false, + isCollectionOrScalar: false, + }, + ui_hidden: false, + }, + }, + useCache: true, + nodePack: 'invokeai', + classification: 'stable', +} + +export const templates: Templates = { + add, + sub, + collect, + scheduler, + main_model_loader, +}; + +export const schema = { + openapi: '3.1.0', + info: { + title: 'Invoke - Community Edition', + description: 'An API for invoking AI image operations', + version: '1.0.0', + }, + components: { + schemas: { + AddInvocation: { + properties: { + id: { + type: 'string', + title: 'Id', + description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', + field_kind: 'node_attribute', + }, + is_intermediate: { + type: 'boolean', + title: 'Is Intermediate', + description: 'Whether or not this is an intermediate invocation.', + default: false, + field_kind: 'node_attribute', + ui_type: 'IsIntermediate', + }, + use_cache: { + type: 'boolean', + title: 'Use Cache', + description: 'Whether or not to use the cache', + default: true, + field_kind: 'node_attribute', + }, + a: { + type: 'integer', + title: 'A', + description: 'The first number', + default: 0, + field_kind: 'input', + input: 'any', + orig_default: 0, + orig_required: false, + ui_hidden: false, + }, + b: { + type: 'integer', + title: 'B', + description: 'The second number', + default: 0, + field_kind: 'input', + input: 'any', + orig_default: 0, + orig_required: false, + ui_hidden: false, + }, + type: { + type: 'string', + enum: ['add'], + const: 'add', + title: 'type', + default: 'add', + field_kind: 'node_attribute', + }, + }, + type: 'object', + required: ['type', 'id'], + title: 'Add Integers', + description: 'Adds two numbers', + category: 'math', + classification: 'stable', + node_pack: 'invokeai', + tags: ['math', 'add'], + version: '1.0.1', + output: { + $ref: '#/components/schemas/IntegerOutput', + }, + class: 'invocation', + }, + IntegerOutput: { + description: 'Base class for nodes that output a single integer', + properties: { + value: { + description: 'The output integer', + field_kind: 'output', + title: 'Value', + type: 'integer', + ui_hidden: false, + }, + type: { + const: 'integer_output', + default: 'integer_output', + enum: ['integer_output'], + field_kind: 'node_attribute', + title: 'type', + type: 'string', + }, + }, + required: ['value', 'type', 'type'], + title: 'IntegerOutput', + type: 'object', + class: 'output', + }, + SchedulerInvocation: { + properties: { + id: { + type: 'string', + title: 'Id', + description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', + field_kind: 'node_attribute', + }, + is_intermediate: { + type: 'boolean', + title: 'Is Intermediate', + description: 'Whether or not this is an intermediate invocation.', + default: false, + field_kind: 'node_attribute', + ui_type: 'IsIntermediate', + }, + use_cache: { + type: 'boolean', + title: 'Use Cache', + description: 'Whether or not to use the cache', + default: true, + field_kind: 'node_attribute', + }, + scheduler: { + type: 'string', + enum: [ + 'ddim', + 'ddpm', + 'deis', + 'lms', + 'lms_k', + 'pndm', + 'heun', + 'heun_k', + 'euler', + 'euler_k', + 'euler_a', + 'kdpm_2', + 'kdpm_2_a', + 'dpmpp_2s', + 'dpmpp_2s_k', + 'dpmpp_2m', + 'dpmpp_2m_k', + 'dpmpp_2m_sde', + 'dpmpp_2m_sde_k', + 'dpmpp_sde', + 'dpmpp_sde_k', + 'unipc', + 'lcm', + 'tcd', + ], + title: 'Scheduler', + description: 'Scheduler to use during inference', + default: 'euler', + field_kind: 'input', + input: 'any', + orig_default: 'euler', + orig_required: false, + ui_hidden: false, + ui_type: 'SchedulerField', + }, + type: { + type: 'string', + enum: ['scheduler'], + const: 'scheduler', + title: 'type', + default: 'scheduler', + field_kind: 'node_attribute', + }, + }, + type: 'object', + required: ['type', 'id'], + title: 'Scheduler', + description: 'Selects a scheduler.', + category: 'latents', + classification: 'stable', + node_pack: 'invokeai', + tags: ['scheduler'], + version: '1.0.0', + output: { + $ref: '#/components/schemas/SchedulerOutput', + }, + class: 'invocation', + }, + SchedulerOutput: { + properties: { + scheduler: { + description: 'Scheduler to use during inference', + enum: [ + 'ddim', + 'ddpm', + 'deis', + 'lms', + 'lms_k', + 'pndm', + 'heun', + 'heun_k', + 'euler', + 'euler_k', + 'euler_a', + 'kdpm_2', + 'kdpm_2_a', + 'dpmpp_2s', + 'dpmpp_2s_k', + 'dpmpp_2m', + 'dpmpp_2m_k', + 'dpmpp_2m_sde', + 'dpmpp_2m_sde_k', + 'dpmpp_sde', + 'dpmpp_sde_k', + 'unipc', + 'lcm', + 'tcd', + ], + field_kind: 'output', + title: 'Scheduler', + type: 'string', + ui_hidden: false, + ui_type: 'SchedulerField', + }, + type: { + const: 'scheduler_output', + default: 'scheduler_output', + enum: ['scheduler_output'], + field_kind: 'node_attribute', + title: 'type', + type: 'string', + }, + }, + required: ['scheduler', 'type', 'type'], + title: 'SchedulerOutput', + type: 'object', + class: 'output', + }, + MainModelLoaderInvocation: { + properties: { + id: { + type: 'string', + title: 'Id', + description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', + field_kind: 'node_attribute', + }, + is_intermediate: { + type: 'boolean', + title: 'Is Intermediate', + description: 'Whether or not this is an intermediate invocation.', + default: false, + field_kind: 'node_attribute', + ui_type: 'IsIntermediate', + }, + use_cache: { + type: 'boolean', + title: 'Use Cache', + description: 'Whether or not to use the cache', + default: true, + field_kind: 'node_attribute', + }, + model: { + allOf: [ + { + $ref: '#/components/schemas/ModelIdentifierField', + }, + ], + description: 'Main model (UNet, VAE, CLIP) to load', + field_kind: 'input', + input: 'direct', + orig_required: true, + ui_hidden: false, + ui_type: 'MainModelField', + }, + type: { + type: 'string', + enum: ['main_model_loader'], + const: 'main_model_loader', + title: 'type', + default: 'main_model_loader', + field_kind: 'node_attribute', + }, + }, + type: 'object', + required: ['model', 'type', 'id'], + title: 'Main Model', + description: 'Loads a main model, outputting its submodels.', + category: 'model', + classification: 'stable', + node_pack: 'invokeai', + tags: ['model'], + version: '1.0.2', + output: { + $ref: '#/components/schemas/ModelLoaderOutput', + }, + class: 'invocation', + }, + ModelIdentifierField: { + properties: { + key: { + description: "The model's unique key", + title: 'Key', + type: 'string', + }, + hash: { + description: "The model's BLAKE3 hash", + title: 'Hash', + type: 'string', + }, + name: { + description: "The model's name", + title: 'Name', + type: 'string', + }, + base: { + allOf: [ + { + $ref: '#/components/schemas/BaseModelType', + }, + ], + description: "The model's base model type", + }, + type: { + allOf: [ + { + $ref: '#/components/schemas/ModelType', + }, + ], + description: "The model's type", + }, + submodel_type: { + anyOf: [ + { + $ref: '#/components/schemas/SubModelType', + }, + { + type: 'null', + }, + ], + default: null, + description: 'The submodel to load, if this is a main model', + }, + }, + required: ['key', 'hash', 'name', 'base', 'type'], + title: 'ModelIdentifierField', + type: 'object', + }, + BaseModelType: { + description: 'Base model type.', + enum: ['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'], + title: 'BaseModelType', + type: 'string', + }, + ModelType: { + description: 'Model type.', + enum: ['onnx', 'main', 'vae', 'lora', 'controlnet', 'embedding', 'ip_adapter', 'clip_vision', 't2i_adapter'], + title: 'ModelType', + type: 'string', + }, + SubModelType: { + description: 'Submodel type.', + enum: [ + 'unet', + 'text_encoder', + 'text_encoder_2', + 'tokenizer', + 'tokenizer_2', + 'vae', + 'vae_decoder', + 'vae_encoder', + 'scheduler', + 'safety_checker', + ], + title: 'SubModelType', + type: 'string', + }, + ModelLoaderOutput: { + description: 'Model loader output', + properties: { + vae: { + allOf: [ + { + $ref: '#/components/schemas/VAEField', + }, + ], + description: 'VAE', + field_kind: 'output', + title: 'VAE', + ui_hidden: false, + }, + type: { + const: 'model_loader_output', + default: 'model_loader_output', + enum: ['model_loader_output'], + field_kind: 'node_attribute', + title: 'type', + type: 'string', + }, + clip: { + allOf: [ + { + $ref: '#/components/schemas/CLIPField', + }, + ], + description: 'CLIP (tokenizer, text encoder, LoRAs) and skipped layer count', + field_kind: 'output', + title: 'CLIP', + ui_hidden: false, + }, + unet: { + allOf: [ + { + $ref: '#/components/schemas/UNetField', + }, + ], + description: 'UNet (scheduler, LoRAs)', + field_kind: 'output', + title: 'UNet', + ui_hidden: false, + }, + }, + required: ['vae', 'type', 'clip', 'unet', 'type'], + title: 'ModelLoaderOutput', + type: 'object', + class: 'output', + }, + UNetField: { + properties: { + unet: { + allOf: [ + { + $ref: '#/components/schemas/ModelIdentifierField', + }, + ], + description: 'Info to load unet submodel', + }, + scheduler: { + allOf: [ + { + $ref: '#/components/schemas/ModelIdentifierField', + }, + ], + description: 'Info to load scheduler submodel', + }, + loras: { + description: 'LoRAs to apply on model loading', + items: { + $ref: '#/components/schemas/LoRAField', + }, + title: 'Loras', + type: 'array', + }, + seamless_axes: { + description: 'Axes("x" and "y") to which apply seamless', + items: { + type: 'string', + }, + title: 'Seamless Axes', + type: 'array', + }, + freeu_config: { + anyOf: [ + { + $ref: '#/components/schemas/FreeUConfig', + }, + { + type: 'null', + }, + ], + default: null, + description: 'FreeU configuration', + }, + }, + required: ['unet', 'scheduler', 'loras'], + title: 'UNetField', + type: 'object', + class: 'output', + }, + LoRAField: { + properties: { + lora: { + allOf: [ + { + $ref: '#/components/schemas/ModelIdentifierField', + }, + ], + description: 'Info to load lora model', + }, + weight: { + description: 'Weight to apply to lora model', + title: 'Weight', + type: 'number', + }, + }, + required: ['lora', 'weight'], + title: 'LoRAField', + type: 'object', + class: 'output', + }, + FreeUConfig: { + description: + 'Configuration for the FreeU hyperparameters.\n- https://huggingface.co/docs/diffusers/main/en/using-diffusers/freeu\n- https://github.com/ChenyangSi/FreeU', + properties: { + s1: { + description: + 'Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.', + maximum: 3, + minimum: -1, + title: 'S1', + type: 'number', + }, + s2: { + description: + 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.', + maximum: 3, + minimum: -1, + title: 'S2', + type: 'number', + }, + b1: { + description: 'Scaling factor for stage 1 to amplify the contributions of backbone features.', + maximum: 3, + minimum: -1, + title: 'B1', + type: 'number', + }, + b2: { + description: 'Scaling factor for stage 2 to amplify the contributions of backbone features.', + maximum: 3, + minimum: -1, + title: 'B2', + type: 'number', + }, + }, + required: ['s1', 's2', 'b1', 'b2'], + title: 'FreeUConfig', + type: 'object', + class: 'output', + }, + VAEField: { + properties: { + vae: { + allOf: [ + { + $ref: '#/components/schemas/ModelIdentifierField', + }, + ], + description: 'Info to load vae submodel', + }, + seamless_axes: { + description: 'Axes("x" and "y") to which apply seamless', + items: { + type: 'string', + }, + title: 'Seamless Axes', + type: 'array', + }, + }, + required: ['vae'], + title: 'VAEField', + type: 'object', + class: 'output', + }, + CLIPField: { + properties: { + tokenizer: { + allOf: [ + { + $ref: '#/components/schemas/ModelIdentifierField', + }, + ], + description: 'Info to load tokenizer submodel', + }, + text_encoder: { + allOf: [ + { + $ref: '#/components/schemas/ModelIdentifierField', + }, + ], + description: 'Info to load text_encoder submodel', + }, + skipped_layers: { + description: 'Number of skipped layers in text_encoder', + title: 'Skipped Layers', + type: 'integer', + }, + loras: { + description: 'LoRAs to apply on model loading', + items: { + $ref: '#/components/schemas/LoRAField', + }, + title: 'Loras', + type: 'array', + }, + }, + required: ['tokenizer', 'text_encoder', 'skipped_layers', 'loras'], + title: 'CLIPField', + type: 'object', + class: 'output', + }, + CollectInvocation: { + properties: { + id: { + type: 'string', + title: 'Id', + description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', + field_kind: 'node_attribute', + }, + is_intermediate: { + type: 'boolean', + title: 'Is Intermediate', + description: 'Whether or not this is an intermediate invocation.', + default: false, + field_kind: 'node_attribute', + ui_type: 'IsIntermediate', + }, + use_cache: { + type: 'boolean', + title: 'Use Cache', + description: 'Whether or not to use the cache', + default: true, + field_kind: 'node_attribute', + }, + item: { + anyOf: [ + {}, + { + type: 'null', + }, + ], + title: 'Collection Item', + description: 'The item to collect (all inputs must be of the same type)', + field_kind: 'input', + input: 'connection', + orig_required: false, + ui_hidden: false, + ui_type: 'CollectionItemField', + }, + collection: { + items: {}, + type: 'array', + title: 'Collection', + description: 'The collection, will be provided on execution', + default: [], + field_kind: 'input', + input: 'any', + orig_default: [], + orig_required: false, + ui_hidden: true, + }, + type: { + type: 'string', + enum: ['collect'], + const: 'collect', + title: 'type', + default: 'collect', + field_kind: 'node_attribute', + }, + }, + type: 'object', + required: ['type', 'id'], + title: 'CollectInvocation', + description: 'Collects values into a collection', + classification: 'stable', + version: '1.0.0', + output: { + $ref: '#/components/schemas/CollectInvocationOutput', + }, + class: 'invocation', + }, + CollectInvocationOutput: { + properties: { + collection: { + description: 'The collection of input items', + field_kind: 'output', + items: {}, + title: 'Collection', + type: 'array', + ui_hidden: false, + ui_type: 'CollectionField', + }, + type: { + const: 'collect_output', + default: 'collect_output', + enum: ['collect_output'], + field_kind: 'node_attribute', + title: 'type', + type: 'string', + }, + }, + required: ['collection', 'type', 'type'], + title: 'CollectInvocationOutput', + type: 'object', + class: 'output', + }, + SubtractInvocation: { + properties: { + id: { + type: 'string', + title: 'Id', + description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', + field_kind: 'node_attribute', + }, + is_intermediate: { + type: 'boolean', + title: 'Is Intermediate', + description: 'Whether or not this is an intermediate invocation.', + default: false, + field_kind: 'node_attribute', + ui_type: 'IsIntermediate', + }, + use_cache: { + type: 'boolean', + title: 'Use Cache', + description: 'Whether or not to use the cache', + default: true, + field_kind: 'node_attribute', + }, + a: { + type: 'integer', + title: 'A', + description: 'The first number', + default: 0, + field_kind: 'input', + input: 'any', + orig_default: 0, + orig_required: false, + ui_hidden: false, + }, + b: { + type: 'integer', + title: 'B', + description: 'The second number', + default: 0, + field_kind: 'input', + input: 'any', + orig_default: 0, + orig_required: false, + ui_hidden: false, + }, + type: { + type: 'string', + enum: ['sub'], + const: 'sub', + title: 'type', + default: 'sub', + field_kind: 'node_attribute', + }, + }, + type: 'object', + required: ['type', 'id'], + title: 'Subtract Integers', + description: 'Subtracts two numbers', + category: 'math', + classification: 'stable', + node_pack: 'invokeai', + tags: ['math', 'subtract'], + version: '1.0.1', + output: { + $ref: '#/components/schemas/IntegerOutput', + }, + class: 'invocation', + }, + }, + }, +} as OpenAPIV3_1.Document; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts new file mode 100644 index 0000000000..5d10ef368b --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts @@ -0,0 +1,149 @@ +import { deepClone } from 'common/util/deepClone'; +import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode'; +import { set } from 'lodash-es'; +import { describe, expect, it } from 'vitest'; + +import { add, buildEdge, collect, main_model_loader, position, sub, templates } from './testUtils'; +import { buildAcceptResult, buildRejectResult, validateConnection } from './validateConnection'; + +describe(validateConnection.name, () => { + it('should reject invalid connection to self', () => { + const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' }; + const r = validateConnection(c, [], [], templates, null); + expect(r).toEqual(buildRejectResult('nodes.cannotConnectToSelf')); + }); + + describe('missing nodes', () => { + const n1 = buildInvocationNode(position, add); + const n2 = buildInvocationNode(position, sub); + const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; + + it('should reject missing source node', () => { + const r = validateConnection(c, [n2], [], templates, null); + expect(r).toEqual(buildRejectResult('nodes.missingNode')); + }); + + it('should reject missing target node', () => { + const r = validateConnection(c, [n1], [], templates, null); + expect(r).toEqual(buildRejectResult('nodes.missingNode')); + }); + }); + + describe('missing invocation templates', () => { + const n1 = buildInvocationNode(position, add); + const n2 = buildInvocationNode(position, sub); + const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; + const nodes = [n1, n2]; + + it('should reject missing source template', () => { + const r = validateConnection(c, nodes, [], { sub }, null); + expect(r).toEqual(buildRejectResult('nodes.missingInvocationTemplate')); + }); + + it('should reject missing target template', () => { + const r = validateConnection(c, nodes, [], { add }, null); + expect(r).toEqual(buildRejectResult('nodes.missingInvocationTemplate')); + }); + }); + + describe('missing field templates', () => { + const n1 = buildInvocationNode(position, add); + const n2 = buildInvocationNode(position, sub); + const nodes = [n1, n2]; + + it('should reject missing source field template', () => { + const c = { source: n1.id, sourceHandle: 'invalid', target: n2.id, targetHandle: 'a' }; + const r = validateConnection(c, nodes, [], templates, null); + expect(r).toEqual(buildRejectResult('nodes.missingFieldTemplate')); + }); + + it('should reject missing target field template', () => { + const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'invalid' }; + const r = validateConnection(c, nodes, [], templates, null); + expect(r).toEqual(buildRejectResult('nodes.missingFieldTemplate')); + }); + }); + + describe('duplicate connections', () => { + const n1 = buildInvocationNode(position, add); + const n2 = buildInvocationNode(position, sub); + it('should accept non-duplicate connections', () => { + const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; + const r = validateConnection(c, [n1, n2], [], templates, null); + expect(r).toEqual(buildAcceptResult()); + }); + it('should reject duplicate connections', () => { + const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; + const e = buildEdge(n1.id, 'value', n2.id, 'a'); + const r = validateConnection(c, [n1, n2], [e], templates, null); + expect(r).toEqual(buildRejectResult('nodes.cannotDuplicateConnection')); + }); + it('should accept duplicate connections if the duplicate is an ignored edge', () => { + const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; + const e = buildEdge(n1.id, 'value', n2.id, 'a'); + const r = validateConnection(c, [n1, n2], [e], templates, e); + expect(r).toEqual(buildAcceptResult()); + }); + }); + + it('should reject connection to direct input', () => { + // Create cloned add template w/ a direct input + const addWithDirectAField = deepClone(add); + set(addWithDirectAField, 'inputs.a.input', 'direct'); + set(addWithDirectAField, 'type', 'addWithDirectAField'); + + const n1 = buildInvocationNode(position, add); + const n2 = buildInvocationNode(position, addWithDirectAField); + const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; + const r = validateConnection(c, [n1, n2], [], { add, addWithDirectAField }, null); + expect(r).toEqual(buildRejectResult('nodes.cannotConnectToDirectInput')); + }); + + it('should reject connection to a collect node with mismatched item types', () => { + const n1 = buildInvocationNode(position, add); + const n2 = buildInvocationNode(position, collect); + const n3 = buildInvocationNode(position, main_model_loader); + const nodes = [n1, n2, n3]; + const e1 = buildEdge(n1.id, 'value', n2.id, 'item'); + const edges = [e1]; + const c = { source: n3.id, sourceHandle: 'vae', target: n2.id, targetHandle: 'item' }; + const r = validateConnection(c, nodes, edges, templates, null); + expect(r).toEqual(buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes')); + }); + + it('should accept connection to a collect node with matching item types', () => { + const n1 = buildInvocationNode(position, add); + const n2 = buildInvocationNode(position, collect); + const n3 = buildInvocationNode(position, sub); + const nodes = [n1, n2, n3]; + const e1 = buildEdge(n1.id, 'value', n2.id, 'item'); + const edges = [e1]; + const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'item' }; + const r = validateConnection(c, nodes, edges, templates, null); + expect(r).toEqual(buildAcceptResult()); + }); + + it('should reject connections to target field that is already connected', () => { + const n1 = buildInvocationNode(position, add); + const n2 = buildInvocationNode(position, add); + const n3 = buildInvocationNode(position, add); + const nodes = [n1, n2, n3]; + const e1 = buildEdge(n1.id, 'value', n2.id, 'a'); + const edges = [e1]; + const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; + const r = validateConnection(c, nodes, edges, templates, null); + expect(r).toEqual(buildRejectResult('nodes.inputMayOnlyHaveOneConnection')); + }); + + it('should accept connections to target field that is already connected (ignored edge)', () => { + const n1 = buildInvocationNode(position, add); + const n2 = buildInvocationNode(position, add); + const n3 = buildInvocationNode(position, add); + const nodes = [n1, n2, n3]; + const e1 = buildEdge(n1.id, 'value', n2.id, 'a'); + const edges = [e1]; + const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; + const r = validateConnection(c, nodes, edges, templates, e1); + expect(r).toEqual(buildAcceptResult()); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts new file mode 100644 index 0000000000..d45a75ab9f --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts @@ -0,0 +1,109 @@ +import type { Templates } from 'features/nodes/store/types'; +import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual'; +import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType'; +import type { AnyNode } from 'features/nodes/types/invocation'; +import type { Connection as NullableConnection, Edge } from 'reactflow'; +import type { O } from 'ts-toolbelt'; + +type Connection = O.NonNullable; + +export type ValidateConnectionResult = { + isValid: boolean; + messageTKey?: string; +}; + +export type ValidateConnectionFunc = ( + connection: Connection, + nodes: AnyNode[], + edges: Edge[], + templates: Templates, + ignoreEdge: Edge | null +) => ValidateConnectionResult; + +export const buildResult = (isValid: boolean, messageTKey?: string): ValidateConnectionResult => ({ + isValid, + messageTKey, +}); + +const getEqualityPredicate = + (c: Connection) => + (e: Edge): boolean => { + return ( + e.target === c.target && + e.targetHandle === c.targetHandle && + e.source === c.source && + e.sourceHandle === c.sourceHandle + ); + }; + +export const buildAcceptResult = (): ValidateConnectionResult => ({ isValid: true }); +export const buildRejectResult = (messageTKey: string): ValidateConnectionResult => ({ isValid: false, messageTKey }); + +export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, templates, ignoreEdge) => { + if (c.source === c.target) { + return buildRejectResult('nodes.cannotConnectToSelf'); + } + + const filteredEdges = edges.filter((e) => e.id !== ignoreEdge?.id); + + if (filteredEdges.some(getEqualityPredicate(c))) { + // We already have a connection from this source to this target + return buildRejectResult('nodes.cannotDuplicateConnection'); + } + + const sourceNode = nodes.find((n) => n.id === c.source); + if (!sourceNode) { + return buildRejectResult('nodes.missingNode'); + } + + const targetNode = nodes.find((n) => n.id === c.target); + if (!targetNode) { + return buildRejectResult('nodes.missingNode'); + } + + const sourceTemplate = templates[sourceNode.data.type]; + if (!sourceTemplate) { + return buildRejectResult('nodes.missingInvocationTemplate'); + } + + const targetTemplate = templates[targetNode.data.type]; + if (!targetTemplate) { + return buildRejectResult('nodes.missingInvocationTemplate'); + } + + const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle]; + if (!sourceFieldTemplate) { + return buildRejectResult('nodes.missingFieldTemplate'); + } + + const targetFieldTemplate = targetTemplate.inputs[c.targetHandle]; + if (!targetFieldTemplate) { + return buildRejectResult('nodes.missingFieldTemplate'); + } + + if (targetFieldTemplate.input === 'direct') { + return buildRejectResult('nodes.cannotConnectToDirectInput'); + } + + if (targetNode.data.type === 'collect' && c.targetHandle === 'item') { + // Collect nodes shouldn't mix and match field types + const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); + if (collectItemType) { + if (!areTypesEqual(sourceFieldTemplate.type, collectItemType)) { + return buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes'); + } + } + } + + if ( + edges.find((e) => { + return e.target === c.target && e.targetHandle === c.targetHandle; + }) && + // except CollectionItem inputs can have multiples + targetFieldTemplate.type.name !== 'CollectionItemField' + ) { + return buildRejectResult('nodes.inputMayOnlyHaveOneConnection'); + } + + return buildAcceptResult(); +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.test.ts new file mode 100644 index 0000000000..d953fd973f --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.test.ts @@ -0,0 +1,222 @@ +import { describe, expect, it } from 'vitest'; + +import { validateConnectionTypes } from './validateConnectionTypes'; + +describe(validateConnectionTypes.name, () => { + describe('generic cases', () => { + it('should accept Scalar to Scalar of same type', () => { + const r = validateConnectionTypes( + { name: 'FooField', isCollection: false, isCollectionOrScalar: false }, + { name: 'FooField', isCollection: false, isCollectionOrScalar: false } + ); + expect(r).toBe(true); + }); + it('should accept Collection to Collection of same type', () => { + const r = validateConnectionTypes( + { name: 'FooField', isCollection: true, isCollectionOrScalar: false }, + { name: 'FooField', isCollection: true, isCollectionOrScalar: false } + ); + expect(r).toBe(true); + }); + it('should accept Scalar to CollectionOrScalar of same type', () => { + const r = validateConnectionTypes( + { name: 'FooField', isCollection: false, isCollectionOrScalar: false }, + { name: 'FooField', isCollection: false, isCollectionOrScalar: true } + ); + expect(r).toBe(true); + }); + it('should accept Collection to CollectionOrScalar of same type', () => { + const r = validateConnectionTypes( + { name: 'FooField', isCollection: true, isCollectionOrScalar: false }, + { name: 'FooField', isCollection: false, isCollectionOrScalar: true } + ); + expect(r).toBe(true); + }); + it('should reject Collection to Scalar of same type', () => { + const r = validateConnectionTypes( + { name: 'FooField', isCollection: true, isCollectionOrScalar: false }, + { name: 'FooField', isCollection: false, isCollectionOrScalar: false } + ); + expect(r).toBe(false); + }); + it('should reject CollectionOrScalar to Scalar of same type', () => { + const r = validateConnectionTypes( + { name: 'FooField', isCollection: false, isCollectionOrScalar: true }, + { name: 'FooField', isCollection: false, isCollectionOrScalar: false } + ); + expect(r).toBe(false); + }); + it('should reject mismatched types', () => { + const r = validateConnectionTypes( + { name: 'FooField', isCollection: false, isCollectionOrScalar: false }, + { name: 'BarField', isCollection: false, isCollectionOrScalar: false } + ); + expect(r).toBe(false); + }); + }); + + describe('special cases', () => { + it('should reject a collection input to a collection input', () => { + const r = validateConnectionTypes( + { name: 'CollectionField', isCollection: true, isCollectionOrScalar: false }, + { name: 'CollectionField', isCollection: true, isCollectionOrScalar: false } + ); + expect(r).toBe(false); + }); + + it('should accept equal types', () => { + const r = validateConnectionTypes( + { name: 'IntegerField', isCollection: false, isCollectionOrScalar: false }, + { name: 'IntegerField', isCollection: false, isCollectionOrScalar: false } + ); + expect(r).toBe(true); + }); + + describe('CollectionItemField', () => { + it('should accept CollectionItemField to any Scalar target', () => { + const r = validateConnectionTypes( + { name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false }, + { name: 'IntegerField', isCollection: false, isCollectionOrScalar: false } + ); + expect(r).toBe(true); + }); + it('should accept CollectionItemField to any CollectionOrScalar target', () => { + const r = validateConnectionTypes( + { name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false }, + { name: 'IntegerField', isCollection: false, isCollectionOrScalar: true } + ); + expect(r).toBe(true); + }); + it('should accept any non-Collection to CollectionItemField', () => { + const r = validateConnectionTypes( + { name: 'IntegerField', isCollection: false, isCollectionOrScalar: false }, + { name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false } + ); + expect(r).toBe(true); + }); + it('should reject any Collection to CollectionItemField', () => { + const r = validateConnectionTypes( + { name: 'IntegerField', isCollection: true, isCollectionOrScalar: false }, + { name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false } + ); + expect(r).toBe(false); + }); + it('should reject any CollectionOrScalar to CollectionItemField', () => { + const r = validateConnectionTypes( + { name: 'IntegerField', isCollection: false, isCollectionOrScalar: true }, + { name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false } + ); + expect(r).toBe(false); + }); + }); + + describe('CollectionOrScalar', () => { + it('should accept any Scalar of same type to CollectionOrScalar', () => { + const r = validateConnectionTypes( + { name: 'IntegerField', isCollection: false, isCollectionOrScalar: false }, + { name: 'IntegerField', isCollection: false, isCollectionOrScalar: true } + ); + expect(r).toBe(true); + }); + it('should accept any Collection of same type to CollectionOrScalar', () => { + const r = validateConnectionTypes( + { name: 'IntegerField', isCollection: true, isCollectionOrScalar: false }, + { name: 'IntegerField', isCollection: false, isCollectionOrScalar: true } + ); + expect(r).toBe(true); + }); + it('should accept any CollectionOrScalar of same type to CollectionOrScalar', () => { + const r = validateConnectionTypes( + { name: 'IntegerField', isCollection: false, isCollectionOrScalar: true }, + { name: 'IntegerField', isCollection: false, isCollectionOrScalar: true } + ); + expect(r).toBe(true); + }); + }); + + describe('CollectionField', () => { + it('should accept any CollectionField to any Collection type', () => { + const r = validateConnectionTypes( + { name: 'CollectionField', isCollection: false, isCollectionOrScalar: false }, + { name: 'IntegerField', isCollection: true, isCollectionOrScalar: false } + ); + expect(r).toBe(true); + }); + it('should accept any CollectionField to any CollectionOrScalar type', () => { + const r = validateConnectionTypes( + { name: 'CollectionField', isCollection: false, isCollectionOrScalar: false }, + { name: 'IntegerField', isCollection: false, isCollectionOrScalar: true } + ); + expect(r).toBe(true); + }); + }); + + describe('subtype handling', () => { + type TypePair = { t1: string; t2: string }; + const typePairs = [ + { t1: 'IntegerField', t2: 'FloatField' }, + { t1: 'IntegerField', t2: 'StringField' }, + { t1: 'FloatField', t2: 'StringField' }, + ]; + it.each(typePairs)('should accept Scalar $t1 to Scalar $t2', ({ t1, t2 }: TypePair) => { + const r = validateConnectionTypes( + { name: t1, isCollection: false, isCollectionOrScalar: false }, + { name: t2, isCollection: false, isCollectionOrScalar: false } + ); + expect(r).toBe(true); + }); + it.each(typePairs)('should accept Scalar $t1 to CollectionOrScalar $t2', ({ t1, t2 }: TypePair) => { + const r = validateConnectionTypes( + { name: t1, isCollection: false, isCollectionOrScalar: false }, + { name: t2, isCollection: false, isCollectionOrScalar: true } + ); + expect(r).toBe(true); + }); + it.each(typePairs)('should accept Collection $t1 to Collection $t2', ({ t1, t2 }: TypePair) => { + const r = validateConnectionTypes( + { name: t1, isCollection: true, isCollectionOrScalar: false }, + { name: t2, isCollection: false, isCollectionOrScalar: false } + ); + expect(r).toBe(true); + }); + it.each(typePairs)('should accept Collection $t1 to CollectionOrScalar $t2', ({ t1, t2 }: TypePair) => { + const r = validateConnectionTypes( + { name: t1, isCollection: true, isCollectionOrScalar: false }, + { name: t2, isCollection: false, isCollectionOrScalar: true } + ); + expect(r).toBe(true); + }); + it.each(typePairs)('should accept CollectionOrScalar $t1 to CollectionOrScalar $t2', ({ t1, t2 }: TypePair) => { + const r = validateConnectionTypes( + { name: t1, isCollection: false, isCollectionOrScalar: true }, + { name: t2, isCollection: false, isCollectionOrScalar: true } + ); + expect(r).toBe(true); + }); + }); + + describe('AnyField', () => { + it('should accept any Scalar type to AnyField', () => { + const r = validateConnectionTypes( + { name: 'FooField', isCollection: false, isCollectionOrScalar: false }, + { name: 'AnyField', isCollection: false, isCollectionOrScalar: false } + ); + expect(r).toBe(true); + }); + it('should accept any Collection type to AnyField', () => { + const r = validateConnectionTypes( + { name: 'FooField', isCollection: false, isCollectionOrScalar: false }, + { name: 'AnyField', isCollection: true, isCollectionOrScalar: false } + ); + expect(r).toBe(true); + }); + it('should accept any CollectionOrScalar type to AnyField', () => { + const r = validateConnectionTypes( + { name: 'FooField', isCollection: false, isCollectionOrScalar: false }, + { name: 'AnyField', isCollection: false, isCollectionOrScalar: true } + ); + expect(r).toBe(true); + }); + }); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts new file mode 100644 index 0000000000..092279e315 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts @@ -0,0 +1,69 @@ +import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual'; +import type { FieldType } from 'features/nodes/types/field'; + +/** + * Validates that the source and target types are compatible for a connection. + * @param sourceType The type of the source field. + * @param targetType The type of the target field. + * @returns True if the connection is valid, false otherwise. + */ +export const validateConnectionTypes = (sourceType: FieldType, targetType: FieldType) => { + // TODO: There's a bug with Collect -> Iterate nodes: + // https://github.com/invoke-ai/InvokeAI/issues/3956 + // Once this is resolved, we can remove this check. + if (sourceType.name === 'CollectionField' && targetType.name === 'CollectionField') { + return false; + } + + if (areTypesEqual(sourceType, targetType)) { + return true; + } + + /** + * Connection types must be the same for a connection, with exceptions: + * - CollectionItem can connect to any non-Collection + * - Non-Collections can connect to CollectionItem + * - Anything (non-Collections, Collections, CollectionOrScalar) can connect to CollectionOrScalar of the same base type + * - Generic Collection can connect to any other Collection or CollectionOrScalar + * - Any Collection can connect to a Generic Collection + */ + const isCollectionItemToNonCollection = sourceType.name === 'CollectionItemField' && !targetType.isCollection; + + const isNonCollectionToCollectionItem = + targetType.name === 'CollectionItemField' && !sourceType.isCollection && !sourceType.isCollectionOrScalar; + + const isAnythingToCollectionOrScalarOfSameBaseType = + targetType.isCollectionOrScalar && sourceType.name === targetType.name; + + const isGenericCollectionToAnyCollectionOrCollectionOrScalar = + sourceType.name === 'CollectionField' && (targetType.isCollection || targetType.isCollectionOrScalar); + + const isCollectionToGenericCollection = targetType.name === 'CollectionField' && sourceType.isCollection; + + const areBothTypesSingle = + !sourceType.isCollection && + !sourceType.isCollectionOrScalar && + !targetType.isCollection && + !targetType.isCollectionOrScalar; + + const isIntToFloat = areBothTypesSingle && sourceType.name === 'IntegerField' && targetType.name === 'FloatField'; + + const isIntOrFloatToString = + areBothTypesSingle && + (sourceType.name === 'IntegerField' || sourceType.name === 'FloatField') && + targetType.name === 'StringField'; + + const isTargetAnyType = targetType.name === 'AnyField'; + + // One of these must be true for the connection to be valid + return ( + isCollectionItemToNonCollection || + isNonCollectionToCollectionItem || + isAnythingToCollectionOrScalarOfSameBaseType || + isGenericCollectionToAnyCollectionOrCollectionOrScalar || + isCollectionToGenericCollection || + isIntToFloat || + isIntOrFloatToString || + isTargetAnyType + ); +}; diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts index a98f773c7e..8a1a0b5039 100644 --- a/invokeai/frontend/web/src/features/nodes/types/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -188,7 +188,6 @@ const zIntegerFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zIntegerFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zIntegerFieldType, - originalType: zFieldType.optional(), }); export type IntegerFieldValue = z.infer; export type IntegerFieldInputInstance = z.infer; @@ -217,7 +216,6 @@ const zFloatFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zFloatFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zFloatFieldType, - originalType: zFieldType.optional(), }); export type FloatFieldValue = z.infer; export type FloatFieldInputInstance = z.infer; @@ -243,7 +241,6 @@ const zStringFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zStringFieldType, - originalType: zFieldType.optional(), }); export type StringFieldValue = z.infer; @@ -268,7 +265,6 @@ const zBooleanFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zBooleanFieldType, - originalType: zFieldType.optional(), }); export type BooleanFieldValue = z.infer; export type BooleanFieldInputInstance = z.infer; @@ -294,7 +290,6 @@ const zEnumFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zEnumFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zEnumFieldType, - originalType: zFieldType.optional(), }); export type EnumFieldValue = z.infer; export type EnumFieldInputInstance = z.infer; @@ -318,7 +313,6 @@ const zImageFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zImageFieldType, - originalType: zFieldType.optional(), }); export type ImageFieldValue = z.infer; export type ImageFieldInputInstance = z.infer; @@ -342,7 +336,6 @@ const zBoardFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zBoardFieldType, - originalType: zFieldType.optional(), }); export type BoardFieldValue = z.infer; export type BoardFieldInputInstance = z.infer; @@ -366,7 +359,6 @@ const zColorFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zColorFieldType, - originalType: zFieldType.optional(), }); export type ColorFieldValue = z.infer; export type ColorFieldInputInstance = z.infer; @@ -390,7 +382,6 @@ const zMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zMainModelFieldType, - originalType: zFieldType.optional(), }); export type MainModelFieldValue = z.infer; export type MainModelFieldInputInstance = z.infer; @@ -413,7 +404,6 @@ const zModelIdentifierFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zModelIdentifierFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zModelIdentifierFieldType, - originalType: zFieldType.optional(), }); export type ModelIdentifierFieldValue = z.infer; export type ModelIdentifierFieldInputInstance = z.infer; @@ -437,7 +427,6 @@ const zSDXLMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zSDXLMainModelFieldType, - originalType: zFieldType.optional(), }); export type SDXLMainModelFieldInputInstance = z.infer; export type SDXLMainModelFieldInputTemplate = z.infer; @@ -461,7 +450,6 @@ const zSDXLRefinerModelFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zSDXLRefinerModelFieldType, - originalType: zFieldType.optional(), }); export type SDXLRefinerModelFieldValue = z.infer; export type SDXLRefinerModelFieldInputInstance = z.infer; @@ -485,7 +473,6 @@ const zVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zVAEModelFieldType, - originalType: zFieldType.optional(), }); export type VAEModelFieldValue = z.infer; export type VAEModelFieldInputInstance = z.infer; @@ -509,7 +496,6 @@ const zLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zLoRAModelFieldType, - originalType: zFieldType.optional(), }); export type LoRAModelFieldValue = z.infer; export type LoRAModelFieldInputInstance = z.infer; @@ -533,7 +519,6 @@ const zControlNetModelFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zControlNetModelFieldType, - originalType: zFieldType.optional(), }); export type ControlNetModelFieldValue = z.infer; export type ControlNetModelFieldInputInstance = z.infer; @@ -557,7 +542,6 @@ const zIPAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zIPAdapterModelFieldType, - originalType: zFieldType.optional(), }); export type IPAdapterModelFieldValue = z.infer; export type IPAdapterModelFieldInputInstance = z.infer; @@ -581,7 +565,6 @@ const zT2IAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zT2IAdapterModelFieldType, - originalType: zFieldType.optional(), }); export type T2IAdapterModelFieldValue = z.infer; export type T2IAdapterModelFieldInputInstance = z.infer; @@ -605,7 +588,6 @@ const zSchedulerFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zSchedulerFieldType, - originalType: zFieldType.optional(), }); export type SchedulerFieldValue = z.infer; export type SchedulerFieldInputInstance = z.infer; @@ -641,7 +623,6 @@ const zStatelessFieldInputTemplate = zFieldInputTemplateBase.extend({ }); const zStatelessFieldOutputTemplate = zFieldOutputTemplateBase.extend({ type: zStatelessFieldType, - originalType: zFieldType.optional(), }); export type StatelessFieldInputTemplate = z.infer; diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.test.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.test.ts index 480387a8a4..656bdc9d64 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.test.ts @@ -1,942 +1,19 @@ +import { schema, templates } from 'features/nodes/store/util/testUtils'; import { parseSchema } from 'features/nodes/util/schema/parseSchema'; import { omit, pick } from 'lodash-es'; -import type { OpenAPIV3_1 } from 'openapi-types'; import { describe, expect, it } from 'vitest'; describe('parseSchema', () => { it('should parse the schema', () => { - const templates = parseSchema(schema); - expect(templates).toEqual(expected); + const parsed = parseSchema(schema); + expect(parsed).toEqual(templates); }); it('should omit denied nodes', () => { - const templates = parseSchema(schema, undefined, ['add']); - expect(templates).toEqual(omit(expected, 'add')); + const parsed = parseSchema(schema, undefined, ['add']); + expect(parsed).toEqual(omit(templates, 'add')); }); it('should include only allowed nodes', () => { - const templates = parseSchema(schema, ['add']); - expect(templates).toEqual(pick(expected, 'add')); + const parsed = parseSchema(schema, ['add']); + expect(parsed).toEqual(pick(templates, 'add')); }); }); - -const expected = { - add: { - title: 'Add Integers', - type: 'add', - version: '1.0.1', - tags: ['math', 'add'], - description: 'Adds two numbers', - outputType: 'integer_output', - inputs: { - a: { - name: 'a', - title: 'A', - required: false, - description: 'The first number', - fieldKind: 'input', - input: 'any', - ui_hidden: false, - type: { - name: 'IntegerField', - isCollection: false, - isCollectionOrScalar: false, - }, - default: 0, - }, - b: { - name: 'b', - title: 'B', - required: false, - description: 'The second number', - fieldKind: 'input', - input: 'any', - ui_hidden: false, - type: { - name: 'IntegerField', - isCollection: false, - isCollectionOrScalar: false, - }, - default: 0, - }, - }, - outputs: { - value: { - fieldKind: 'output', - name: 'value', - title: 'Value', - description: 'The output integer', - type: { - name: 'IntegerField', - isCollection: false, - isCollectionOrScalar: false, - }, - ui_hidden: false, - }, - }, - useCache: true, - nodePack: 'invokeai', - classification: 'stable', - }, - scheduler: { - title: 'Scheduler', - type: 'scheduler', - version: '1.0.0', - tags: ['scheduler'], - description: 'Selects a scheduler.', - outputType: 'scheduler_output', - inputs: { - scheduler: { - name: 'scheduler', - title: 'Scheduler', - required: false, - description: 'Scheduler to use during inference', - fieldKind: 'input', - input: 'any', - ui_hidden: false, - ui_type: 'SchedulerField', - type: { - name: 'SchedulerField', - isCollection: false, - isCollectionOrScalar: false, - originalType: { - name: 'EnumField', - isCollection: false, - isCollectionOrScalar: false, - }, - }, - default: 'euler', - }, - }, - outputs: { - scheduler: { - fieldKind: 'output', - name: 'scheduler', - title: 'Scheduler', - description: 'Scheduler to use during inference', - type: { - name: 'SchedulerField', - isCollection: false, - isCollectionOrScalar: false, - originalType: { - name: 'EnumField', - isCollection: false, - isCollectionOrScalar: false, - }, - }, - ui_hidden: false, - ui_type: 'SchedulerField', - }, - }, - useCache: true, - nodePack: 'invokeai', - classification: 'stable', - }, - main_model_loader: { - title: 'Main Model', - type: 'main_model_loader', - version: '1.0.2', - tags: ['model'], - description: 'Loads a main model, outputting its submodels.', - outputType: 'model_loader_output', - inputs: { - model: { - name: 'model', - title: 'Model', - required: true, - description: 'Main model (UNet, VAE, CLIP) to load', - fieldKind: 'input', - input: 'direct', - ui_hidden: false, - ui_type: 'MainModelField', - type: { - name: 'MainModelField', - isCollection: false, - isCollectionOrScalar: false, - originalType: { - name: 'ModelIdentifierField', - isCollection: false, - isCollectionOrScalar: false, - }, - }, - }, - }, - outputs: { - vae: { - fieldKind: 'output', - name: 'vae', - title: 'VAE', - description: 'VAE', - type: { - name: 'VAEField', - isCollection: false, - isCollectionOrScalar: false, - }, - ui_hidden: false, - }, - clip: { - fieldKind: 'output', - name: 'clip', - title: 'CLIP', - description: 'CLIP (tokenizer, text encoder, LoRAs) and skipped layer count', - type: { - name: 'CLIPField', - isCollection: false, - isCollectionOrScalar: false, - }, - ui_hidden: false, - }, - unet: { - fieldKind: 'output', - name: 'unet', - title: 'UNet', - description: 'UNet (scheduler, LoRAs)', - type: { - name: 'UNetField', - isCollection: false, - isCollectionOrScalar: false, - }, - ui_hidden: false, - }, - }, - useCache: true, - nodePack: 'invokeai', - classification: 'stable', - }, - collect: { - title: 'Collect', - type: 'collect', - version: '1.0.0', - tags: [], - description: 'Collects values into a collection', - outputType: 'collect_output', - inputs: { - item: { - name: 'item', - title: 'Collection Item', - required: false, - description: 'The item to collect (all inputs must be of the same type)', - fieldKind: 'input', - input: 'connection', - ui_hidden: false, - ui_type: 'CollectionItemField', - type: { - name: 'CollectionItemField', - isCollection: false, - isCollectionOrScalar: false, - }, - }, - }, - outputs: { - collection: { - fieldKind: 'output', - name: 'collection', - title: 'Collection', - description: 'The collection of input items', - type: { - name: 'CollectionField', - isCollection: true, - isCollectionOrScalar: false, - }, - ui_hidden: false, - ui_type: 'CollectionField', - }, - }, - useCache: true, - classification: 'stable', - }, -}; - -const schema = { - openapi: '3.1.0', - info: { - title: 'Invoke - Community Edition', - description: 'An API for invoking AI image operations', - version: '1.0.0', - }, - components: { - schemas: { - AddInvocation: { - properties: { - id: { - type: 'string', - title: 'Id', - description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', - field_kind: 'node_attribute', - }, - is_intermediate: { - type: 'boolean', - title: 'Is Intermediate', - description: 'Whether or not this is an intermediate invocation.', - default: false, - field_kind: 'node_attribute', - ui_type: 'IsIntermediate', - }, - use_cache: { - type: 'boolean', - title: 'Use Cache', - description: 'Whether or not to use the cache', - default: true, - field_kind: 'node_attribute', - }, - a: { - type: 'integer', - title: 'A', - description: 'The first number', - default: 0, - field_kind: 'input', - input: 'any', - orig_default: 0, - orig_required: false, - ui_hidden: false, - }, - b: { - type: 'integer', - title: 'B', - description: 'The second number', - default: 0, - field_kind: 'input', - input: 'any', - orig_default: 0, - orig_required: false, - ui_hidden: false, - }, - type: { - type: 'string', - enum: ['add'], - const: 'add', - title: 'type', - default: 'add', - field_kind: 'node_attribute', - }, - }, - type: 'object', - required: ['type', 'id'], - title: 'Add Integers', - description: 'Adds two numbers', - category: 'math', - classification: 'stable', - node_pack: 'invokeai', - tags: ['math', 'add'], - version: '1.0.1', - output: { - $ref: '#/components/schemas/IntegerOutput', - }, - class: 'invocation', - }, - IntegerOutput: { - description: 'Base class for nodes that output a single integer', - properties: { - value: { - description: 'The output integer', - field_kind: 'output', - title: 'Value', - type: 'integer', - ui_hidden: false, - }, - type: { - const: 'integer_output', - default: 'integer_output', - enum: ['integer_output'], - field_kind: 'node_attribute', - title: 'type', - type: 'string', - }, - }, - required: ['value', 'type', 'type'], - title: 'IntegerOutput', - type: 'object', - class: 'output', - }, - SchedulerInvocation: { - properties: { - id: { - type: 'string', - title: 'Id', - description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', - field_kind: 'node_attribute', - }, - is_intermediate: { - type: 'boolean', - title: 'Is Intermediate', - description: 'Whether or not this is an intermediate invocation.', - default: false, - field_kind: 'node_attribute', - ui_type: 'IsIntermediate', - }, - use_cache: { - type: 'boolean', - title: 'Use Cache', - description: 'Whether or not to use the cache', - default: true, - field_kind: 'node_attribute', - }, - scheduler: { - type: 'string', - enum: [ - 'ddim', - 'ddpm', - 'deis', - 'lms', - 'lms_k', - 'pndm', - 'heun', - 'heun_k', - 'euler', - 'euler_k', - 'euler_a', - 'kdpm_2', - 'kdpm_2_a', - 'dpmpp_2s', - 'dpmpp_2s_k', - 'dpmpp_2m', - 'dpmpp_2m_k', - 'dpmpp_2m_sde', - 'dpmpp_2m_sde_k', - 'dpmpp_sde', - 'dpmpp_sde_k', - 'unipc', - 'lcm', - 'tcd', - ], - title: 'Scheduler', - description: 'Scheduler to use during inference', - default: 'euler', - field_kind: 'input', - input: 'any', - orig_default: 'euler', - orig_required: false, - ui_hidden: false, - ui_type: 'SchedulerField', - }, - type: { - type: 'string', - enum: ['scheduler'], - const: 'scheduler', - title: 'type', - default: 'scheduler', - field_kind: 'node_attribute', - }, - }, - type: 'object', - required: ['type', 'id'], - title: 'Scheduler', - description: 'Selects a scheduler.', - category: 'latents', - classification: 'stable', - node_pack: 'invokeai', - tags: ['scheduler'], - version: '1.0.0', - output: { - $ref: '#/components/schemas/SchedulerOutput', - }, - class: 'invocation', - }, - SchedulerOutput: { - properties: { - scheduler: { - description: 'Scheduler to use during inference', - enum: [ - 'ddim', - 'ddpm', - 'deis', - 'lms', - 'lms_k', - 'pndm', - 'heun', - 'heun_k', - 'euler', - 'euler_k', - 'euler_a', - 'kdpm_2', - 'kdpm_2_a', - 'dpmpp_2s', - 'dpmpp_2s_k', - 'dpmpp_2m', - 'dpmpp_2m_k', - 'dpmpp_2m_sde', - 'dpmpp_2m_sde_k', - 'dpmpp_sde', - 'dpmpp_sde_k', - 'unipc', - 'lcm', - 'tcd', - ], - field_kind: 'output', - title: 'Scheduler', - type: 'string', - ui_hidden: false, - ui_type: 'SchedulerField', - }, - type: { - const: 'scheduler_output', - default: 'scheduler_output', - enum: ['scheduler_output'], - field_kind: 'node_attribute', - title: 'type', - type: 'string', - }, - }, - required: ['scheduler', 'type', 'type'], - title: 'SchedulerOutput', - type: 'object', - class: 'output', - }, - MainModelLoaderInvocation: { - properties: { - id: { - type: 'string', - title: 'Id', - description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', - field_kind: 'node_attribute', - }, - is_intermediate: { - type: 'boolean', - title: 'Is Intermediate', - description: 'Whether or not this is an intermediate invocation.', - default: false, - field_kind: 'node_attribute', - ui_type: 'IsIntermediate', - }, - use_cache: { - type: 'boolean', - title: 'Use Cache', - description: 'Whether or not to use the cache', - default: true, - field_kind: 'node_attribute', - }, - model: { - allOf: [ - { - $ref: '#/components/schemas/ModelIdentifierField', - }, - ], - description: 'Main model (UNet, VAE, CLIP) to load', - field_kind: 'input', - input: 'direct', - orig_required: true, - ui_hidden: false, - ui_type: 'MainModelField', - }, - type: { - type: 'string', - enum: ['main_model_loader'], - const: 'main_model_loader', - title: 'type', - default: 'main_model_loader', - field_kind: 'node_attribute', - }, - }, - type: 'object', - required: ['model', 'type', 'id'], - title: 'Main Model', - description: 'Loads a main model, outputting its submodels.', - category: 'model', - classification: 'stable', - node_pack: 'invokeai', - tags: ['model'], - version: '1.0.2', - output: { - $ref: '#/components/schemas/ModelLoaderOutput', - }, - class: 'invocation', - }, - ModelIdentifierField: { - properties: { - key: { - description: "The model's unique key", - title: 'Key', - type: 'string', - }, - hash: { - description: "The model's BLAKE3 hash", - title: 'Hash', - type: 'string', - }, - name: { - description: "The model's name", - title: 'Name', - type: 'string', - }, - base: { - allOf: [ - { - $ref: '#/components/schemas/BaseModelType', - }, - ], - description: "The model's base model type", - }, - type: { - allOf: [ - { - $ref: '#/components/schemas/ModelType', - }, - ], - description: "The model's type", - }, - submodel_type: { - anyOf: [ - { - $ref: '#/components/schemas/SubModelType', - }, - { - type: 'null', - }, - ], - default: null, - description: 'The submodel to load, if this is a main model', - }, - }, - required: ['key', 'hash', 'name', 'base', 'type'], - title: 'ModelIdentifierField', - type: 'object', - }, - BaseModelType: { - description: 'Base model type.', - enum: ['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner'], - title: 'BaseModelType', - type: 'string', - }, - ModelType: { - description: 'Model type.', - enum: ['onnx', 'main', 'vae', 'lora', 'controlnet', 'embedding', 'ip_adapter', 'clip_vision', 't2i_adapter'], - title: 'ModelType', - type: 'string', - }, - SubModelType: { - description: 'Submodel type.', - enum: [ - 'unet', - 'text_encoder', - 'text_encoder_2', - 'tokenizer', - 'tokenizer_2', - 'vae', - 'vae_decoder', - 'vae_encoder', - 'scheduler', - 'safety_checker', - ], - title: 'SubModelType', - type: 'string', - }, - ModelLoaderOutput: { - description: 'Model loader output', - properties: { - vae: { - allOf: [ - { - $ref: '#/components/schemas/VAEField', - }, - ], - description: 'VAE', - field_kind: 'output', - title: 'VAE', - ui_hidden: false, - }, - type: { - const: 'model_loader_output', - default: 'model_loader_output', - enum: ['model_loader_output'], - field_kind: 'node_attribute', - title: 'type', - type: 'string', - }, - clip: { - allOf: [ - { - $ref: '#/components/schemas/CLIPField', - }, - ], - description: 'CLIP (tokenizer, text encoder, LoRAs) and skipped layer count', - field_kind: 'output', - title: 'CLIP', - ui_hidden: false, - }, - unet: { - allOf: [ - { - $ref: '#/components/schemas/UNetField', - }, - ], - description: 'UNet (scheduler, LoRAs)', - field_kind: 'output', - title: 'UNet', - ui_hidden: false, - }, - }, - required: ['vae', 'type', 'clip', 'unet', 'type'], - title: 'ModelLoaderOutput', - type: 'object', - class: 'output', - }, - UNetField: { - properties: { - unet: { - allOf: [ - { - $ref: '#/components/schemas/ModelIdentifierField', - }, - ], - description: 'Info to load unet submodel', - }, - scheduler: { - allOf: [ - { - $ref: '#/components/schemas/ModelIdentifierField', - }, - ], - description: 'Info to load scheduler submodel', - }, - loras: { - description: 'LoRAs to apply on model loading', - items: { - $ref: '#/components/schemas/LoRAField', - }, - title: 'Loras', - type: 'array', - }, - seamless_axes: { - description: 'Axes("x" and "y") to which apply seamless', - items: { - type: 'string', - }, - title: 'Seamless Axes', - type: 'array', - }, - freeu_config: { - anyOf: [ - { - $ref: '#/components/schemas/FreeUConfig', - }, - { - type: 'null', - }, - ], - default: null, - description: 'FreeU configuration', - }, - }, - required: ['unet', 'scheduler', 'loras'], - title: 'UNetField', - type: 'object', - class: 'output', - }, - LoRAField: { - properties: { - lora: { - allOf: [ - { - $ref: '#/components/schemas/ModelIdentifierField', - }, - ], - description: 'Info to load lora model', - }, - weight: { - description: 'Weight to apply to lora model', - title: 'Weight', - type: 'number', - }, - }, - required: ['lora', 'weight'], - title: 'LoRAField', - type: 'object', - class: 'output', - }, - FreeUConfig: { - description: - 'Configuration for the FreeU hyperparameters.\n- https://huggingface.co/docs/diffusers/main/en/using-diffusers/freeu\n- https://github.com/ChenyangSi/FreeU', - properties: { - s1: { - description: - 'Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.', - maximum: 3.0, - minimum: -1.0, - title: 'S1', - type: 'number', - }, - s2: { - description: - 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.', - maximum: 3.0, - minimum: -1.0, - title: 'S2', - type: 'number', - }, - b1: { - description: 'Scaling factor for stage 1 to amplify the contributions of backbone features.', - maximum: 3.0, - minimum: -1.0, - title: 'B1', - type: 'number', - }, - b2: { - description: 'Scaling factor for stage 2 to amplify the contributions of backbone features.', - maximum: 3.0, - minimum: -1.0, - title: 'B2', - type: 'number', - }, - }, - required: ['s1', 's2', 'b1', 'b2'], - title: 'FreeUConfig', - type: 'object', - class: 'output', - }, - VAEField: { - properties: { - vae: { - allOf: [ - { - $ref: '#/components/schemas/ModelIdentifierField', - }, - ], - description: 'Info to load vae submodel', - }, - seamless_axes: { - description: 'Axes("x" and "y") to which apply seamless', - items: { - type: 'string', - }, - title: 'Seamless Axes', - type: 'array', - }, - }, - required: ['vae'], - title: 'VAEField', - type: 'object', - class: 'output', - }, - CLIPField: { - properties: { - tokenizer: { - allOf: [ - { - $ref: '#/components/schemas/ModelIdentifierField', - }, - ], - description: 'Info to load tokenizer submodel', - }, - text_encoder: { - allOf: [ - { - $ref: '#/components/schemas/ModelIdentifierField', - }, - ], - description: 'Info to load text_encoder submodel', - }, - skipped_layers: { - description: 'Number of skipped layers in text_encoder', - title: 'Skipped Layers', - type: 'integer', - }, - loras: { - description: 'LoRAs to apply on model loading', - items: { - $ref: '#/components/schemas/LoRAField', - }, - title: 'Loras', - type: 'array', - }, - }, - required: ['tokenizer', 'text_encoder', 'skipped_layers', 'loras'], - title: 'CLIPField', - type: 'object', - class: 'output', - }, - CollectInvocation: { - properties: { - id: { - type: 'string', - title: 'Id', - description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', - field_kind: 'node_attribute', - }, - is_intermediate: { - type: 'boolean', - title: 'Is Intermediate', - description: 'Whether or not this is an intermediate invocation.', - default: false, - field_kind: 'node_attribute', - ui_type: 'IsIntermediate', - }, - use_cache: { - type: 'boolean', - title: 'Use Cache', - description: 'Whether or not to use the cache', - default: true, - field_kind: 'node_attribute', - }, - item: { - anyOf: [ - {}, - { - type: 'null', - }, - ], - title: 'Collection Item', - description: 'The item to collect (all inputs must be of the same type)', - field_kind: 'input', - input: 'connection', - orig_required: false, - ui_hidden: false, - ui_type: 'CollectionItemField', - }, - collection: { - items: {}, - type: 'array', - title: 'Collection', - description: 'The collection, will be provided on execution', - default: [], - field_kind: 'input', - input: 'any', - orig_default: [], - orig_required: false, - ui_hidden: true, - }, - type: { - type: 'string', - enum: ['collect'], - const: 'collect', - title: 'type', - default: 'collect', - field_kind: 'node_attribute', - }, - }, - type: 'object', - required: ['type', 'id'], - title: 'CollectInvocation', - description: 'Collects values into a collection', - classification: 'stable', - version: '1.0.0', - output: { - $ref: '#/components/schemas/CollectInvocationOutput', - }, - class: 'invocation', - }, - CollectInvocationOutput: { - properties: { - collection: { - description: 'The collection of input items', - field_kind: 'output', - items: {}, - title: 'Collection', - type: 'array', - ui_hidden: false, - ui_type: 'CollectionField', - }, - type: { - const: 'collect_output', - default: 'collect_output', - enum: ['collect_output'], - field_kind: 'node_attribute', - title: 'type', - type: 'string', - }, - }, - required: ['collection', 'type', 'type'], - title: 'CollectInvocationOutput', - type: 'object', - class: 'output', - }, - }, - }, -} as OpenAPIV3_1.Document; From 04a596179b3ffdd69997b8d66b42faac349b5f4c Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 00:25:58 +1000 Subject: [PATCH 022/207] tests(ui): finish test cases for validateConnection --- .../features/nodes/store/util/testUtils.ts | 344 +++++++++++++++++- .../store/util/validateConnection.test.ts | 22 +- .../nodes/store/util/validateConnection.ts | 28 +- 3 files changed, 388 insertions(+), 6 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts b/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts index efde3336e2..b68ff8bef6 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts @@ -298,7 +298,149 @@ export const main_model_loader: InvocationTemplate = { useCache: true, nodePack: 'invokeai', classification: 'stable', -} +}; + +export const img_resize: InvocationTemplate = { + title: 'Resize Image', + type: 'img_resize', + version: '1.2.2', + tags: ['image', 'resize'], + description: 'Resizes an image to specific dimensions', + outputType: 'image_output', + inputs: { + board: { + name: 'board', + title: 'Board', + required: false, + description: 'The board to save the image to', + fieldKind: 'input', + input: 'direct', + ui_hidden: false, + type: { + name: 'BoardField', + isCollection: false, + isCollectionOrScalar: false, + }, + }, + metadata: { + name: 'metadata', + title: 'Metadata', + required: false, + description: 'Optional metadata to be saved with the image', + fieldKind: 'input', + input: 'connection', + ui_hidden: false, + type: { + name: 'MetadataField', + isCollection: false, + isCollectionOrScalar: false, + }, + }, + image: { + name: 'image', + title: 'Image', + required: true, + description: 'The image to resize', + fieldKind: 'input', + input: 'any', + ui_hidden: false, + type: { + name: 'ImageField', + isCollection: false, + isCollectionOrScalar: false, + }, + }, + width: { + name: 'width', + title: 'Width', + required: false, + description: 'The width to resize to (px)', + fieldKind: 'input', + input: 'any', + ui_hidden: false, + type: { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + }, + default: 512, + exclusiveMinimum: 0, + }, + height: { + name: 'height', + title: 'Height', + required: false, + description: 'The height to resize to (px)', + fieldKind: 'input', + input: 'any', + ui_hidden: false, + type: { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + }, + default: 512, + exclusiveMinimum: 0, + }, + resample_mode: { + name: 'resample_mode', + title: 'Resample Mode', + required: false, + description: 'The resampling mode', + fieldKind: 'input', + input: 'any', + ui_hidden: false, + type: { + name: 'EnumField', + isCollection: false, + isCollectionOrScalar: false, + }, + options: ['nearest', 'box', 'bilinear', 'hamming', 'bicubic', 'lanczos'], + default: 'bicubic', + }, + }, + outputs: { + image: { + fieldKind: 'output', + name: 'image', + title: 'Image', + description: 'The output image', + type: { + name: 'ImageField', + isCollection: false, + isCollectionOrScalar: false, + }, + ui_hidden: false, + }, + width: { + fieldKind: 'output', + name: 'width', + title: 'Width', + description: 'The width of the image in pixels', + type: { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + }, + ui_hidden: false, + }, + height: { + fieldKind: 'output', + name: 'height', + title: 'Height', + description: 'The height of the image in pixels', + type: { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + }, + ui_hidden: false, + }, + }, + useCache: true, + nodePack: 'invokeai', + classification: 'stable', +}; export const templates: Templates = { add, @@ -306,6 +448,7 @@ export const templates: Templates = { collect, scheduler, main_model_loader, + img_resize, }; export const schema = { @@ -1068,6 +1211,205 @@ export const schema = { }, class: 'invocation', }, + ImageResizeInvocation: { + properties: { + board: { + anyOf: [ + { + $ref: '#/components/schemas/BoardField', + }, + { + type: 'null', + }, + ], + description: 'The board to save the image to', + field_kind: 'internal', + input: 'direct', + orig_required: false, + ui_hidden: false, + }, + metadata: { + anyOf: [ + { + $ref: '#/components/schemas/MetadataField', + }, + { + type: 'null', + }, + ], + description: 'Optional metadata to be saved with the image', + field_kind: 'internal', + input: 'connection', + orig_required: false, + ui_hidden: false, + }, + id: { + type: 'string', + title: 'Id', + description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', + field_kind: 'node_attribute', + }, + is_intermediate: { + type: 'boolean', + title: 'Is Intermediate', + description: 'Whether or not this is an intermediate invocation.', + default: false, + field_kind: 'node_attribute', + ui_type: 'IsIntermediate', + }, + use_cache: { + type: 'boolean', + title: 'Use Cache', + description: 'Whether or not to use the cache', + default: true, + field_kind: 'node_attribute', + }, + image: { + allOf: [ + { + $ref: '#/components/schemas/ImageField', + }, + ], + description: 'The image to resize', + field_kind: 'input', + input: 'any', + orig_required: true, + ui_hidden: false, + }, + width: { + type: 'integer', + exclusiveMinimum: 0, + title: 'Width', + description: 'The width to resize to (px)', + default: 512, + field_kind: 'input', + input: 'any', + orig_default: 512, + orig_required: false, + ui_hidden: false, + }, + height: { + type: 'integer', + exclusiveMinimum: 0, + title: 'Height', + description: 'The height to resize to (px)', + default: 512, + field_kind: 'input', + input: 'any', + orig_default: 512, + orig_required: false, + ui_hidden: false, + }, + resample_mode: { + type: 'string', + enum: ['nearest', 'box', 'bilinear', 'hamming', 'bicubic', 'lanczos'], + title: 'Resample Mode', + description: 'The resampling mode', + default: 'bicubic', + field_kind: 'input', + input: 'any', + orig_default: 'bicubic', + orig_required: false, + ui_hidden: false, + }, + type: { + type: 'string', + enum: ['img_resize'], + const: 'img_resize', + title: 'type', + default: 'img_resize', + field_kind: 'node_attribute', + }, + }, + type: 'object', + required: ['type', 'id'], + title: 'Resize Image', + description: 'Resizes an image to specific dimensions', + category: 'image', + classification: 'stable', + node_pack: 'invokeai', + tags: ['image', 'resize'], + version: '1.2.2', + output: { + $ref: '#/components/schemas/ImageOutput', + }, + class: 'invocation', + }, + ImageField: { + description: 'An image primitive field', + properties: { + image_name: { + description: 'The name of the image', + title: 'Image Name', + type: 'string', + }, + }, + required: ['image_name'], + title: 'ImageField', + type: 'object', + class: 'output', + }, + ImageOutput: { + description: 'Base class for nodes that output a single image', + properties: { + image: { + allOf: [ + { + $ref: '#/components/schemas/ImageField', + }, + ], + description: 'The output image', + field_kind: 'output', + ui_hidden: false, + }, + width: { + description: 'The width of the image in pixels', + field_kind: 'output', + title: 'Width', + type: 'integer', + ui_hidden: false, + }, + height: { + description: 'The height of the image in pixels', + field_kind: 'output', + title: 'Height', + type: 'integer', + ui_hidden: false, + }, + type: { + const: 'image_output', + default: 'image_output', + enum: ['image_output'], + field_kind: 'node_attribute', + title: 'type', + type: 'string', + }, + }, + required: ['image', 'width', 'height', 'type', 'type'], + title: 'ImageOutput', + type: 'object', + class: 'output', + }, + MetadataField: { + description: + 'Pydantic model for metadata with custom root of type dict[str, Any].\nMetadata is stored without a strict schema.', + title: 'MetadataField', + type: 'object', + class: 'output', + }, + BoardField: { + properties: { + board_id: { + type: 'string', + title: 'Board Id', + description: 'The id of the board', + }, + }, + type: 'object', + required: ['board_id'], + title: 'BoardField', + description: 'A board primitive field', + }, }, }, } as OpenAPIV3_1.Document; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts index 5d10ef368b..cf05b4deb6 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts @@ -3,7 +3,7 @@ import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNod import { set } from 'lodash-es'; import { describe, expect, it } from 'vitest'; -import { add, buildEdge, collect, main_model_loader, position, sub, templates } from './testUtils'; +import { add, buildEdge, collect, img_resize, main_model_loader, position, sub, templates } from './testUtils'; import { buildAcceptResult, buildRejectResult, validateConnection } from './validateConnection'; describe(validateConnection.name, () => { @@ -146,4 +146,24 @@ describe(validateConnection.name, () => { const r = validateConnection(c, nodes, edges, templates, e1); expect(r).toEqual(buildAcceptResult()); }); + + it('should reject connections between invalid types', () => { + const n1 = buildInvocationNode(position, add); + const n2 = buildInvocationNode(position, img_resize); + const nodes = [n1, n2]; + const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'image' }; + const r = validateConnection(c, nodes, [], templates, null); + expect(r).toEqual(buildRejectResult('nodes.fieldTypesMustMatch')); + }); + + it('should reject connections that would create cycles', () => { + const n1 = buildInvocationNode(position, add); + const n2 = buildInvocationNode(position, sub); + const nodes = [n1, n2]; + const e1 = buildEdge(n1.id, 'value', n2.id, 'a'); + const edges = [e1]; + const c = { source: n2.id, sourceHandle: 'value', target: n1.id, targetHandle: 'a' }; + const r = validateConnection(c, nodes, edges, templates, null); + expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle')); + }); }); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts index d45a75ab9f..db8b7b737e 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts @@ -1,6 +1,8 @@ import type { Templates } from 'features/nodes/store/types'; import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual'; import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType'; +import { getHasCycles } from 'features/nodes/store/util/getHasCycles'; +import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; import type { AnyNode } from 'features/nodes/types/invocation'; import type { Connection as NullableConnection, Edge } from 'reactflow'; import type { O } from 'ts-toolbelt'; @@ -36,6 +38,12 @@ const getEqualityPredicate = ); }; +const getTargetEqualityPredicate = + (c: Connection) => + (e: Edge): boolean => { + return e.target === c.target && e.targetHandle === c.targetHandle; + }; + export const buildAcceptResult = (): ValidateConnectionResult => ({ isValid: true }); export const buildRejectResult = (messageTKey: string): ValidateConnectionResult => ({ isValid: false, messageTKey }); @@ -44,6 +52,12 @@ export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, temp return buildRejectResult('nodes.cannotConnectToSelf'); } + /** + * We may need to ignore an edge when validating a connection. + * + * For example, while an edge is being updated, it still exists in the array of edges. As we validate the new connection, + * the user experience should be that the edge is temporarily removed from the graph, so we need to ignore it. + */ const filteredEdges = edges.filter((e) => e.id !== ignoreEdge?.id); if (filteredEdges.some(getEqualityPredicate(c))) { @@ -96,14 +110,20 @@ export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, temp } if ( - edges.find((e) => { - return e.target === c.target && e.targetHandle === c.targetHandle; - }) && - // except CollectionItem inputs can have multiples + filteredEdges.find(getTargetEqualityPredicate(c)) && + // except CollectionItem inputs can have multiple input connections targetFieldTemplate.type.name !== 'CollectionItemField' ) { return buildRejectResult('nodes.inputMayOnlyHaveOneConnection'); } + if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) { + return buildRejectResult('nodes.fieldTypesMustMatch'); + } + + if (getHasCycles(c.source, c.target, nodes, edges)) { + return buildRejectResult('nodes.connectionWouldCreateCycle'); + } + return buildAcceptResult(); }; From 00c2d8f95d6f966feb631942e66d3b09f387b203 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 00:30:26 +1000 Subject: [PATCH 023/207] tidy(ui): areTypesEqual var names --- .../nodes/store/util/areTypesEqual.ts | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/store/util/areTypesEqual.ts b/invokeai/frontend/web/src/features/nodes/store/util/areTypesEqual.ts index e01b48b972..8502cb563c 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/areTypesEqual.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/areTypesEqual.ts @@ -3,27 +3,26 @@ import { isEqual, omit } from 'lodash-es'; /** * Checks if two types are equal. If the field types have original types, those are also compared. Any match is - * considered equal. For example, if the source type and original target type match, the types are considered equal. - * @param sourceType The type of the source field. - * @param targetType The type of the target field. + * considered equal. For example, if the first type and original second type match, the types are considered equal. + * @param firstType The first type to compare. + * @param secondType The second type to compare. * @returns True if the types are equal, false otherwise. */ - -export const areTypesEqual = (sourceType: FieldType, targetType: FieldType) => { - const _sourceType = 'originalType' in sourceType ? omit(sourceType, 'originalType') : sourceType; - const _targetType = 'originalType' in targetType ? omit(targetType, 'originalType') : targetType; - const _sourceTypeOriginal = 'originalType' in sourceType ? sourceType.originalType : null; - const _targetTypeOriginal = 'originalType' in targetType ? targetType.originalType : null; - if (isEqual(_sourceType, _targetType)) { +export const areTypesEqual = (firstType: FieldType, secondType: FieldType) => { + const _firstType = 'originalType' in firstType ? omit(firstType, 'originalType') : firstType; + const _secondType = 'originalType' in secondType ? omit(secondType, 'originalType') : secondType; + const _originalFirstType = 'originalType' in firstType ? firstType.originalType : null; + const _originalSecondType = 'originalType' in secondType ? secondType.originalType : null; + if (isEqual(_firstType, _secondType)) { return true; } - if (_targetTypeOriginal && isEqual(_sourceType, _targetTypeOriginal)) { + if (_originalSecondType && isEqual(_firstType, _originalSecondType)) { return true; } - if (_sourceTypeOriginal && isEqual(_sourceTypeOriginal, _targetType)) { + if (_originalFirstType && isEqual(_originalFirstType, _secondType)) { return true; } - if (_sourceTypeOriginal && _targetTypeOriginal && isEqual(_sourceTypeOriginal, _targetTypeOriginal)) { + if (_originalFirstType && _originalSecondType && isEqual(_originalFirstType, _originalSecondType)) { return true; } return false; From 059d5a682c8d5556d1dc3634d2b5026901d0aeeb Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 00:30:49 +1000 Subject: [PATCH 024/207] tidy(ui): validateConnection code clarity --- .../nodes/store/util/validateConnection.ts | 22 +++++++++---------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts index db8b7b737e..debf294557 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts @@ -56,7 +56,8 @@ export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, temp * We may need to ignore an edge when validating a connection. * * For example, while an edge is being updated, it still exists in the array of edges. As we validate the new connection, - * the user experience should be that the edge is temporarily removed from the graph, so we need to ignore it. + * the user experience should be that the edge is temporarily removed from the graph, so we need to ignore it, else + * the validation will fail unexpectedly. */ const filteredEdges = edges.filter((e) => e.id !== ignoreEdge?.id); @@ -100,21 +101,18 @@ export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, temp } if (targetNode.data.type === 'collect' && c.targetHandle === 'item') { - // Collect nodes shouldn't mix and match field types + // Collect nodes shouldn't mix and match field types. const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); - if (collectItemType) { - if (!areTypesEqual(sourceFieldTemplate.type, collectItemType)) { - return buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes'); - } + if (collectItemType && !areTypesEqual(sourceFieldTemplate.type, collectItemType)) { + return buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes'); } } - if ( - filteredEdges.find(getTargetEqualityPredicate(c)) && - // except CollectionItem inputs can have multiple input connections - targetFieldTemplate.type.name !== 'CollectionItemField' - ) { - return buildRejectResult('nodes.inputMayOnlyHaveOneConnection'); + if (filteredEdges.find(getTargetEqualityPredicate(c))) { + // CollectionItemField inputs can have multiple input connections + if (targetFieldTemplate.type.name !== 'CollectionItemField') { + return buildRejectResult('nodes.inputMayOnlyHaveOneConnection'); + } } if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) { From 8074a802d6549b32ed60d5ee973c50c397018dac Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 00:53:27 +1000 Subject: [PATCH 025/207] tests(ui): coverage for validateConnectionTypes --- .../util/validateConnectionTypes.test.ts | 2 +- .../store/util/validateConnectionTypes.ts | 30 +++++++++++-------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.test.ts index d953fd973f..10344dd349 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.test.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.test.ts @@ -175,7 +175,7 @@ describe(validateConnectionTypes.name, () => { it.each(typePairs)('should accept Collection $t1 to Collection $t2', ({ t1, t2 }: TypePair) => { const r = validateConnectionTypes( { name: t1, isCollection: true, isCollectionOrScalar: false }, - { name: t2, isCollection: false, isCollectionOrScalar: false } + { name: t2, isCollection: true, isCollectionOrScalar: false } ); expect(r).toBe(true); }); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts index 092279e315..778b33a7b1 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts @@ -40,18 +40,25 @@ export const validateConnectionTypes = (sourceType: FieldType, targetType: Field const isCollectionToGenericCollection = targetType.name === 'CollectionField' && sourceType.isCollection; - const areBothTypesSingle = - !sourceType.isCollection && - !sourceType.isCollectionOrScalar && - !targetType.isCollection && - !targetType.isCollectionOrScalar; + const isSourceScalar = !sourceType.isCollection && !sourceType.isCollectionOrScalar; + const isTargetScalar = !targetType.isCollection && !targetType.isCollectionOrScalar; + const isScalarToScalar = isSourceScalar && isTargetScalar; + const isScalarToCollectionOrScalar = isSourceScalar && targetType.isCollectionOrScalar; + const isCollectionToCollection = sourceType.isCollection && targetType.isCollection; + const isCollectionToCollectionOrScalar = sourceType.isCollection && targetType.isCollectionOrScalar; + const isCollectionOrScalarToCollectionOrScalar = sourceType.isCollectionOrScalar && targetType.isCollectionOrScalar; + const isPluralityMatch = + isScalarToScalar || + isCollectionToCollection || + isCollectionToCollectionOrScalar || + isCollectionOrScalarToCollectionOrScalar || + isScalarToCollectionOrScalar; - const isIntToFloat = areBothTypesSingle && sourceType.name === 'IntegerField' && targetType.name === 'FloatField'; + const isIntToFloat = sourceType.name === 'IntegerField' && targetType.name === 'FloatField'; + const isIntToString = sourceType.name === 'IntegerField' && targetType.name === 'StringField'; + const isFloatToString = sourceType.name === 'FloatField' && targetType.name === 'StringField'; - const isIntOrFloatToString = - areBothTypesSingle && - (sourceType.name === 'IntegerField' || sourceType.name === 'FloatField') && - targetType.name === 'StringField'; + const isSubTypeMatch = isPluralityMatch && (isIntToFloat || isIntToString || isFloatToString); const isTargetAnyType = targetType.name === 'AnyField'; @@ -62,8 +69,7 @@ export const validateConnectionTypes = (sourceType: FieldType, targetType: Field isAnythingToCollectionOrScalarOfSameBaseType || isGenericCollectionToAnyCollectionOrCollectionOrScalar || isCollectionToGenericCollection || - isIntToFloat || - isIntOrFloatToString || + isSubTypeMatch || isTargetAnyType ); }; From 857889d1faf0cc2e7fd5b51d34ff3ac46b47b18e Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 00:56:16 +1000 Subject: [PATCH 026/207] tests(ui): coverage for getCollectItemType --- .../features/nodes/store/util/getCollectItemType.test.ts | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts index 93c63b6f41..7f0a96bf33 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts @@ -13,4 +13,10 @@ describe(getCollectItemType.name, () => { const result = getCollectItemType(templates, nodes, edges, n2.id); expect(result).toEqual({ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false }); }); + it('should return null if the collect node does not have any connections', () => { + const n1 = buildInvocationNode(position, collect); + const nodes = [n1]; + const result = getCollectItemType(templates, nodes, [], n1.id); + expect(result).toBeNull(); + }); }); From 972398d203ca0a7bc67477713419428d4dc148f8 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 01:04:58 +1000 Subject: [PATCH 027/207] tests(ui): add iterate to test schema --- .../features/nodes/store/util/testUtils.ts | 173 ++++++++++++++++++ 1 file changed, 173 insertions(+) diff --git a/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts b/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts index b68ff8bef6..470236a82e 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts @@ -442,10 +442,78 @@ export const img_resize: InvocationTemplate = { classification: 'stable', }; +const iterate: InvocationTemplate = { + title: 'Iterate', + type: 'iterate', + version: '1.1.0', + tags: [], + description: 'Iterates over a list of items', + outputType: 'iterate_output', + inputs: { + collection: { + name: 'collection', + title: 'Collection', + required: false, + description: 'The list of items to iterate over', + fieldKind: 'input', + input: 'connection', + ui_hidden: false, + ui_type: 'CollectionField', + type: { + name: 'CollectionField', + isCollection: true, + isCollectionOrScalar: false, + }, + }, + }, + outputs: { + item: { + fieldKind: 'output', + name: 'item', + title: 'Collection Item', + description: 'The item being iterated over', + type: { + name: 'CollectionItemField', + isCollection: false, + isCollectionOrScalar: false, + }, + ui_hidden: false, + ui_type: 'CollectionItemField', + }, + index: { + fieldKind: 'output', + name: 'index', + title: 'Index', + description: 'The index of the item', + type: { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + }, + ui_hidden: false, + }, + total: { + fieldKind: 'output', + name: 'total', + title: 'Total', + description: 'The total number of items', + type: { + name: 'IntegerField', + isCollection: false, + isCollectionOrScalar: false, + }, + ui_hidden: false, + }, + }, + useCache: true, + classification: 'stable', +}; + export const templates: Templates = { add, sub, collect, + iterate, scheduler, main_model_loader, img_resize, @@ -1410,6 +1478,111 @@ export const schema = { title: 'BoardField', description: 'A board primitive field', }, + IterateInvocation: { + properties: { + id: { + type: 'string', + title: 'Id', + description: 'The id of this instance of an invocation. Must be unique among all instances of invocations.', + field_kind: 'node_attribute', + }, + is_intermediate: { + type: 'boolean', + title: 'Is Intermediate', + description: 'Whether or not this is an intermediate invocation.', + default: false, + field_kind: 'node_attribute', + ui_type: 'IsIntermediate', + }, + use_cache: { + type: 'boolean', + title: 'Use Cache', + description: 'Whether or not to use the cache', + default: true, + field_kind: 'node_attribute', + }, + collection: { + items: {}, + type: 'array', + title: 'Collection', + description: 'The list of items to iterate over', + default: [], + field_kind: 'input', + input: 'any', + orig_default: [], + orig_required: false, + ui_hidden: false, + ui_type: 'CollectionField', + }, + index: { + type: 'integer', + title: 'Index', + description: 'The index, will be provided on executed iterators', + default: 0, + field_kind: 'input', + input: 'any', + orig_default: 0, + orig_required: false, + ui_hidden: true, + }, + type: { + type: 'string', + enum: ['iterate'], + const: 'iterate', + title: 'type', + default: 'iterate', + field_kind: 'node_attribute', + }, + }, + type: 'object', + required: ['type', 'id'], + title: 'IterateInvocation', + description: 'Iterates over a list of items', + classification: 'stable', + version: '1.1.0', + output: { + $ref: '#/components/schemas/IterateInvocationOutput', + }, + class: 'invocation', + }, + IterateInvocationOutput: { + description: 'Used to connect iteration outputs. Will be expanded to a specific output.', + properties: { + item: { + description: 'The item being iterated over', + field_kind: 'output', + title: 'Collection Item', + ui_hidden: false, + ui_type: 'CollectionItemField', + }, + index: { + description: 'The index of the item', + field_kind: 'output', + title: 'Index', + type: 'integer', + ui_hidden: false, + }, + total: { + description: 'The total number of items', + field_kind: 'output', + title: 'Total', + type: 'integer', + ui_hidden: false, + }, + type: { + const: 'iterate_output', + default: 'iterate_output', + enum: ['iterate_output'], + field_kind: 'node_attribute', + title: 'type', + type: 'string', + }, + }, + required: ['item', 'index', 'total', 'type', 'type'], + title: 'IterateInvocationOutput', + type: 'object', + class: 'output', + }, }, }, } as OpenAPIV3_1.Document; From 78f9f3ee95bbdd3472bb0e517eaccb7278f8df6b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 01:10:49 +1000 Subject: [PATCH 028/207] feat(ui): better types for validateConnection --- .../nodes/store/util/validateConnection.ts | 32 ++++++++++++++----- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts index debf294557..b6b5a43d37 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts @@ -6,13 +6,19 @@ import { validateConnectionTypes } from 'features/nodes/store/util/validateConne import type { AnyNode } from 'features/nodes/types/invocation'; import type { Connection as NullableConnection, Edge } from 'reactflow'; import type { O } from 'ts-toolbelt'; +import { assert } from 'tsafe'; type Connection = O.NonNullable; -export type ValidateConnectionResult = { - isValid: boolean; - messageTKey?: string; -}; +export type ValidateConnectionResult = + | { + isValid: true; + messageTKey?: string; + } + | { + isValid: false; + messageTKey: string; + }; export type ValidateConnectionFunc = ( connection: Connection, @@ -22,10 +28,20 @@ export type ValidateConnectionFunc = ( ignoreEdge: Edge | null ) => ValidateConnectionResult; -export const buildResult = (isValid: boolean, messageTKey?: string): ValidateConnectionResult => ({ - isValid, - messageTKey, -}); +export const buildResult = (isValid: boolean, messageTKey?: string): ValidateConnectionResult => { + if (isValid) { + return { + isValid, + messageTKey, + }; + } else { + assert(messageTKey !== undefined); + return { + isValid, + messageTKey, + }; + } +}; const getEqualityPredicate = (c: Connection) => From 6ad01d824d689319232bd6654bf722962766cfa0 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 01:26:43 +1000 Subject: [PATCH 029/207] feat(ui): add strict mode to validateConnection --- .../store/util/validateConnection.test.ts | 26 ++++ .../nodes/store/util/validateConnection.ts | 127 +++++++++--------- 2 files changed, 91 insertions(+), 62 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts index cf05b4deb6..108839a499 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts @@ -166,4 +166,30 @@ describe(validateConnection.name, () => { const r = validateConnection(c, nodes, edges, templates, null); expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle')); }); + + describe('non-strict mode', () => { + it('should reject connections from self to self in non-strict mode', () => { + const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' }; + const r = validateConnection(c, [], [], templates, null, false); + expect(r).toEqual(buildRejectResult('nodes.cannotConnectToSelf')); + }); + it('should reject connections that create cycles in non-strict mode', () => { + const n1 = buildInvocationNode(position, add); + const n2 = buildInvocationNode(position, sub); + const nodes = [n1, n2]; + const e1 = buildEdge(n1.id, 'value', n2.id, 'a'); + const edges = [e1]; + const c = { source: n2.id, sourceHandle: 'value', target: n1.id, targetHandle: 'a' }; + const r = validateConnection(c, nodes, edges, templates, null, false); + expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle')); + }); + it('should otherwise allow invalid connections in non-strict mode', () => { + const n1 = buildInvocationNode(position, add); + const n2 = buildInvocationNode(position, img_resize); + const nodes = [n1, n2]; + const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'image' }; + const r = validateConnection(c, nodes, [], templates, null, false); + expect(r).toEqual(buildAcceptResult()); + }); + }); }); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts index b6b5a43d37..edb8ac5ecb 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts @@ -25,7 +25,8 @@ export type ValidateConnectionFunc = ( nodes: AnyNode[], edges: Edge[], templates: Templates, - ignoreEdge: Edge | null + ignoreEdge: Edge | null, + strict?: boolean ) => ValidateConnectionResult; export const buildResult = (isValid: boolean, messageTKey?: string): ValidateConnectionResult => { @@ -63,76 +64,78 @@ const getTargetEqualityPredicate = export const buildAcceptResult = (): ValidateConnectionResult => ({ isValid: true }); export const buildRejectResult = (messageTKey: string): ValidateConnectionResult => ({ isValid: false, messageTKey }); -export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, templates, ignoreEdge) => { +export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, templates, ignoreEdge, strict = true) => { if (c.source === c.target) { return buildRejectResult('nodes.cannotConnectToSelf'); } - /** - * We may need to ignore an edge when validating a connection. - * - * For example, while an edge is being updated, it still exists in the array of edges. As we validate the new connection, - * the user experience should be that the edge is temporarily removed from the graph, so we need to ignore it, else - * the validation will fail unexpectedly. - */ - const filteredEdges = edges.filter((e) => e.id !== ignoreEdge?.id); + if (strict) { + /** + * We may need to ignore an edge when validating a connection. + * + * For example, while an edge is being updated, it still exists in the array of edges. As we validate the new connection, + * the user experience should be that the edge is temporarily removed from the graph, so we need to ignore it, else + * the validation will fail unexpectedly. + */ + const filteredEdges = edges.filter((e) => e.id !== ignoreEdge?.id); - if (filteredEdges.some(getEqualityPredicate(c))) { - // We already have a connection from this source to this target - return buildRejectResult('nodes.cannotDuplicateConnection'); - } - - const sourceNode = nodes.find((n) => n.id === c.source); - if (!sourceNode) { - return buildRejectResult('nodes.missingNode'); - } - - const targetNode = nodes.find((n) => n.id === c.target); - if (!targetNode) { - return buildRejectResult('nodes.missingNode'); - } - - const sourceTemplate = templates[sourceNode.data.type]; - if (!sourceTemplate) { - return buildRejectResult('nodes.missingInvocationTemplate'); - } - - const targetTemplate = templates[targetNode.data.type]; - if (!targetTemplate) { - return buildRejectResult('nodes.missingInvocationTemplate'); - } - - const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle]; - if (!sourceFieldTemplate) { - return buildRejectResult('nodes.missingFieldTemplate'); - } - - const targetFieldTemplate = targetTemplate.inputs[c.targetHandle]; - if (!targetFieldTemplate) { - return buildRejectResult('nodes.missingFieldTemplate'); - } - - if (targetFieldTemplate.input === 'direct') { - return buildRejectResult('nodes.cannotConnectToDirectInput'); - } - - if (targetNode.data.type === 'collect' && c.targetHandle === 'item') { - // Collect nodes shouldn't mix and match field types. - const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); - if (collectItemType && !areTypesEqual(sourceFieldTemplate.type, collectItemType)) { - return buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes'); + if (filteredEdges.some(getEqualityPredicate(c))) { + // We already have a connection from this source to this target + return buildRejectResult('nodes.cannotDuplicateConnection'); } - } - if (filteredEdges.find(getTargetEqualityPredicate(c))) { - // CollectionItemField inputs can have multiple input connections - if (targetFieldTemplate.type.name !== 'CollectionItemField') { - return buildRejectResult('nodes.inputMayOnlyHaveOneConnection'); + const sourceNode = nodes.find((n) => n.id === c.source); + if (!sourceNode) { + return buildRejectResult('nodes.missingNode'); } - } - if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) { - return buildRejectResult('nodes.fieldTypesMustMatch'); + const targetNode = nodes.find((n) => n.id === c.target); + if (!targetNode) { + return buildRejectResult('nodes.missingNode'); + } + + const sourceTemplate = templates[sourceNode.data.type]; + if (!sourceTemplate) { + return buildRejectResult('nodes.missingInvocationTemplate'); + } + + const targetTemplate = templates[targetNode.data.type]; + if (!targetTemplate) { + return buildRejectResult('nodes.missingInvocationTemplate'); + } + + const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle]; + if (!sourceFieldTemplate) { + return buildRejectResult('nodes.missingFieldTemplate'); + } + + const targetFieldTemplate = targetTemplate.inputs[c.targetHandle]; + if (!targetFieldTemplate) { + return buildRejectResult('nodes.missingFieldTemplate'); + } + + if (targetFieldTemplate.input === 'direct') { + return buildRejectResult('nodes.cannotConnectToDirectInput'); + } + + if (targetNode.data.type === 'collect' && c.targetHandle === 'item') { + // Collect nodes shouldn't mix and match field types. + const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); + if (collectItemType && !areTypesEqual(sourceFieldTemplate.type, collectItemType)) { + return buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes'); + } + } + + if (filteredEdges.find(getTargetEqualityPredicate(c))) { + // CollectionItemField inputs can have multiple input connections + if (targetFieldTemplate.type.name !== 'CollectionItemField') { + return buildRejectResult('nodes.inputMayOnlyHaveOneConnection'); + } + } + + if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) { + return buildRejectResult('nodes.fieldTypesMustMatch'); + } } if (getHasCycles(c.source, c.target, nodes, edges)) { From fc31dddbf7c1deb1c10b621bc21281e5bead39d4 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 01:27:25 +1000 Subject: [PATCH 030/207] feat(ui): use new validateConnection --- .../nodes/hooks/useConnectionState.ts | 9 +- .../nodes/hooks/useIsValidConnection.ts | 84 ++--------- .../nodes/store/util/connectionValidation.ts | 134 ------------------ .../store/util/makeConnectionErrorSelector.ts | 72 ++++++++++ 4 files changed, 88 insertions(+), 211 deletions(-) delete mode 100644 invokeai/frontend/web/src/features/nodes/store/util/connectionValidation.ts create mode 100644 invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts index 9571ce2ee2..5dcb7a28b5 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts @@ -2,11 +2,9 @@ import { useStore } from '@nanostores/react'; import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { $pendingConnection, $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { makeConnectionErrorSelector } from 'features/nodes/store/util/connectionValidation.js'; +import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeConnectionErrorSelector'; import { useMemo } from 'react'; -import { useFieldType } from './useFieldType.ts'; - type UseConnectionStateProps = { nodeId: string; fieldName: string; @@ -16,7 +14,6 @@ type UseConnectionStateProps = { export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => { const pendingConnection = useStore($pendingConnection); const templates = useStore($templates); - const fieldType = useFieldType(nodeId, fieldName, kind); const selectIsConnected = useMemo( () => @@ -34,8 +31,8 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta ); const selectConnectionError = useMemo( - () => makeConnectionErrorSelector(templates, nodeId, fieldName, kind === 'inputs' ? 'target' : 'source', fieldType), - [templates, nodeId, fieldName, kind, fieldType] + () => makeConnectionErrorSelector(templates, nodeId, fieldName, kind === 'inputs' ? 'target' : 'source'), + [templates, nodeId, fieldName, kind] ); const isConnected = useAppSelector(selectIsConnected); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts index 77c4e3c75b..0f8609d2ff 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts @@ -2,13 +2,9 @@ import { useStore } from '@nanostores/react'; import { useAppSelector, useAppStore } from 'app/store/storeHooks'; import { $templates } from 'features/nodes/store/nodesSlice'; -import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual'; -import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType'; -import { getHasCycles } from 'features/nodes/store/util/getHasCycles'; -import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; -import type { InvocationNodeData } from 'features/nodes/types/invocation'; +import { validateConnection } from 'features/nodes/store/util/validateConnection'; import { useCallback } from 'react'; -import type { Connection, Node } from 'reactflow'; +import type { Connection } from 'reactflow'; /** * NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts` @@ -26,74 +22,20 @@ export const useIsValidConnection = () => { return false; } - if (source === target) { - // Don't allow nodes to connect to themselves, even if validation is disabled - return false; - } + const { nodes, edges } = store.getState().nodes.present; - const state = store.getState(); - const { nodes, edges } = state.nodes.present; + const validationResult = validateConnection( + { source, sourceHandle, target, targetHandle }, + nodes, + edges, + templates, + null, + shouldValidateGraph + ); - // Find the source and target nodes - const sourceNode = nodes.find((node) => node.id === source) as Node; - const targetNode = nodes.find((node) => node.id === target) as Node; - const sourceFieldTemplate = templates[sourceNode.data.type]?.outputs[sourceHandle]; - const targetFieldTemplate = templates[targetNode.data.type]?.inputs[targetHandle]; - - // Conditional guards against undefined nodes/handles - if (!(sourceFieldTemplate && targetFieldTemplate)) { - return false; - } - - if (targetFieldTemplate.input === 'direct') { - return false; - } - - if (!shouldValidateGraph) { - // manual override! - return true; - } - - if ( - edges.find((edge) => { - edge.target === target && - edge.targetHandle === targetHandle && - edge.source === source && - edge.sourceHandle === sourceHandle; - }) - ) { - // We already have a connection from this source to this target - return false; - } - - if (targetNode.data.type === 'collect' && targetFieldTemplate.name === 'item') { - // Collect nodes shouldn't mix and match field types - const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); - if (collectItemType) { - return areTypesEqual(sourceFieldTemplate.type, collectItemType); - } - } - - // Connection is invalid if target already has a connection - if ( - edges.find((edge) => { - return edge.target === target && edge.targetHandle === targetHandle; - }) && - // except CollectionItem inputs can have multiples - targetFieldTemplate.type.name !== 'CollectionItemField' - ) { - return false; - } - - // Must use the originalType here if it exists - if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) { - return false; - } - - // Graphs much be acyclic (no loops!) - return !getHasCycles(source, target, nodes, edges); + return validationResult.isValid; }, - [shouldValidateGraph, templates, store] + [templates, shouldValidateGraph, store] ); return isValidConnection; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/connectionValidation.ts b/invokeai/frontend/web/src/features/nodes/store/util/connectionValidation.ts deleted file mode 100644 index 7819221f8a..0000000000 --- a/invokeai/frontend/web/src/features/nodes/store/util/connectionValidation.ts +++ /dev/null @@ -1,134 +0,0 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; -import type { RootState } from 'app/store/store'; -import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types'; -import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; -import type { FieldType } from 'features/nodes/types/field'; -import i18n from 'i18next'; -import type { HandleType } from 'reactflow'; -import { assert } from 'tsafe'; - -import { areTypesEqual } from './areTypesEqual'; -import { getCollectItemType } from './getCollectItemType'; -import { getHasCycles } from './getHasCycles'; - -/** - * Creates a selector that validates a pending connection. - * - * NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts` - * TODO: Figure out how to do this without duplicating all the logic - * - * @param templates The invocation templates - * @param pendingConnection The current pending connection (if there is one) - * @param nodeId The id of the node for which the selector is being created - * @param fieldName The name of the field for which the selector is being created - * @param handleType The type of the handle for which the selector is being created - * @param fieldType The type of the field for which the selector is being created - * @returns - */ -export const makeConnectionErrorSelector = ( - templates: Templates, - nodeId: string, - fieldName: string, - handleType: HandleType, - fieldType: FieldType -) => { - return createMemoizedSelector( - selectNodesSlice, - (state: RootState, pendingConnection: PendingConnection | null) => pendingConnection, - (nodesSlice: NodesState, pendingConnection: PendingConnection | null) => { - const { nodes, edges } = nodesSlice; - - if (!pendingConnection) { - return i18n.t('nodes.noConnectionInProgress'); - } - - const connectionNodeId = pendingConnection.node.id; - const connectionFieldName = pendingConnection.fieldTemplate.name; - const connectionHandleType = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source'; - const connectionStartFieldType = pendingConnection.fieldTemplate.type; - - if (!connectionHandleType || !connectionNodeId || !connectionFieldName) { - return i18n.t('nodes.noConnectionData'); - } - - const targetType = handleType === 'target' ? fieldType : connectionStartFieldType; - const sourceType = handleType === 'source' ? fieldType : connectionStartFieldType; - - if (nodeId === connectionNodeId) { - return i18n.t('nodes.cannotConnectToSelf'); - } - - if (handleType === connectionHandleType) { - if (handleType === 'source') { - return i18n.t('nodes.cannotConnectOutputToOutput'); - } - return i18n.t('nodes.cannotConnectInputToInput'); - } - - // we have to figure out which is the target and which is the source - const targetNodeId = handleType === 'target' ? nodeId : connectionNodeId; - const targetFieldName = handleType === 'target' ? fieldName : connectionFieldName; - const sourceNodeId = handleType === 'source' ? nodeId : connectionNodeId; - const sourceFieldName = handleType === 'source' ? fieldName : connectionFieldName; - - if ( - edges.find((edge) => { - edge.target === targetNodeId && - edge.targetHandle === targetFieldName && - edge.source === sourceNodeId && - edge.sourceHandle === sourceFieldName; - }) - ) { - // We already have a connection from this source to this target - return i18n.t('nodes.cannotDuplicateConnection'); - } - - const targetNode = nodes.find((node) => node.id === targetNodeId); - assert(targetNode, `Target node not found: ${targetNodeId}`); - const targetTemplate = templates[targetNode.data.type]; - assert(targetTemplate, `Target template not found: ${targetNode.data.type}`); - - if (targetTemplate.inputs[targetFieldName]?.input === 'direct') { - return i18n.t('nodes.cannotConnectToDirectInput'); - } - - if (targetNode.data.type === 'collect' && targetFieldName === 'item') { - // Collect nodes shouldn't mix and match field types - const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); - if (collectItemType) { - if (!areTypesEqual(sourceType, collectItemType)) { - return i18n.t('nodes.cannotMixAndMatchCollectionItemTypes'); - } - } - } - - if ( - edges.find((edge) => { - return edge.target === targetNodeId && edge.targetHandle === targetFieldName; - }) && - // except CollectionItem inputs can have multiples - targetType.name !== 'CollectionItemField' - ) { - return i18n.t('nodes.inputMayOnlyHaveOneConnection'); - } - - if (!validateConnectionTypes(sourceType, targetType)) { - return i18n.t('nodes.fieldTypesMustMatch'); - } - - const hasCycles = getHasCycles( - connectionHandleType === 'source' ? connectionNodeId : nodeId, - connectionHandleType === 'source' ? nodeId : connectionNodeId, - nodes, - edges - ); - - if (hasCycles) { - return i18n.t('nodes.connectionWouldCreateCycle'); - } - - return; - } - ); -}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts new file mode 100644 index 0000000000..3cefb6815f --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts @@ -0,0 +1,72 @@ +import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import type { RootState } from 'app/store/store'; +import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; +import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types'; +import { validateConnection } from 'features/nodes/store/util/validateConnection'; +import i18n from 'i18next'; +import type { HandleType } from 'reactflow'; + +/** + * Creates a selector that validates a pending connection. + * + * NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts` + * TODO: Figure out how to do this without duplicating all the logic + * + * @param templates The invocation templates + * @param nodeId The id of the node for which the selector is being created + * @param fieldName The name of the field for which the selector is being created + * @param handleType The type of the handle for which the selector is being created + * @returns + */ +export const makeConnectionErrorSelector = ( + templates: Templates, + nodeId: string, + fieldName: string, + handleType: HandleType +) => { + return createMemoizedSelector( + selectNodesSlice, + (state: RootState, pendingConnection: PendingConnection | null) => pendingConnection, + (nodesSlice: NodesState, pendingConnection: PendingConnection | null) => { + const { nodes, edges } = nodesSlice; + + if (!pendingConnection) { + return i18n.t('nodes.noConnectionInProgress'); + } + + const connectionNodeId = pendingConnection.node.id; + const connectionFieldName = pendingConnection.fieldTemplate.name; + const connectionHandleType = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source'; + + if (handleType === connectionHandleType) { + if (handleType === 'source') { + return i18n.t('nodes.cannotConnectOutputToOutput'); + } + return i18n.t('nodes.cannotConnectInputToInput'); + } + + // we have to figure out which is the target and which is the source + const source = handleType === 'source' ? nodeId : connectionNodeId; + const sourceHandle = handleType === 'source' ? fieldName : connectionFieldName; + const target = handleType === 'target' ? nodeId : connectionNodeId; + const targetHandle = handleType === 'target' ? fieldName : connectionFieldName; + + const validationResult = validateConnection( + { + source, + sourceHandle, + target, + targetHandle, + }, + nodes, + edges, + templates, + null + ); + + if (!validationResult.isValid) { + return i18n.t(validationResult.messageTKey); + } + } + ); +}; From 3605b6b1a36c27ef2299de6892ae0aff424a3c68 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 01:37:54 +1000 Subject: [PATCH 031/207] fix(ui): handling for in-progress edge updates during conection validation --- .../web/src/features/nodes/components/flow/Flow.tsx | 10 +++++----- .../web/src/features/nodes/hooks/useConnection.ts | 9 +++++---- .../web/src/features/nodes/hooks/useConnectionState.ts | 5 +++-- .../src/features/nodes/hooks/useIsValidConnection.ts | 6 +++--- .../web/src/features/nodes/store/nodesSlice.ts | 2 +- .../nodes/store/util/getFirstValidConnection.ts | 10 ++++++---- .../nodes/store/util/makeConnectionErrorSelector.ts | 8 +++++--- 7 files changed, 28 insertions(+), 22 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx index 501513919a..18bbac0b44 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx @@ -9,8 +9,8 @@ import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher'; import { $cursorPos, $didUpdateEdge, + $edgePendingUpdate, $isAddNodePopoverOpen, - $isUpdatingEdge, $lastEdgeUpdateMouseEvent, $pendingConnection, $viewport, @@ -160,8 +160,8 @@ export const Flow = memo(() => { * where the edge is deleted if you click it accidentally). */ - const onEdgeUpdateStart: NonNullable = useCallback((e, _edge, _handleType) => { - $isUpdatingEdge.set(true); + const onEdgeUpdateStart: NonNullable = useCallback((e, edge, _handleType) => { + $edgePendingUpdate.set(edge); $didUpdateEdge.set(false); $lastEdgeUpdateMouseEvent.set(e); }, []); @@ -196,7 +196,7 @@ export const Flow = memo(() => { dispatch(edgeDeleted(edge.id)); } - $isUpdatingEdge.set(false); + $edgePendingUpdate.set(null); $didUpdateEdge.set(false); $pendingConnection.set(null); $lastEdgeUpdateMouseEvent.set(null); @@ -259,7 +259,7 @@ export const Flow = memo(() => { useHotkeys(['meta+shift+z', 'ctrl+shift+z'], onRedoHotkey); const onEscapeHotkey = useCallback(() => { - if (!$isUpdatingEdge.get()) { + if (!$edgePendingUpdate.get()) { $pendingConnection.set(null); $isAddNodePopoverOpen.set(false); cancelConnection(); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts index 0190a0b29e..d81a9e5807 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts @@ -2,8 +2,8 @@ import { useStore } from '@nanostores/react'; import { useAppStore } from 'app/store/storeHooks'; import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode'; import { + $edgePendingUpdate, $isAddNodePopoverOpen, - $isUpdatingEdge, $pendingConnection, $templates, connectionMade, @@ -52,12 +52,12 @@ export const useConnection = () => { const onConnectEnd = useCallback(() => { const { dispatch } = store; const pendingConnection = $pendingConnection.get(); - const isUpdatingEdge = $isUpdatingEdge.get(); + const edgePendingUpdate = $edgePendingUpdate.get(); const mouseOverNodeId = $mouseOverNode.get(); // If we are in the middle of an edge update, and the mouse isn't over a node, we should just bail so the edge // update logic can finish up - if (isUpdatingEdge && !mouseOverNodeId) { + if (edgePendingUpdate && !mouseOverNodeId) { $pendingConnection.set(null); return; } @@ -80,7 +80,8 @@ export const useConnection = () => { edges, pendingConnection, candidateNode, - candidateTemplate + candidateTemplate, + edgePendingUpdate ); if (connection) { dispatch(connectionMade(connection)); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts index 5dcb7a28b5..7649209863 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts @@ -1,7 +1,7 @@ import { useStore } from '@nanostores/react'; import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; -import { $pendingConnection, $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice'; +import { $edgePendingUpdate, $pendingConnection, $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice'; import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeConnectionErrorSelector'; import { useMemo } from 'react'; @@ -14,6 +14,7 @@ type UseConnectionStateProps = { export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionStateProps) => { const pendingConnection = useStore($pendingConnection); const templates = useStore($templates); + const edgePendingUpdate = useStore($edgePendingUpdate); const selectIsConnected = useMemo( () => @@ -47,7 +48,7 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind] ); }, [fieldName, kind, nodeId, pendingConnection]); - const connectionError = useAppSelector((s) => selectConnectionError(s, pendingConnection)); + const connectionError = useAppSelector((s) => selectConnectionError(s, pendingConnection, edgePendingUpdate)); const shouldDim = useMemo( () => Boolean(isConnectionInProgress && connectionError && !isConnectionStartField), diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts index 0f8609d2ff..9a978b09a8 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts @@ -1,7 +1,7 @@ // TODO: enable this at some point import { useStore } from '@nanostores/react'; import { useAppSelector, useAppStore } from 'app/store/storeHooks'; -import { $templates } from 'features/nodes/store/nodesSlice'; +import { $edgePendingUpdate, $templates } from 'features/nodes/store/nodesSlice'; import { validateConnection } from 'features/nodes/store/util/validateConnection'; import { useCallback } from 'react'; import type { Connection } from 'reactflow'; @@ -21,7 +21,7 @@ export const useIsValidConnection = () => { if (!(source && sourceHandle && target && targetHandle)) { return false; } - + const edgePendingUpdate = $edgePendingUpdate.get(); const { nodes, edges } = store.getState().nodes.present; const validationResult = validateConnection( @@ -29,7 +29,7 @@ export const useIsValidConnection = () => { nodes, edges, templates, - null, + edgePendingUpdate, shouldValidateGraph ); diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 83632c16e1..7915d3608c 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -503,7 +503,7 @@ export const $copiedNodes = atom([]); export const $copiedEdges = atom([]); export const $edgesToCopiedNodes = atom([]); export const $pendingConnection = atom(null); -export const $isUpdatingEdge = atom(false); +export const $edgePendingUpdate = atom(null); export const $didUpdateEdge = atom(false); export const $lastEdgeUpdateMouseEvent = atom(null); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts index 98155f0c20..00899c065d 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts @@ -2,7 +2,7 @@ import type { PendingConnection, Templates } from 'features/nodes/store/types'; import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation'; import { differenceWith, map } from 'lodash-es'; -import type { Connection } from 'reactflow'; +import type { Connection, Edge } from 'reactflow'; import { assert } from 'tsafe'; import { areTypesEqual } from './areTypesEqual'; @@ -26,7 +26,8 @@ export const getFirstValidConnection = ( edges: InvocationNodeEdge[], pendingConnection: PendingConnection, candidateNode: InvocationNode, - candidateTemplate: InvocationTemplate + candidateTemplate: InvocationTemplate, + edgePendingUpdate: Edge | null ): Connection | null => { if (pendingConnection.node.id === candidateNode.id) { // Cannot connect to self @@ -52,7 +53,7 @@ export const getFirstValidConnection = ( // Only one connection per target field is allowed - look for an unconnected target field const candidateFields = map(candidateTemplate.inputs); const candidateConnectedFields = edges - .filter((edge) => edge.target === candidateNode.id) + .filter((edge) => edge.target === candidateNode.id || edge.id === edgePendingUpdate?.id) .map((edge) => { // Edges must always have a targetHandle, safe to assert here assert(edge.targetHandle); @@ -63,7 +64,8 @@ export const getFirstValidConnection = ( candidateConnectedFields, (field, connectedFieldName) => field.name === connectedFieldName ); - const candidateField = candidateUnconnectedFields.find((field) => validateConnectionTypes(pendingConnection.fieldTemplate.type, field.type) + const candidateField = candidateUnconnectedFields.find((field) => + validateConnectionTypes(pendingConnection.fieldTemplate.type, field.type) ); if (candidateField) { return { diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts index 3cefb6815f..fb7ed49d41 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts @@ -4,7 +4,7 @@ import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types'; import { validateConnection } from 'features/nodes/store/util/validateConnection'; import i18n from 'i18next'; -import type { HandleType } from 'reactflow'; +import type { Edge, HandleType } from 'reactflow'; /** * Creates a selector that validates a pending connection. @@ -27,7 +27,9 @@ export const makeConnectionErrorSelector = ( return createMemoizedSelector( selectNodesSlice, (state: RootState, pendingConnection: PendingConnection | null) => pendingConnection, - (nodesSlice: NodesState, pendingConnection: PendingConnection | null) => { + (state: RootState, pendingConnection: PendingConnection | null, edgePendingUpdate: Edge | null) => + edgePendingUpdate, + (nodesSlice: NodesState, pendingConnection: PendingConnection | null, edgePendingUpdate: Edge | null) => { const { nodes, edges } = nodesSlice; if (!pendingConnection) { @@ -61,7 +63,7 @@ export const makeConnectionErrorSelector = ( nodes, edges, templates, - null + edgePendingUpdate ); if (!validationResult.isValid) { From ea97ae5ae8ee31b4cb0718b63db10b61c73aef76 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 07:45:53 +1000 Subject: [PATCH 032/207] tidy(ui): extraneous vars in makeConnectionErrorSelector --- .../nodes/store/util/makeConnectionErrorSelector.ts | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts index fb7ed49d41..e1a443a60e 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts @@ -36,8 +36,6 @@ export const makeConnectionErrorSelector = ( return i18n.t('nodes.noConnectionInProgress'); } - const connectionNodeId = pendingConnection.node.id; - const connectionFieldName = pendingConnection.fieldTemplate.name; const connectionHandleType = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source'; if (handleType === connectionHandleType) { @@ -48,10 +46,10 @@ export const makeConnectionErrorSelector = ( } // we have to figure out which is the target and which is the source - const source = handleType === 'source' ? nodeId : connectionNodeId; - const sourceHandle = handleType === 'source' ? fieldName : connectionFieldName; - const target = handleType === 'target' ? nodeId : connectionNodeId; - const targetHandle = handleType === 'target' ? fieldName : connectionFieldName; + const source = handleType === 'source' ? nodeId : pendingConnection.node.id; + const sourceHandle = handleType === 'source' ? fieldName : pendingConnection.fieldTemplate.name; + const target = handleType === 'target' ? nodeId : pendingConnection.node.id; + const targetHandle = handleType === 'target' ? fieldName : pendingConnection.fieldTemplate.name; const validationResult = validateConnection( { From fe3980a3698e315f82b3929087ba093c8c411489 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 07:49:03 +1000 Subject: [PATCH 033/207] tests(ui): add buildNode convenience wrapper for buildInvocationNode --- .../store/util/getCollectItemType.test.ts | 9 ++- .../nodes/store/util/getHasCycles.test.ts | 9 ++- .../features/nodes/store/util/testUtils.ts | 3 + .../store/util/validateConnection.test.ts | 63 +++++++++---------- 4 files changed, 42 insertions(+), 42 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts index 7f0a96bf33..2fed41c14c 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts @@ -1,20 +1,19 @@ import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType'; -import { add, buildEdge, collect, position, templates } from 'features/nodes/store/util/testUtils'; +import { add, buildEdge, buildNode, collect, templates } from 'features/nodes/store/util/testUtils'; import type { FieldType } from 'features/nodes/types/field'; -import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode'; import { describe, expect, it } from 'vitest'; describe(getCollectItemType.name, () => { it('should return the type of the items the collect node collects', () => { - const n1 = buildInvocationNode(position, add); - const n2 = buildInvocationNode(position, collect); + const n1 = buildNode(add); + const n2 = buildNode(collect); const nodes = [n1, n2]; const edges = [buildEdge(n1.id, 'value', n2.id, 'item')]; const result = getCollectItemType(templates, nodes, edges, n2.id); expect(result).toEqual({ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false }); }); it('should return null if the collect node does not have any connections', () => { - const n1 = buildInvocationNode(position, collect); + const n1 = buildNode(collect); const nodes = [n1]; const result = getCollectItemType(templates, nodes, [], n1.id); expect(result).toBeNull(); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getHasCycles.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/getHasCycles.test.ts index 872da36998..5b3a31de09 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/getHasCycles.test.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/getHasCycles.test.ts @@ -1,12 +1,11 @@ import { getHasCycles } from 'features/nodes/store/util/getHasCycles'; -import { add, buildEdge, position } from 'features/nodes/store/util/testUtils'; -import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode'; +import { add, buildEdge, buildNode } from 'features/nodes/store/util/testUtils'; import { describe, expect, it } from 'vitest'; describe(getHasCycles.name, () => { - const n1 = buildInvocationNode(position, add); - const n2 = buildInvocationNode(position, add); - const n3 = buildInvocationNode(position, add); + const n1 = buildNode(add); + const n2 = buildNode(add); + const n3 = buildNode(add); const nodes = [n1, n2, n3]; it('should return true if the graph WOULD have cycles after adding the edge', () => { diff --git a/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts b/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts index 470236a82e..f351083bc5 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts @@ -1,5 +1,6 @@ import type { Templates } from 'features/nodes/store/types'; import type { InvocationTemplate } from 'features/nodes/types/invocation'; +import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode'; import type { OpenAPIV3_1 } from 'openapi-types'; import type { Edge, XYPosition } from 'reactflow'; @@ -14,6 +15,8 @@ export const buildEdge = (source: string, sourceHandle: string, target: string, export const position: XYPosition = { x: 0, y: 0 }; +export const buildNode = (template: InvocationTemplate) => buildInvocationNode({ x: 0, y: 0 }, template); + export const add: InvocationTemplate = { title: 'Add Integers', type: 'add', diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts index 108839a499..19035afd54 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts @@ -1,9 +1,8 @@ import { deepClone } from 'common/util/deepClone'; -import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode'; import { set } from 'lodash-es'; import { describe, expect, it } from 'vitest'; -import { add, buildEdge, collect, img_resize, main_model_loader, position, sub, templates } from './testUtils'; +import { add, buildEdge, buildNode, collect, img_resize, main_model_loader, sub, templates } from './testUtils'; import { buildAcceptResult, buildRejectResult, validateConnection } from './validateConnection'; describe(validateConnection.name, () => { @@ -14,8 +13,8 @@ describe(validateConnection.name, () => { }); describe('missing nodes', () => { - const n1 = buildInvocationNode(position, add); - const n2 = buildInvocationNode(position, sub); + const n1 = buildNode(add); + const n2 = buildNode(sub); const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; it('should reject missing source node', () => { @@ -30,8 +29,8 @@ describe(validateConnection.name, () => { }); describe('missing invocation templates', () => { - const n1 = buildInvocationNode(position, add); - const n2 = buildInvocationNode(position, sub); + const n1 = buildNode(add); + const n2 = buildNode(sub); const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; const nodes = [n1, n2]; @@ -47,8 +46,8 @@ describe(validateConnection.name, () => { }); describe('missing field templates', () => { - const n1 = buildInvocationNode(position, add); - const n2 = buildInvocationNode(position, sub); + const n1 = buildNode(add); + const n2 = buildNode(sub); const nodes = [n1, n2]; it('should reject missing source field template', () => { @@ -65,8 +64,8 @@ describe(validateConnection.name, () => { }); describe('duplicate connections', () => { - const n1 = buildInvocationNode(position, add); - const n2 = buildInvocationNode(position, sub); + const n1 = buildNode(add); + const n2 = buildNode(sub); it('should accept non-duplicate connections', () => { const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; const r = validateConnection(c, [n1, n2], [], templates, null); @@ -92,17 +91,17 @@ describe(validateConnection.name, () => { set(addWithDirectAField, 'inputs.a.input', 'direct'); set(addWithDirectAField, 'type', 'addWithDirectAField'); - const n1 = buildInvocationNode(position, add); - const n2 = buildInvocationNode(position, addWithDirectAField); + const n1 = buildNode(add); + const n2 = buildNode(addWithDirectAField); const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; const r = validateConnection(c, [n1, n2], [], { add, addWithDirectAField }, null); expect(r).toEqual(buildRejectResult('nodes.cannotConnectToDirectInput')); }); it('should reject connection to a collect node with mismatched item types', () => { - const n1 = buildInvocationNode(position, add); - const n2 = buildInvocationNode(position, collect); - const n3 = buildInvocationNode(position, main_model_loader); + const n1 = buildNode(add); + const n2 = buildNode(collect); + const n3 = buildNode(main_model_loader); const nodes = [n1, n2, n3]; const e1 = buildEdge(n1.id, 'value', n2.id, 'item'); const edges = [e1]; @@ -112,9 +111,9 @@ describe(validateConnection.name, () => { }); it('should accept connection to a collect node with matching item types', () => { - const n1 = buildInvocationNode(position, add); - const n2 = buildInvocationNode(position, collect); - const n3 = buildInvocationNode(position, sub); + const n1 = buildNode(add); + const n2 = buildNode(collect); + const n3 = buildNode(sub); const nodes = [n1, n2, n3]; const e1 = buildEdge(n1.id, 'value', n2.id, 'item'); const edges = [e1]; @@ -124,9 +123,9 @@ describe(validateConnection.name, () => { }); it('should reject connections to target field that is already connected', () => { - const n1 = buildInvocationNode(position, add); - const n2 = buildInvocationNode(position, add); - const n3 = buildInvocationNode(position, add); + const n1 = buildNode(add); + const n2 = buildNode(add); + const n3 = buildNode(add); const nodes = [n1, n2, n3]; const e1 = buildEdge(n1.id, 'value', n2.id, 'a'); const edges = [e1]; @@ -136,9 +135,9 @@ describe(validateConnection.name, () => { }); it('should accept connections to target field that is already connected (ignored edge)', () => { - const n1 = buildInvocationNode(position, add); - const n2 = buildInvocationNode(position, add); - const n3 = buildInvocationNode(position, add); + const n1 = buildNode(add); + const n2 = buildNode(add); + const n3 = buildNode(add); const nodes = [n1, n2, n3]; const e1 = buildEdge(n1.id, 'value', n2.id, 'a'); const edges = [e1]; @@ -148,8 +147,8 @@ describe(validateConnection.name, () => { }); it('should reject connections between invalid types', () => { - const n1 = buildInvocationNode(position, add); - const n2 = buildInvocationNode(position, img_resize); + const n1 = buildNode(add); + const n2 = buildNode(img_resize); const nodes = [n1, n2]; const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'image' }; const r = validateConnection(c, nodes, [], templates, null); @@ -157,8 +156,8 @@ describe(validateConnection.name, () => { }); it('should reject connections that would create cycles', () => { - const n1 = buildInvocationNode(position, add); - const n2 = buildInvocationNode(position, sub); + const n1 = buildNode(add); + const n2 = buildNode(sub); const nodes = [n1, n2]; const e1 = buildEdge(n1.id, 'value', n2.id, 'a'); const edges = [e1]; @@ -174,8 +173,8 @@ describe(validateConnection.name, () => { expect(r).toEqual(buildRejectResult('nodes.cannotConnectToSelf')); }); it('should reject connections that create cycles in non-strict mode', () => { - const n1 = buildInvocationNode(position, add); - const n2 = buildInvocationNode(position, sub); + const n1 = buildNode(add); + const n2 = buildNode(sub); const nodes = [n1, n2]; const e1 = buildEdge(n1.id, 'value', n2.id, 'a'); const edges = [e1]; @@ -184,8 +183,8 @@ describe(validateConnection.name, () => { expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle')); }); it('should otherwise allow invalid connections in non-strict mode', () => { - const n1 = buildInvocationNode(position, add); - const n2 = buildInvocationNode(position, img_resize); + const n1 = buildNode(add); + const n2 = buildNode(img_resize); const nodes = [n1, n2]; const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'image' }; const r = validateConnection(c, nodes, [], templates, null, false); From ce2ad5903c8878d80f400028d8e3f69690eced98 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 09:09:42 +1000 Subject: [PATCH 034/207] feat(ui): extract logic for finding candidate fields to own function --- .../store/util/getFirstValidConnection.ts | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts index 00899c065d..1449a3298a 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts @@ -116,3 +116,75 @@ export const getFirstValidConnection = ( return null; }; + +export const getTargetCandidateFields = ( + source: string, + sourceHandle: string, + target: string, + nodes: AnyNode[], + edges: Edge[], + templates: Templates, + edgePendingUpdate: Edge | null +): FieldInputTemplate[] => { + const sourceNode = nodes.find((n) => n.id === source); + const targetNode = nodes.find((n) => n.id === target); + if (!sourceNode || !targetNode) { + return []; + } + + const sourceTemplate = templates[sourceNode.data.type]; + const targetTemplate = templates[targetNode.data.type]; + if (!sourceTemplate || !targetTemplate) { + return []; + } + + const sourceField = sourceTemplate.outputs[sourceHandle]; + + if (!sourceField) { + return []; + } + + const targetCandidateFields = map(targetTemplate.inputs).filter((field) => { + const c = { source, sourceHandle, target, targetHandle: field.name }; + const r = validateConnection(c, nodes, edges, templates, edgePendingUpdate, true); + return r.isValid; + }); + + return targetCandidateFields; +}; + +export const getSourceCandidateFields = ( + target: string, + targetHandle: string, + source: string, + nodes: AnyNode[], + edges: Edge[], + templates: Templates, + edgePendingUpdate: Edge | null +): FieldOutputTemplate[] => { + const targetNode = nodes.find((n) => n.id === target); + const sourceNode = nodes.find((n) => n.id === source); + if (!sourceNode || !targetNode) { + return []; + } + + const sourceTemplate = templates[sourceNode.data.type]; + const targetTemplate = templates[targetNode.data.type]; + if (!sourceTemplate || !targetTemplate) { + return []; + } + + const targetField = targetTemplate.inputs[targetHandle]; + + if (!targetField) { + return []; + } + + const sourceCandidateFields = map(sourceTemplate.outputs).filter((field) => { + const c = { source, sourceHandle: field.name, target, targetHandle }; + const r = validateConnection(c, nodes, edges, templates, edgePendingUpdate, true); + return r.isValid; + }); + + return sourceCandidateFields; +}; From c98205d0d745e69108488469fb8d6e20df3710c9 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 09:10:12 +1000 Subject: [PATCH 035/207] tests(ui): candidate fields, getFirstValidConnection (wip) --- .../util/getFirstValidConnection.test.ts | 148 ++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100644 invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.test.ts diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.test.ts new file mode 100644 index 0000000000..59d5bbadf6 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.test.ts @@ -0,0 +1,148 @@ +import { deepClone } from 'common/util/deepClone'; +import type { PendingConnection } from 'features/nodes/store/types'; +import { + getFirstValidConnection, + getSourceCandidateFields, + getTargetCandidateFields, +} from 'features/nodes/store/util/getFirstValidConnection'; +import { add, buildEdge, buildNode, img_resize, templates } from 'features/nodes/store/util/testUtils'; +import { unset } from 'lodash-es'; +import { describe, expect, it } from 'vitest'; + +describe('getFirstValidConnection', () => { + it('should return null if the pending and candidate nodes are the same node', () => { + const pc: PendingConnection = { node: buildNode(add), template: add, fieldTemplate: add.inputs['a']! }; + const candidateNode = pc.node; + expect(getFirstValidConnection(templates, [pc.node], [], pc, candidateNode, add, null)).toBe(null); + }); + + describe('connecting from a source to a target', () => { + const pc: PendingConnection = { + node: buildNode(img_resize), + template: img_resize, + fieldTemplate: img_resize.outputs['width']!, + }; + const candidateNode = buildNode(img_resize); + + it('should return the first valid connection if there are no connected fields', () => { + const r = getFirstValidConnection(templates, [pc.node, candidateNode], [], pc, candidateNode, img_resize, null); + const c = { + source: pc.node.id, + sourceHandle: pc.fieldTemplate.name, + target: candidateNode.id, + targetHandle: 'width', + }; + expect(r).toEqual(c); + }); + it('should return the first valid connection if there is a connected field', () => { + const r = getFirstValidConnection( + templates, + [pc.node, candidateNode], + [buildEdge(pc.node.id, 'width', candidateNode.id, 'width')], + pc, + candidateNode, + img_resize, + null + ); + const c = { + source: pc.node.id, + sourceHandle: pc.fieldTemplate.name, + target: candidateNode.id, + targetHandle: 'height', + }; + expect(r).toEqual(c); + }); + it('should return the first valid connection if there is an edgePendingUpdate', () => { + const e = buildEdge(pc.node.id, 'width', candidateNode.id, 'width'); + const r = getFirstValidConnection(templates, [pc.node, candidateNode], [e], pc, candidateNode, img_resize, e); + const c = { + source: pc.node.id, + sourceHandle: pc.fieldTemplate.name, + target: candidateNode.id, + targetHandle: 'width', + }; + expect(r).toEqual(c); + }); + }); + describe('connecting from a target to a source', () => {}); +}); + +describe('getTargetCandidateFields', () => { + it('should return an empty array if the nodes canot be found', () => { + const r = getTargetCandidateFields('missing', 'value', 'missing', [], [], templates, null); + expect(r).toEqual([]); + }); + it('should return an empty array if the templates cannot be found', () => { + const n1 = buildNode(add); + const n2 = buildNode(add); + const nodes = [n1, n2]; + const r = getTargetCandidateFields(n1.id, 'value', n2.id, nodes, [], {}, null); + expect(r).toEqual([]); + }); + it('should return an empty array if the source field template cannot be found', () => { + const n1 = buildNode(add); + const n2 = buildNode(add); + const nodes = [n1, n2]; + + const addWithoutOutputValue = deepClone(add); + unset(addWithoutOutputValue, 'outputs.value'); + + const r = getTargetCandidateFields(n1.id, 'value', n2.id, nodes, [], { add: addWithoutOutputValue }, null); + expect(r).toEqual([]); + }); + it('should return all valid target fields if there are no connected fields', () => { + const n1 = buildNode(img_resize); + const n2 = buildNode(img_resize); + const nodes = [n1, n2]; + const r = getTargetCandidateFields(n1.id, 'width', n2.id, nodes, [], templates, null); + expect(r).toEqual([img_resize.inputs['width'], img_resize.inputs['height']]); + }); + it('should ignore the edgePendingUpdate if provided', () => { + const n1 = buildNode(img_resize); + const n2 = buildNode(img_resize); + const nodes = [n1, n2]; + const edgePendingUpdate = buildEdge(n1.id, 'width', n2.id, 'width'); + const r = getTargetCandidateFields(n1.id, 'width', n2.id, nodes, [], templates, edgePendingUpdate); + expect(r).toEqual([img_resize.inputs['width'], img_resize.inputs['height']]); + }); +}); + +describe('getSourceCandidateFields', () => { + it('should return an empty array if the nodes canot be found', () => { + const r = getSourceCandidateFields('missing', 'value', 'missing', [], [], templates, null); + expect(r).toEqual([]); + }); + it('should return an empty array if the templates cannot be found', () => { + const n1 = buildNode(add); + const n2 = buildNode(add); + const nodes = [n1, n2]; + const r = getSourceCandidateFields(n2.id, 'a', n1.id, nodes, [], {}, null); + expect(r).toEqual([]); + }); + it('should return an empty array if the source field template cannot be found', () => { + const n1 = buildNode(add); + const n2 = buildNode(add); + const nodes = [n1, n2]; + + const addWithoutInputA = deepClone(add); + unset(addWithoutInputA, 'inputs.a'); + + const r = getSourceCandidateFields(n1.id, 'a', n2.id, nodes, [], { add: addWithoutInputA }, null); + expect(r).toEqual([]); + }); + it('should return all valid source fields if there are no connected fields', () => { + const n1 = buildNode(img_resize); + const n2 = buildNode(img_resize); + const nodes = [n1, n2]; + const r = getSourceCandidateFields(n2.id, 'width', n1.id, nodes, [], templates, null); + expect(r).toEqual([img_resize.outputs['width'], img_resize.outputs['height']]); + }); + it('should ignore the edgePendingUpdate if provided', () => { + const n1 = buildNode(img_resize); + const n2 = buildNode(img_resize); + const nodes = [n1, n2]; + const edgePendingUpdate = buildEdge(n1.id, 'width', n2.id, 'width'); + const r = getSourceCandidateFields(n2.id, 'width', n1.id, nodes, [], templates, edgePendingUpdate); + expect(r).toEqual([img_resize.outputs['width'], img_resize.outputs['height']]); + }); +}); From 83000a4190ac180f74679e49888018f8a61b732a Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 09:59:29 +1000 Subject: [PATCH 036/207] feat(ui): rework getFirstValidConnection with new helpers --- .../store/util/getFirstValidConnection.ts | 143 +++++++----------- 1 file changed, 51 insertions(+), 92 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts index 1449a3298a..adc51341d7 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts @@ -1,117 +1,76 @@ -import type { PendingConnection, Templates } from 'features/nodes/store/types'; -import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; -import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation'; -import { differenceWith, map } from 'lodash-es'; +import type { Templates } from 'features/nodes/store/types'; +import { validateConnection } from 'features/nodes/store/util/validateConnection'; +import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field'; +import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation'; +import { map } from 'lodash-es'; import type { Connection, Edge } from 'reactflow'; -import { assert } from 'tsafe'; - -import { areTypesEqual } from './areTypesEqual'; -import { getCollectItemType } from './getCollectItemType'; -import { getHasCycles } from './getHasCycles'; /** - * Finds the first valid field for a pending connection between two nodes. - * @param templates The invocation templates + * + * @param source The source (node id) + * @param sourceHandle The source handle (field name), if any + * @param target The target (node id) + * @param targetHandle The target handle (field name), if any * @param nodes The current nodes * @param edges The current edges - * @param pendingConnection The pending connection - * @param candidateNode The candidate node to which the connection is being made - * @param candidateTemplate The candidate template for the candidate node - * @returns The first valid connection, or null if no valid connection is found + * @param templates The current templates + * @param edgePendingUpdate The edge pending update, if any + * @returns */ - export const getFirstValidConnection = ( - templates: Templates, + source: string, + sourceHandle: string | null, + target: string, + targetHandle: string | null, nodes: AnyNode[], edges: InvocationNodeEdge[], - pendingConnection: PendingConnection, - candidateNode: InvocationNode, - candidateTemplate: InvocationTemplate, + templates: Templates, edgePendingUpdate: Edge | null ): Connection | null => { - if (pendingConnection.node.id === candidateNode.id) { - // Cannot connect to self + if (source === target) { return null; } - const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source'; + if (sourceHandle && targetHandle) { + return { source, sourceHandle, target, targetHandle }; + } - if (pendingFieldKind === 'source') { - // Connecting from a source to a target - if (getHasCycles(pendingConnection.node.id, candidateNode.id, nodes, edges)) { - return null; - } - if (candidateNode.data.type === 'collect') { - // Special handling for collect node - the `item` field takes any number of connections - return { - source: pendingConnection.node.id, - sourceHandle: pendingConnection.fieldTemplate.name, - target: candidateNode.id, - targetHandle: 'item', - }; - } - // Only one connection per target field is allowed - look for an unconnected target field - const candidateFields = map(candidateTemplate.inputs); - const candidateConnectedFields = edges - .filter((edge) => edge.target === candidateNode.id || edge.id === edgePendingUpdate?.id) - .map((edge) => { - // Edges must always have a targetHandle, safe to assert here - assert(edge.targetHandle); - return edge.targetHandle; - }); - const candidateUnconnectedFields = differenceWith( - candidateFields, - candidateConnectedFields, - (field, connectedFieldName) => field.name === connectedFieldName + if (sourceHandle && !targetHandle) { + const candidates = getTargetCandidateFields( + source, + sourceHandle, + target, + nodes, + edges, + templates, + edgePendingUpdate ); - const candidateField = candidateUnconnectedFields.find((field) => - validateConnectionTypes(pendingConnection.fieldTemplate.type, field.type) - ); - if (candidateField) { - return { - source: pendingConnection.node.id, - sourceHandle: pendingConnection.fieldTemplate.name, - target: candidateNode.id, - targetHandle: candidateField.name, - }; - } - } else { - // Connecting from a target to a source - // Ensure we there is not already an edge to the target, except for collect nodes - const isCollect = pendingConnection.node.data.type === 'collect'; - const isTargetAlreadyConnected = edges.some( - (e) => e.target === pendingConnection.node.id && e.targetHandle === pendingConnection.fieldTemplate.name - ); - if (!isCollect && isTargetAlreadyConnected) { + + const firstCandidate = candidates[0]; + if (!firstCandidate) { return null; } - if (getHasCycles(candidateNode.id, pendingConnection.node.id, nodes, edges)) { + return { source, sourceHandle, target, targetHandle: firstCandidate.name }; + } + + if (!sourceHandle && targetHandle) { + const candidates = getSourceCandidateFields( + target, + targetHandle, + source, + nodes, + edges, + templates, + edgePendingUpdate + ); + + const firstCandidate = candidates[0]; + if (!firstCandidate) { return null; } - // Sources/outputs can have any number of edges, we can take the first matching output field - let candidateFields = map(candidateTemplate.outputs); - if (isCollect) { - // Narrow candidates to same field type as already is connected to the collect node - const collectItemType = getCollectItemType(templates, nodes, edges, pendingConnection.node.id); - if (collectItemType) { - candidateFields = candidateFields.filter((field) => areTypesEqual(field.type, collectItemType)); - } - } - const candidateField = candidateFields.find((field) => { - const isValid = validateConnectionTypes(field.type, pendingConnection.fieldTemplate.type); - const isAlreadyConnected = edges.some((e) => e.source === candidateNode.id && e.sourceHandle === field.name); - return isValid && !isAlreadyConnected; - }); - if (candidateField) { - return { - source: candidateNode.id, - sourceHandle: candidateField.name, - target: pendingConnection.node.id, - targetHandle: pendingConnection.fieldTemplate.name, - }; - } + return { source, sourceHandle: firstCandidate.name, target, targetHandle }; } return null; From b1e28c2f2c3cb1e638c2c3d76941c92678233f87 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 09:59:59 +1000 Subject: [PATCH 037/207] tests(ui): coverage for getFirstValidConnection --- .../util/getFirstValidConnection.test.ts | 119 +++++++++++++----- 1 file changed, 87 insertions(+), 32 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.test.ts index 59d5bbadf6..7d04ea8a58 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.test.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.test.ts @@ -1,5 +1,4 @@ import { deepClone } from 'common/util/deepClone'; -import type { PendingConnection } from 'features/nodes/store/types'; import { getFirstValidConnection, getSourceCandidateFields, @@ -11,60 +10,116 @@ import { describe, expect, it } from 'vitest'; describe('getFirstValidConnection', () => { it('should return null if the pending and candidate nodes are the same node', () => { - const pc: PendingConnection = { node: buildNode(add), template: add, fieldTemplate: add.inputs['a']! }; - const candidateNode = pc.node; - expect(getFirstValidConnection(templates, [pc.node], [], pc, candidateNode, add, null)).toBe(null); + const n = buildNode(add); + expect(getFirstValidConnection(n.id, 'value', n.id, null, [n], [], templates, null)).toBe(null); + }); + + it('should return null if the sourceHandle and targetHandle are null', () => { + const n1 = buildNode(add); + const n2 = buildNode(add); + expect(getFirstValidConnection(n1.id, null, n2.id, null, [n1, n2], [], templates, null)).toBe(null); + }); + + it('should return itself if both sourceHandle and targetHandle are provided', () => { + const n1 = buildNode(add); + const n2 = buildNode(add); + expect(getFirstValidConnection(n1.id, 'value', n2.id, 'a', [n1, n2], [], templates, null)).toEqual({ + source: n1.id, + sourceHandle: 'value', + target: n2.id, + targetHandle: 'a', + }); }); describe('connecting from a source to a target', () => { - const pc: PendingConnection = { - node: buildNode(img_resize), - template: img_resize, - fieldTemplate: img_resize.outputs['width']!, - }; - const candidateNode = buildNode(img_resize); + const n1 = buildNode(img_resize); + const n2 = buildNode(img_resize); it('should return the first valid connection if there are no connected fields', () => { - const r = getFirstValidConnection(templates, [pc.node, candidateNode], [], pc, candidateNode, img_resize, null); + const r = getFirstValidConnection(n1.id, 'width', n2.id, null, [n1, n2], [], templates, null); const c = { - source: pc.node.id, - sourceHandle: pc.fieldTemplate.name, - target: candidateNode.id, + source: n1.id, + sourceHandle: 'width', + target: n2.id, targetHandle: 'width', }; expect(r).toEqual(c); }); it('should return the first valid connection if there is a connected field', () => { - const r = getFirstValidConnection( - templates, - [pc.node, candidateNode], - [buildEdge(pc.node.id, 'width', candidateNode.id, 'width')], - pc, - candidateNode, - img_resize, - null - ); + const e = buildEdge(n1.id, 'height', n2.id, 'width'); + const r = getFirstValidConnection(n1.id, 'width', n2.id, null, [n1, n2], [e], templates, null); const c = { - source: pc.node.id, - sourceHandle: pc.fieldTemplate.name, - target: candidateNode.id, + source: n1.id, + sourceHandle: 'width', + target: n2.id, targetHandle: 'height', }; expect(r).toEqual(c); }); it('should return the first valid connection if there is an edgePendingUpdate', () => { - const e = buildEdge(pc.node.id, 'width', candidateNode.id, 'width'); - const r = getFirstValidConnection(templates, [pc.node, candidateNode], [e], pc, candidateNode, img_resize, e); + const e = buildEdge(n1.id, 'width', n2.id, 'width'); + const r = getFirstValidConnection(n1.id, 'width', n2.id, null, [n1, n2], [e], templates, e); const c = { - source: pc.node.id, - sourceHandle: pc.fieldTemplate.name, - target: candidateNode.id, + source: n1.id, + sourceHandle: 'width', + target: n2.id, targetHandle: 'width', }; expect(r).toEqual(c); }); + it('should return null if the target has no valid fields', () => { + const e1 = buildEdge(n1.id, 'width', n2.id, 'width'); + const e2 = buildEdge(n1.id, 'height', n2.id, 'height'); + const n3 = buildNode(add); + const r = getFirstValidConnection(n3.id, 'value', n2.id, null, [n1, n2, n3], [e1, e2], templates, null); + expect(r).toEqual(null); + }); + }); + + describe('connecting from a target to a source', () => { + const n1 = buildNode(img_resize); + const n2 = buildNode(img_resize); + + it('should return the first valid connection if there are no connected fields', () => { + const r = getFirstValidConnection(n1.id, null, n2.id, 'width', [n1, n2], [], templates, null); + const c = { + source: n1.id, + sourceHandle: 'width', + target: n2.id, + targetHandle: 'width', + }; + expect(r).toEqual(c); + }); + it('should return the first valid connection if there is a connected field', () => { + const e = buildEdge(n1.id, 'height', n2.id, 'width'); + const r = getFirstValidConnection(n1.id, null, n2.id, 'height', [n1, n2], [e], templates, null); + const c = { + source: n1.id, + sourceHandle: 'width', + target: n2.id, + targetHandle: 'height', + }; + expect(r).toEqual(c); + }); + it('should return the first valid connection if there is an edgePendingUpdate', () => { + const e = buildEdge(n1.id, 'width', n2.id, 'width'); + const r = getFirstValidConnection(n1.id, null, n2.id, 'width', [n1, n2], [e], templates, e); + const c = { + source: n1.id, + sourceHandle: 'width', + target: n2.id, + targetHandle: 'width', + }; + expect(r).toEqual(c); + }); + it('should return null if the target has no valid fields', () => { + const e1 = buildEdge(n1.id, 'width', n2.id, 'width'); + const e2 = buildEdge(n1.id, 'height', n2.id, 'height'); + const n3 = buildNode(add); + const r = getFirstValidConnection(n3.id, null, n2.id, 'a', [n1, n2, n3], [e1, e2], templates, null); + expect(r).toEqual(null); + }); }); - describe('connecting from a target to a source', () => {}); }); describe('getTargetCandidateFields', () => { From 4bda174eb99c5c9ed63690c4811ee187a1035a5b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 11:11:31 +1000 Subject: [PATCH 038/207] tests(ui): coverage for getCollectItemType --- .../store/util/getCollectItemType.test.ts | 33 ++++++++++++++++--- .../nodes/store/util/getCollectItemType.ts | 7 ++-- 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts index 2fed41c14c..935250b697 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts @@ -1,21 +1,44 @@ +import { deepClone } from 'common/util/deepClone'; import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType'; import { add, buildEdge, buildNode, collect, templates } from 'features/nodes/store/util/testUtils'; import type { FieldType } from 'features/nodes/types/field'; +import { unset } from 'lodash-es'; import { describe, expect, it } from 'vitest'; describe(getCollectItemType.name, () => { it('should return the type of the items the collect node collects', () => { const n1 = buildNode(add); const n2 = buildNode(collect); - const nodes = [n1, n2]; - const edges = [buildEdge(n1.id, 'value', n2.id, 'item')]; - const result = getCollectItemType(templates, nodes, edges, n2.id); + const e1 = buildEdge(n1.id, 'value', n2.id, 'item'); + const result = getCollectItemType(templates, [n1, n2], [e1], n2.id); expect(result).toEqual({ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false }); }); it('should return null if the collect node does not have any connections', () => { const n1 = buildNode(collect); - const nodes = [n1]; - const result = getCollectItemType(templates, nodes, [], n1.id); + const result = getCollectItemType(templates, [n1], [], n1.id); + expect(result).toBeNull(); + }); + it("should return null if the first edge to collect's node doesn't exist", () => { + const n1 = buildNode(collect); + const n2 = buildNode(add); + const e1 = buildEdge(n2.id, 'value', n1.id, 'item'); + const result = getCollectItemType(templates, [n1], [e1], n1.id); + expect(result).toBeNull(); + }); + it("should return null if the first edge to collect's node template doesn't exist", () => { + const n1 = buildNode(collect); + const n2 = buildNode(add); + const e1 = buildEdge(n2.id, 'value', n1.id, 'item'); + const result = getCollectItemType({ collect }, [n1, n2], [e1], n1.id); + expect(result).toBeNull(); + }); + it("should return null if the first edge to the collect's field template doesn't exist", () => { + const n1 = buildNode(collect); + const n2 = buildNode(add); + const addWithoutOutputValue = deepClone(add); + unset(addWithoutOutputValue, 'outputs.value'); + const e1 = buildEdge(n2.id, 'value', n1.id, 'item'); + const result = getCollectItemType({ add: addWithoutOutputValue, collect }, [n2, n1], [e1], n1.id); expect(result).toBeNull(); }); }); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.ts b/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.ts index 9e0ce0fbee..e6c117d91e 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.ts @@ -30,6 +30,9 @@ export const getCollectItemType = ( if (!template) { return null; } - const fieldType = template.outputs[firstEdgeToCollect.sourceHandle]?.type ?? null; - return fieldType; + const fieldTemplate = template.outputs[firstEdgeToCollect.sourceHandle]; + if (!fieldTemplate) { + return null; + } + return fieldTemplate.type; }; From a80e3448f57f90e8e1993d09493498d0179468df Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 11:49:40 +1000 Subject: [PATCH 039/207] feat(ui): rework pendingConnection --- .../flow/AddNodePopover/AddNodePopover.tsx | 37 ++++++++---- .../src/features/nodes/hooks/useConnection.ts | 56 ++++++++++--------- .../nodes/hooks/useConnectionState.ts | 4 +- .../web/src/features/nodes/store/types.ts | 7 ++- .../store/util/makeConnectionErrorSelector.ts | 12 ++-- .../features/nodes/store/util/testUtils.ts | 6 +- .../nodes/store/util/validateConnection.ts | 20 +------ 7 files changed, 73 insertions(+), 69 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx index 14d69b4720..561890245e 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx @@ -9,6 +9,7 @@ import type { SelectInstance } from 'chakra-react-select'; import { useBuildNode } from 'features/nodes/hooks/useBuildNode'; import { $cursorPos, + $edgePendingUpdate, $isAddNodePopoverOpen, $pendingConnection, $templates, @@ -28,7 +29,6 @@ import { useHotkeys } from 'react-hotkeys-hook'; import type { HotkeyCallback } from 'react-hotkeys-hook/dist/types'; import { useTranslation } from 'react-i18next'; import type { FilterOptionOption } from 'react-select/dist/declarations/src/filters'; -import { assert } from 'tsafe'; const createRegex = memoize( (inputValue: string) => @@ -68,16 +68,18 @@ const AddNodePopover = () => { const filteredTemplates = useMemo(() => { // If we have a connection in progress, we need to filter the node choices + const templatesArray = map(templates); if (!pendingConnection) { - return map(templates); + return templatesArray; } return filter(templates, (template) => { - const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind; - const fields = pendingFieldKind === 'input' ? template.outputs : template.inputs; - return some(fields, (field) => { - const sourceType = pendingFieldKind === 'input' ? field.type : pendingConnection.fieldTemplate.type; - const targetType = pendingFieldKind === 'output' ? field.type : pendingConnection.fieldTemplate.type; + const candidateFields = pendingConnection.handleType === 'source' ? template.inputs : template.outputs; + return some(candidateFields, (field) => { + const sourceType = + pendingConnection.handleType === 'source' ? field.type : pendingConnection.fieldTemplate.type; + const targetType = + pendingConnection.handleType === 'target' ? field.type : pendingConnection.fieldTemplate.type; return validateConnectionTypes(sourceType, targetType); }); }); @@ -144,10 +146,25 @@ const AddNodePopover = () => { // Auto-connect an edge if we just added a node and have a pending connection if (pendingConnection && isInvocationNode(node)) { - const template = templates[node.data.type]; - assert(template, 'Template not found'); + const edgePendingUpdate = $edgePendingUpdate.get(); + const { handleType } = pendingConnection; + + const source = handleType === 'source' ? pendingConnection.nodeId : node.id; + const sourceHandle = handleType === 'source' ? pendingConnection.handleId : null; + const target = handleType === 'target' ? pendingConnection.nodeId : node.id; + const targetHandle = handleType === 'target' ? pendingConnection.handleId : null; + const { nodes, edges } = store.getState().nodes.present; - const connection = getFirstValidConnection(templates, nodes, edges, pendingConnection, node, template); + const connection = getFirstValidConnection( + source, + sourceHandle, + target, + targetHandle, + nodes, + edges, + templates, + edgePendingUpdate + ); if (connection) { dispatch(connectionMade(connection)); } diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts index d81a9e5807..f7bf1b8740 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts @@ -9,10 +9,10 @@ import { connectionMade, } from 'features/nodes/store/nodesSlice'; import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection'; -import { isInvocationNode } from 'features/nodes/types/invocation'; import { isString } from 'lodash-es'; import { useCallback, useMemo } from 'react'; -import { type OnConnect, type OnConnectEnd, type OnConnectStart, useUpdateNodeInternals } from 'reactflow'; +import type { OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow'; +import { useUpdateNodeInternals } from 'reactflow'; import { assert } from 'tsafe'; export const useConnection = () => { @@ -21,21 +21,27 @@ export const useConnection = () => { const updateNodeInternals = useUpdateNodeInternals(); const onConnectStart = useCallback( - (event, params) => { + (event, { nodeId, handleId, handleType }) => { + assert(nodeId && handleId && handleType, 'Invalid connection start event'); const nodes = store.getState().nodes.present.nodes; - const { nodeId, handleId, handleType } = params; - assert(nodeId && handleId && handleType, `Invalid connection start params: ${JSON.stringify(params)}`); + const node = nodes.find((n) => n.id === nodeId); - assert(isInvocationNode(node), `Invalid node during connection: ${JSON.stringify(node)}`); + if (!node) { + return; + } + const template = templates[node.data.type]; - assert(template, `Template not found for node type: ${node.data.type}`); - const fieldTemplate = handleType === 'source' ? template.outputs[handleId] : template.inputs[handleId]; - assert(fieldTemplate, `Field template not found for field: ${node.data.type}.${handleId}`); - $pendingConnection.set({ - node, - template, - fieldTemplate, - }); + if (!template) { + return; + } + + const fieldTemplates = template[handleType === 'source' ? 'outputs' : 'inputs']; + const fieldTemplate = fieldTemplates[handleId]; + if (!fieldTemplate) { + return; + } + + $pendingConnection.set({ nodeId, handleId, handleType, fieldTemplate }); }, [store, templates] ); @@ -67,20 +73,20 @@ export const useConnection = () => { } const { nodes, edges } = store.getState().nodes.present; if (mouseOverNodeId) { - const candidateNode = nodes.filter(isInvocationNode).find((n) => n.id === mouseOverNodeId); - if (!candidateNode) { - // The mouse is over a non-invocation node - bail - return; - } - const candidateTemplate = templates[candidateNode.data.type]; - assert(candidateTemplate, `Template not found for node type: ${candidateNode.data.type}`); + const { handleType } = pendingConnection; + const source = handleType === 'source' ? pendingConnection.nodeId : mouseOverNodeId; + const sourceHandle = handleType === 'source' ? pendingConnection.handleId : null; + const target = handleType === 'target' ? pendingConnection.nodeId : mouseOverNodeId; + const targetHandle = handleType === 'target' ? pendingConnection.handleId : null; + const connection = getFirstValidConnection( - templates, + source, + sourceHandle, + target, + targetHandle, nodes, edges, - pendingConnection, - candidateNode, - candidateTemplate, + templates, edgePendingUpdate ); if (connection) { diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts index 7649209863..d218734fff 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts @@ -43,8 +43,8 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta return false; } return ( - pendingConnection.node.id === nodeId && - pendingConnection.fieldTemplate.name === fieldName && + pendingConnection.nodeId === nodeId && + pendingConnection.handleId === fieldName && pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind] ); }, [fieldName, kind, nodeId, pendingConnection]); diff --git a/invokeai/frontend/web/src/features/nodes/store/types.ts b/invokeai/frontend/web/src/features/nodes/store/types.ts index 2f514bdb5b..6dcf70cfad 100644 --- a/invokeai/frontend/web/src/features/nodes/store/types.ts +++ b/invokeai/frontend/web/src/features/nodes/store/types.ts @@ -6,19 +6,20 @@ import type { } from 'features/nodes/types/field'; import type { AnyNode, - InvocationNode, InvocationNodeEdge, InvocationTemplate, NodeExecutionState, } from 'features/nodes/types/invocation'; import type { WorkflowV3 } from 'features/nodes/types/workflow'; +import type { HandleType } from 'reactflow'; export type Templates = Record; export type NodeExecutionStates = Record; export type PendingConnection = { - node: InvocationNode; - template: InvocationTemplate; + nodeId: string; + handleId: string; + handleType: HandleType; fieldTemplate: FieldInputTemplate | FieldOutputTemplate; }; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts index e1a443a60e..c6d05d2c7c 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts @@ -36,9 +36,7 @@ export const makeConnectionErrorSelector = ( return i18n.t('nodes.noConnectionInProgress'); } - const connectionHandleType = pendingConnection.fieldTemplate.fieldKind === 'input' ? 'target' : 'source'; - - if (handleType === connectionHandleType) { + if (handleType === pendingConnection.handleType) { if (handleType === 'source') { return i18n.t('nodes.cannotConnectOutputToOutput'); } @@ -46,10 +44,10 @@ export const makeConnectionErrorSelector = ( } // we have to figure out which is the target and which is the source - const source = handleType === 'source' ? nodeId : pendingConnection.node.id; - const sourceHandle = handleType === 'source' ? fieldName : pendingConnection.fieldTemplate.name; - const target = handleType === 'target' ? nodeId : pendingConnection.node.id; - const targetHandle = handleType === 'target' ? fieldName : pendingConnection.fieldTemplate.name; + const source = handleType === 'source' ? nodeId : pendingConnection.nodeId; + const sourceHandle = handleType === 'source' ? fieldName : pendingConnection.handleId; + const target = handleType === 'target' ? nodeId : pendingConnection.nodeId; + const targetHandle = handleType === 'target' ? fieldName : pendingConnection.handleId; const validationResult = validateConnection( { diff --git a/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts b/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts index f351083bc5..5155bb14ea 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts @@ -2,7 +2,7 @@ import type { Templates } from 'features/nodes/store/types'; import type { InvocationTemplate } from 'features/nodes/types/invocation'; import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode'; import type { OpenAPIV3_1 } from 'openapi-types'; -import type { Edge, XYPosition } from 'reactflow'; +import type { Edge } from 'reactflow'; export const buildEdge = (source: string, sourceHandle: string, target: string, targetHandle: string): Edge => ({ source, @@ -13,8 +13,6 @@ export const buildEdge = (source: string, sourceHandle: string, target: string, id: `reactflow__edge-${source}${sourceHandle}-${target}${targetHandle}`, }); -export const position: XYPosition = { x: 0, y: 0 }; - export const buildNode = (template: InvocationTemplate) => buildInvocationNode({ x: 0, y: 0 }, template); export const add: InvocationTemplate = { @@ -176,7 +174,7 @@ export const collect: InvocationTemplate = { classification: 'stable', }; -export const scheduler: InvocationTemplate = { +const scheduler: InvocationTemplate = { title: 'Scheduler', type: 'scheduler', version: '1.0.0', diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts index edb8ac5ecb..56e45c0d80 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts @@ -6,11 +6,10 @@ import { validateConnectionTypes } from 'features/nodes/store/util/validateConne import type { AnyNode } from 'features/nodes/types/invocation'; import type { Connection as NullableConnection, Edge } from 'reactflow'; import type { O } from 'ts-toolbelt'; -import { assert } from 'tsafe'; type Connection = O.NonNullable; -export type ValidateConnectionResult = +type ValidateConnectionResult = | { isValid: true; messageTKey?: string; @@ -20,7 +19,7 @@ export type ValidateConnectionResult = messageTKey: string; }; -export type ValidateConnectionFunc = ( +type ValidateConnectionFunc = ( connection: Connection, nodes: AnyNode[], edges: Edge[], @@ -29,21 +28,6 @@ export type ValidateConnectionFunc = ( strict?: boolean ) => ValidateConnectionResult; -export const buildResult = (isValid: boolean, messageTKey?: string): ValidateConnectionResult => { - if (isValid) { - return { - isValid, - messageTKey, - }; - } else { - assert(messageTKey !== undefined); - return { - isValid, - messageTKey, - }; - } -}; - const getEqualityPredicate = (c: Connection) => (e: Edge): boolean => { From 6b11740ddafaa75ef80a3a018a931db6075dfc4d Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 11:53:05 +1000 Subject: [PATCH 040/207] chore(ui): knip --- .../web/src/features/nodes/hooks/useFieldType.ts.ts | 9 --------- 1 file changed, 9 deletions(-) delete mode 100644 invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts deleted file mode 100644 index 90c08a94aa..0000000000 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts +++ /dev/null @@ -1,9 +0,0 @@ -import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate'; -import type { FieldType } from 'features/nodes/types/field'; -import { useMemo } from 'react'; - -export const useFieldType = (nodeId: string, fieldName: string, kind: 'inputs' | 'outputs'): FieldType => { - const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind); - const fieldType = useMemo(() => fieldTemplate.type, [fieldTemplate]); - return fieldType; -}; From 504ac82077fce730651798d27057a8c003810f58 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 12:15:57 +1000 Subject: [PATCH 041/207] fix(ui): duplicated edges when updating edge with lazy connect --- .../frontend/web/src/features/nodes/hooks/useConnection.ts | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts index f7bf1b8740..de01c79b30 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts @@ -2,11 +2,13 @@ import { useStore } from '@nanostores/react'; import { useAppStore } from 'app/store/storeHooks'; import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode'; import { + $didUpdateEdge, $edgePendingUpdate, $isAddNodePopoverOpen, $pendingConnection, $templates, connectionMade, + edgeDeleted, } from 'features/nodes/store/nodesSlice'; import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection'; import { isString } from 'lodash-es'; @@ -93,6 +95,10 @@ export const useConnection = () => { dispatch(connectionMade(connection)); const nodesToUpdate = [connection.source, connection.target].filter(isString); updateNodeInternals(nodesToUpdate); + if (edgePendingUpdate) { + dispatch(edgeDeleted(edgePendingUpdate.id)); + $didUpdateEdge.set(true); + } } $pendingConnection.set(null); } else { From 26029108f7150a1bf870b26b91205d6da2f9d0c0 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 13:54:32 +1000 Subject: [PATCH 042/207] feat(ui): rework node and edge mutation logic Remove our DIY'd reducers, consolidating all node and edge mutations to use `edgesChanged` and `nodesChanged`, which are called by reactflow. This makes the API for manipulating nodes and edges less tangly and error-prone. --- .../flow/AddNodePopover/AddNodePopover.tsx | 6 +- .../features/nodes/components/flow/Flow.tsx | 50 ++++------- .../flow/nodes/Invocation/MissingFallback.tsx | 20 +++++ .../Invocation/fields/LinearViewField.tsx | 11 ++- .../sidePanel/viewMode/WorkflowField.tsx | 11 ++- .../sidePanel/workflow/WorkflowLinearTab.tsx | 10 ++- .../src/features/nodes/hooks/useConnection.ts | 25 +++--- .../features/nodes/hooks/useDoesFieldExist.ts | 20 +++++ .../src/features/nodes/store/nodesSlice.ts | 87 ++++++++++--------- .../nodes/store/util/reactFlowUtil.ts | 32 +++++++ .../src/features/nodes/store/workflowSlice.ts | 14 +-- 11 files changed, 186 insertions(+), 100 deletions(-) create mode 100644 invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/MissingFallback.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/hooks/useDoesFieldExist.ts create mode 100644 invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.ts diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx index 561890245e..12592c86da 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx @@ -14,11 +14,12 @@ import { $pendingConnection, $templates, closeAddNodePopover, - connectionMade, + edgesChanged, nodeAdded, openAddNodePopover, } from 'features/nodes/store/nodesSlice'; import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection'; +import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil'; import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; import type { AnyNode } from 'features/nodes/types/invocation'; import { isInvocationNode } from 'features/nodes/types/invocation'; @@ -166,7 +167,8 @@ const AddNodePopover = () => { edgePendingUpdate ); if (connection) { - dispatch(connectionMade(connection)); + const newEdge = connectionToEdge(connection); + dispatch(edgesChanged([{ type: 'add', item: newEdge }])); } } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx index 18bbac0b44..5327d72478 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx @@ -14,29 +14,24 @@ import { $lastEdgeUpdateMouseEvent, $pendingConnection, $viewport, - connectionMade, - edgeDeleted, edgesChanged, - edgesDeleted, nodesChanged, - nodesDeleted, redo, selectedAll, + selectionDeleted, undo, } from 'features/nodes/store/nodesSlice'; import { $flow } from 'features/nodes/store/reactFlowInstance'; -import { isString } from 'lodash-es'; +import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil'; import type { CSSProperties, MouseEvent } from 'react'; import { memo, useCallback, useMemo, useRef } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; import type { OnEdgesChange, - OnEdgesDelete, OnEdgeUpdateFunc, OnInit, OnMoveEnd, OnNodesChange, - OnNodesDelete, ProOptions, ReactFlowProps, ReactFlowState, @@ -50,8 +45,6 @@ import CurrentImageNode from './nodes/CurrentImage/CurrentImageNode'; import InvocationNodeWrapper from './nodes/Invocation/InvocationNodeWrapper'; import NotesNode from './nodes/Notes/NotesNode'; -const DELETE_KEYS = ['Delete', 'Backspace']; - const edgeTypes = { collapsed: InvocationCollapsedEdge, default: InvocationDefaultEdge, @@ -109,20 +102,6 @@ export const Flow = memo(() => { [dispatch] ); - const onEdgesDelete: OnEdgesDelete = useCallback( - (edges) => { - dispatch(edgesDeleted(edges)); - }, - [dispatch] - ); - - const onNodesDelete: OnNodesDelete = useCallback( - (nodes) => { - dispatch(nodesDeleted(nodes)); - }, - [dispatch] - ); - const handleMoveEnd: OnMoveEnd = useCallback((e, viewport) => { $viewport.set(viewport); }, []); @@ -167,16 +146,20 @@ export const Flow = memo(() => { }, []); const onEdgeUpdate: OnEdgeUpdateFunc = useCallback( - (edge, newConnection) => { + (oldEdge, newConnection) => { // This event is fired when an edge update is successful $didUpdateEdge.set(true); // When an edge update is successful, we need to delete the old edge and create a new one - dispatch(edgeDeleted(edge.id)); - dispatch(connectionMade(newConnection)); + const newEdge = connectionToEdge(newConnection); + dispatch( + edgesChanged([ + { type: 'remove', id: oldEdge.id }, + { type: 'add', item: newEdge }, + ]) + ); // Because we shift the position of handles depending on whether a field is connected or not, we must use // updateNodeInternals to tell reactflow to recalculate the positions of the handles - const nodesToUpdate = [edge.source, edge.target, newConnection.source, newConnection.target].filter(isString); - updateNodeInternals(nodesToUpdate); + updateNodeInternals([oldEdge.source, oldEdge.target, newEdge.source, newEdge.target]); }, [dispatch, updateNodeInternals] ); @@ -193,7 +176,7 @@ export const Flow = memo(() => { // If we got this far and did not successfully update an edge, and the mouse moved away from the handle, // the user probably intended to delete the edge if (!didUpdateEdge && didMouseMove) { - dispatch(edgeDeleted(edge.id)); + dispatch(edgesChanged([{ type: 'remove', id: edge.id }])); } $edgePendingUpdate.set(null); @@ -267,6 +250,11 @@ export const Flow = memo(() => { }, [cancelConnection]); useHotkeys('esc', onEscapeHotkey); + const onDeleteHotkey = useCallback(() => { + dispatch(selectionDeleted()); + }, [dispatch]); + useHotkeys(['delete', 'backspace'], onDeleteHotkey); + return ( { onMouseMove={onMouseMove} onNodesChange={onNodesChange} onEdgesChange={onEdgesChange} - onEdgesDelete={onEdgesDelete} onEdgeUpdate={onEdgeUpdate} onEdgeUpdateStart={onEdgeUpdateStart} onEdgeUpdateEnd={onEdgeUpdateEnd} - onNodesDelete={onNodesDelete} onConnectStart={onConnectStart} onConnect={onConnect} onConnectEnd={onConnectEnd} @@ -298,7 +284,7 @@ export const Flow = memo(() => { proOptions={proOptions} style={flowStyles} onPaneClick={handlePaneClick} - deleteKeyCode={DELETE_KEYS} + deleteKeyCode={null} selectionMode={selectionMode} elevateEdgesOnSelect > diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/MissingFallback.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/MissingFallback.tsx new file mode 100644 index 0000000000..ca5b74b7ff --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/MissingFallback.tsx @@ -0,0 +1,20 @@ +import { useDoesFieldExist } from 'features/nodes/hooks/useDoesFieldExist'; +import type { PropsWithChildren } from 'react'; +import { memo } from 'react'; + +type Props = PropsWithChildren<{ + nodeId: string; + fieldName?: string; +}>; + +export const MissingFallback = memo((props: Props) => { + // We must be careful here to avoid race conditions where a deleted node is still referenced as an exposed field + const exists = useDoesFieldExist(props.nodeId, props.fieldName); + if (!exists) { + return null; + } + + return props.children; +}); + +MissingFallback.displayName = 'MissingFallback'; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx index 0cd199f7a4..f7ff85f479 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/LinearViewField.tsx @@ -3,6 +3,7 @@ import { CSS } from '@dnd-kit/utilities'; import { Flex, Icon, IconButton, Spacer, Tooltip } from '@invoke-ai/ui-library'; import { useAppDispatch } from 'app/store/storeHooks'; import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay'; +import { MissingFallback } from 'features/nodes/components/flow/nodes/Invocation/MissingFallback'; import { useFieldOriginalValue } from 'features/nodes/hooks/useFieldOriginalValue'; import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode'; import { workflowExposedFieldRemoved } from 'features/nodes/store/workflowSlice'; @@ -20,7 +21,7 @@ type Props = { fieldName: string; }; -const LinearViewField = ({ nodeId, fieldName }: Props) => { +const LinearViewFieldInternal = ({ nodeId, fieldName }: Props) => { const dispatch = useAppDispatch(); const { isValueChanged, onReset } = useFieldOriginalValue(nodeId, fieldName); const { isMouseOverNode, handleMouseOut, handleMouseOver } = useMouseOverNode(nodeId); @@ -99,4 +100,12 @@ const LinearViewField = ({ nodeId, fieldName }: Props) => { ); }; +const LinearViewField = ({ nodeId, fieldName }: Props) => { + return ( + + + + ); +}; + export default memo(LinearViewField); diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/viewMode/WorkflowField.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/viewMode/WorkflowField.tsx index e707dd4f54..a30bda354d 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/viewMode/WorkflowField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/viewMode/WorkflowField.tsx @@ -1,6 +1,7 @@ import { Flex, FormLabel, Icon, IconButton, Spacer, Tooltip } from '@invoke-ai/ui-library'; import FieldTooltipContent from 'features/nodes/components/flow/nodes/Invocation/fields/FieldTooltipContent'; import InputFieldRenderer from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer'; +import { MissingFallback } from 'features/nodes/components/flow/nodes/Invocation/MissingFallback'; import { useFieldLabel } from 'features/nodes/hooks/useFieldLabel'; import { useFieldOriginalValue } from 'features/nodes/hooks/useFieldOriginalValue'; import { useFieldTemplateTitle } from 'features/nodes/hooks/useFieldTemplateTitle'; @@ -14,7 +15,7 @@ type Props = { fieldName: string; }; -const WorkflowField = ({ nodeId, fieldName }: Props) => { +const WorkflowFieldInternal = ({ nodeId, fieldName }: Props) => { const label = useFieldLabel(nodeId, fieldName); const fieldTemplateTitle = useFieldTemplateTitle(nodeId, fieldName, 'inputs'); const { isValueChanged, onReset } = useFieldOriginalValue(nodeId, fieldName); @@ -50,4 +51,12 @@ const WorkflowField = ({ nodeId, fieldName }: Props) => { ); }; +const WorkflowField = ({ nodeId, fieldName }: Props) => { + return ( + + + + ); +}; + export default memo(WorkflowField); diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLinearTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLinearTab.tsx index fa1767138e..9b0e5bb9d6 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLinearTab.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/workflow/WorkflowLinearTab.tsx @@ -6,10 +6,10 @@ import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; import DndSortable from 'features/dnd/components/DndSortable'; import type { DragEndEvent } from 'features/dnd/types'; -import LinearViewField from 'features/nodes/components/flow/nodes/Invocation/fields/LinearViewField'; +import LinearViewFieldInternal from 'features/nodes/components/flow/nodes/Invocation/fields/LinearViewField'; import { selectWorkflowSlice, workflowExposedFieldsReordered } from 'features/nodes/store/workflowSlice'; import type { FieldIdentifier } from 'features/nodes/types/field'; -import { memo, useCallback } from 'react'; +import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo'; @@ -40,16 +40,18 @@ const WorkflowLinearTab = () => { [dispatch, fields] ); + const items = useMemo(() => fields.map((field) => `${field.nodeId}.${field.fieldName}`), [fields]); + return ( - `${field.nodeId}.${field.fieldName}`)}> + {isLoading ? ( ) : fields.length ? ( fields.map(({ nodeId, fieldName }) => ( - + )) ) : ( diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts index de01c79b30..36491e80bc 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts @@ -7,13 +7,12 @@ import { $isAddNodePopoverOpen, $pendingConnection, $templates, - connectionMade, - edgeDeleted, + edgesChanged, } from 'features/nodes/store/nodesSlice'; import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection'; -import { isString } from 'lodash-es'; +import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil'; import { useCallback, useMemo } from 'react'; -import type { OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow'; +import type { EdgeChange, OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow'; import { useUpdateNodeInternals } from 'reactflow'; import { assert } from 'tsafe'; @@ -50,9 +49,9 @@ export const useConnection = () => { const onConnect = useCallback( (connection) => { const { dispatch } = store; - dispatch(connectionMade(connection)); - const nodesToUpdate = [connection.source, connection.target].filter(isString); - updateNodeInternals(nodesToUpdate); + const newEdge = connectionToEdge(connection); + dispatch(edgesChanged([{ type: 'add', item: newEdge }])); + updateNodeInternals([newEdge.source, newEdge.target]); $pendingConnection.set(null); }, [store, updateNodeInternals] @@ -92,13 +91,17 @@ export const useConnection = () => { edgePendingUpdate ); if (connection) { - dispatch(connectionMade(connection)); - const nodesToUpdate = [connection.source, connection.target].filter(isString); - updateNodeInternals(nodesToUpdate); + const newEdge = connectionToEdge(connection); + const changes: EdgeChange[] = [{ type: 'add', item: newEdge }]; + + const nodesToUpdate = [newEdge.source, newEdge.target]; if (edgePendingUpdate) { - dispatch(edgeDeleted(edgePendingUpdate.id)); $didUpdateEdge.set(true); + changes.push({ type: 'remove', id: edgePendingUpdate.id }); + nodesToUpdate.push(edgePendingUpdate.source, edgePendingUpdate.target); } + dispatch(edgesChanged(changes)); + updateNodeInternals(nodesToUpdate); } $pendingConnection.set(null); } else { diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useDoesFieldExist.ts b/invokeai/frontend/web/src/features/nodes/hooks/useDoesFieldExist.ts new file mode 100644 index 0000000000..4e97b1689c --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/hooks/useDoesFieldExist.ts @@ -0,0 +1,20 @@ +import { useAppSelector } from 'app/store/storeHooks'; +import { isInvocationNode } from 'features/nodes/types/invocation'; + +export const useDoesFieldExist = (nodeId: string, fieldName?: string) => { + const doesFieldExist = useAppSelector((s) => { + const node = s.nodes.present.nodes.find((n) => n.id === nodeId); + if (!isInvocationNode(node)) { + return false; + } + if (fieldName === undefined) { + return true; + } + if (!node.data.inputs[fieldName]) { + return false; + } + return true; + }); + + return doesFieldExist; +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 7915d3608c..a1e32a72fe 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -1,6 +1,7 @@ import type { PayloadAction, UnknownAction } from '@reduxjs/toolkit'; import { createSlice, isAnyOf } from '@reduxjs/toolkit'; import type { PersistConfig, RootState } from 'app/store/store'; +import { deepClone } from 'common/util/deepClone'; import { workflowLoaded } from 'features/nodes/store/actions'; import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants'; import type { @@ -48,8 +49,8 @@ import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocatio import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation'; import { atom } from 'nanostores'; import type { MouseEvent } from 'react'; -import type { Connection, Edge, EdgeChange, EdgeRemoveChange, Node, NodeChange, Viewport, XYPosition } from 'reactflow'; -import { addEdge, applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from 'reactflow'; +import type { Edge, EdgeChange, Node, NodeChange, Viewport, XYPosition } from 'reactflow'; +import { applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from 'reactflow'; import type { UndoableOptions } from 'redux-undo'; import type { z } from 'zod'; @@ -124,10 +125,27 @@ export const nodesSlice = createSlice({ state.nodes.push(node); }, edgesChanged: (state, action: PayloadAction) => { - state.edges = applyEdgeChanges(action.payload, state.edges); - }, - connectionMade: (state, action: PayloadAction) => { - state.edges = addEdge({ ...action.payload, type: 'default' }, state.edges); + const changes = deepClone(action.payload); + action.payload.forEach((change) => { + if (change.type === 'remove' || change.type === 'select') { + const edge = state.edges.find((e) => e.id === change.id); + // If we deleted or selected a collapsed edge, we need to find its "hidden" edges and do the same to them + if (edge && edge.type === 'collapsed') { + const hiddenEdges = state.edges.filter((e) => e.source === edge.source && e.target === edge.target); + if (change.type === 'remove') { + hiddenEdges.forEach((e) => { + changes.push({ type: 'remove', id: e.id }); + }); + } + if (change.type === 'select') { + hiddenEdges.forEach((e) => { + changes.push({ type: 'select', id: e.id, selected: change.selected }); + }); + } + } + } + }); + state.edges = applyEdgeChanges(changes, state.edges); }, fieldLabelChanged: ( state, @@ -264,33 +282,6 @@ export const nodesSlice = createSlice({ } } }, - edgeDeleted: (state, action: PayloadAction) => { - state.edges = state.edges.filter((e) => e.id !== action.payload); - }, - edgesDeleted: (state, action: PayloadAction) => { - const edges = action.payload; - const collapsedEdges = edges.filter((e) => e.type === 'collapsed'); - - // if we delete a collapsed edge, we need to delete all collapsed edges between the same nodes - if (collapsedEdges.length) { - const edgeChanges: EdgeRemoveChange[] = []; - collapsedEdges.forEach((collapsedEdge) => { - state.edges.forEach((edge) => { - if (edge.source === collapsedEdge.source && edge.target === collapsedEdge.target) { - edgeChanges.push({ id: edge.id, type: 'remove' }); - } - }); - }); - state.edges = applyEdgeChanges(edgeChanges, state.edges); - } - }, - nodesDeleted: (state, action: PayloadAction) => { - action.payload.forEach((node) => { - if (!isInvocationNode(node)) { - return; - } - }); - }, nodeLabelChanged: (state, action: PayloadAction<{ nodeId: string; label: string }>) => { const { nodeId, label } = action.payload; const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId); @@ -435,6 +426,23 @@ export const nodesSlice = createSlice({ state.nodes = applyNodeChanges(nodeChanges, state.nodes); state.edges = applyEdgeChanges(edgeChanges, state.edges); }, + selectionDeleted: (state) => { + const selectedNodes = state.nodes.filter((n) => n.selected); + const selectedEdges = state.edges.filter((e) => e.selected); + + const nodeChanges: NodeChange[] = selectedNodes.map((n) => ({ + id: n.id, + type: 'remove', + })); + + const edgeChanges: EdgeChange[] = selectedEdges.map((e) => ({ + id: e.id, + type: 'remove', + })); + + state.nodes = applyNodeChanges(nodeChanges, state.nodes); + state.edges = applyEdgeChanges(edgeChanges, state.edges); + }, undo: (state) => state, redo: (state) => state, }, @@ -457,10 +465,7 @@ export const nodesSlice = createSlice({ }); export const { - connectionMade, - edgeDeleted, edgesChanged, - edgesDeleted, fieldValueReset, fieldBoardValueChanged, fieldBooleanValueChanged, @@ -488,11 +493,11 @@ export const { nodeLabelChanged, nodeNotesChanged, nodesChanged, - nodesDeleted, nodeUseCacheChanged, notesNodeValueChanged, selectedAll, selectionPasted, + selectionDeleted, undo, redo, } = nodesSlice.actions; @@ -580,10 +585,7 @@ export const nodesUndoableConfig: UndoableOptions = { // This is used for tracking `state.workflow.isTouched` export const isAnyNodeOrEdgeMutation = isAnyOf( - connectionMade, - edgeDeleted, edgesChanged, - edgesDeleted, fieldBoardValueChanged, fieldBooleanValueChanged, fieldColorValueChanged, @@ -601,13 +603,14 @@ export const isAnyNodeOrEdgeMutation = isAnyOf( fieldStringValueChanged, fieldVaeModelValueChanged, nodeAdded, + nodesChanged, nodeReplaced, nodeIsIntermediateChanged, nodeIsOpenChanged, nodeLabelChanged, nodeNotesChanged, - nodesDeleted, nodeUseCacheChanged, notesNodeValueChanged, - selectionPasted + selectionPasted, + selectionDeleted ); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.ts b/invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.ts new file mode 100644 index 0000000000..89be7951a2 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.ts @@ -0,0 +1,32 @@ +import type { Connection, Edge } from 'reactflow'; +import { assert } from 'tsafe'; + +/** + * Gets the edge id for a connection + * Copied from: https://github.com/xyflow/xyflow/blob/v11/packages/core/src/utils/graph.ts#L44-L45 + * Requested for this to be exported in: https://github.com/xyflow/xyflow/issues/4290 + * @param connection The connection to get the id for + * @returns The edge id + */ +const getEdgeId = (connection: Connection): string => { + const { source, sourceHandle, target, targetHandle } = connection; + return `reactflow__edge-${source}${sourceHandle || ''}-${target}${targetHandle || ''}`; +}; + +/** + * Converts a connection to an edge + * @param connection The connection to convert to an edge + * @returns The edge + * @throws If the connection is invalid (e.g. missing source, sourcehandle, target, or targetHandle) + */ +export const connectionToEdge = (connection: Connection): Edge => { + const { source, sourceHandle, target, targetHandle } = connection; + assert(source && sourceHandle && target && targetHandle, 'Invalid connection'); + return { + source, + sourceHandle, + target, + targetHandle, + id: getEdgeId({ source, sourceHandle, target, targetHandle }), + }; +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts b/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts index 6293d3cce5..b3ec4f0614 100644 --- a/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts @@ -3,7 +3,7 @@ import { createSlice } from '@reduxjs/toolkit'; import type { PersistConfig, RootState } from 'app/store/store'; import { deepClone } from 'common/util/deepClone'; import { workflowLoaded } from 'features/nodes/store/actions'; -import { isAnyNodeOrEdgeMutation, nodeEditorReset, nodesChanged, nodesDeleted } from 'features/nodes/store/nodesSlice'; +import { isAnyNodeOrEdgeMutation, nodeEditorReset, nodesChanged } from 'features/nodes/store/nodesSlice'; import type { FieldIdentifierWithValue, WorkflowMode, @@ -139,16 +139,16 @@ export const workflowSlice = createSlice({ }; }); - builder.addCase(nodesDeleted, (state, action) => { - action.payload.forEach((node) => { - state.exposedFields = state.exposedFields.filter((f) => f.nodeId !== node.id); - }); - }); - builder.addCase(nodeEditorReset, () => deepClone(initialWorkflowState)); builder.addCase(nodesChanged, (state, action) => { // Not all changes to nodes should result in the workflow being marked touched + action.payload.forEach((change) => { + if (change.type === 'remove') { + state.exposedFields = state.exposedFields.filter((f) => f.nodeId !== change.id); + } + }); + const filteredChanges = action.payload.filter((change) => { // We always want to mark the workflow as touched if a node is added, removed, or reset if (['add', 'remove', 'reset'].includes(change.type)) { From e4808440429b2b35342e058c7083623688909068 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 14:06:02 +1000 Subject: [PATCH 043/207] fix(ui): edge styling --- .../flow/edges/InvocationCollapsedEdge.tsx | 30 ++++++++--------- .../flow/edges/InvocationDefaultEdge.tsx | 28 +++++++--------- .../flow/edges/util/getEdgeColor.ts | 14 ++++++++ .../flow/edges/util/makeEdgeSelector.ts | 33 ++++++++++--------- 4 files changed, 56 insertions(+), 49 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationCollapsedEdge.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationCollapsedEdge.tsx index 2e2fb31154..0d7e7b7d5e 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationCollapsedEdge.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationCollapsedEdge.tsx @@ -2,13 +2,13 @@ import { Badge, Flex } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; import { useAppSelector } from 'app/store/storeHooks'; import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens'; +import { getEdgeStyles } from 'features/nodes/components/flow/edges/util/getEdgeColor'; +import { makeEdgeSelector } from 'features/nodes/components/flow/edges/util/makeEdgeSelector'; import { $templates } from 'features/nodes/store/nodesSlice'; import { memo, useMemo } from 'react'; import type { EdgeProps } from 'reactflow'; import { BaseEdge, EdgeLabelRenderer, getBezierPath } from 'reactflow'; -import { makeEdgeSelector } from './util/makeEdgeSelector'; - const InvocationCollapsedEdge = ({ sourceX, sourceY, @@ -18,19 +18,19 @@ const InvocationCollapsedEdge = ({ targetPosition, markerEnd, data, - selected, + selected = false, source, - target, sourceHandleId, + target, targetHandleId, }: EdgeProps<{ count: number }>) => { const templates = useStore($templates); const selector = useMemo( - () => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId, selected), - [templates, selected, source, sourceHandleId, target, targetHandleId] + () => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId), + [templates, source, sourceHandleId, target, targetHandleId] ); - const { isSelected, shouldAnimate } = useAppSelector(selector); + const { shouldAnimateEdges, areConnectedNodesSelected } = useAppSelector(selector); const [edgePath, labelX, labelY] = getBezierPath({ sourceX, @@ -44,14 +44,8 @@ const InvocationCollapsedEdge = ({ const { base500 } = useChakraThemeTokens(); const edgeStyles = useMemo( - () => ({ - strokeWidth: isSelected ? 3 : 2, - stroke: base500, - opacity: isSelected ? 0.8 : 0.5, - animation: shouldAnimate ? 'dashdraw 0.5s linear infinite' : undefined, - strokeDasharray: shouldAnimate ? 5 : 'none', - }), - [base500, isSelected, shouldAnimate] + () => getEdgeStyles(base500, selected, shouldAnimateEdges, areConnectedNodesSelected), + [areConnectedNodesSelected, base500, selected, shouldAnimateEdges] ); return ( @@ -60,11 +54,15 @@ const InvocationCollapsedEdge = ({ {data?.count && data.count > 1 && ( - + {data.count} diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationDefaultEdge.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationDefaultEdge.tsx index 2e4340975b..5a27e974e5 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationDefaultEdge.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/InvocationDefaultEdge.tsx @@ -1,8 +1,8 @@ import { Flex, Text } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; import { useAppSelector } from 'app/store/storeHooks'; +import { getEdgeStyles } from 'features/nodes/components/flow/edges/util/getEdgeColor'; import { $templates } from 'features/nodes/store/nodesSlice'; -import type { CSSProperties } from 'react'; import { memo, useMemo } from 'react'; import type { EdgeProps } from 'reactflow'; import { BaseEdge, EdgeLabelRenderer, getBezierPath } from 'reactflow'; @@ -17,7 +17,7 @@ const InvocationDefaultEdge = ({ sourcePosition, targetPosition, markerEnd, - selected, + selected = false, source, target, sourceHandleId, @@ -25,11 +25,11 @@ const InvocationDefaultEdge = ({ }: EdgeProps) => { const templates = useStore($templates); const selector = useMemo( - () => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId, selected), - [templates, source, sourceHandleId, target, targetHandleId, selected] + () => makeEdgeSelector(templates, source, sourceHandleId, target, targetHandleId), + [templates, source, sourceHandleId, target, targetHandleId] ); - const { isSelected, shouldAnimate, stroke, label } = useAppSelector(selector); + const { shouldAnimateEdges, areConnectedNodesSelected, stroke, label } = useAppSelector(selector); const shouldShowEdgeLabels = useAppSelector((s) => s.workflowSettings.shouldShowEdgeLabels); const [edgePath, labelX, labelY] = getBezierPath({ @@ -41,15 +41,9 @@ const InvocationDefaultEdge = ({ targetPosition, }); - const edgeStyles = useMemo( - () => ({ - strokeWidth: isSelected ? 3 : 2, - stroke, - opacity: isSelected ? 0.8 : 0.5, - animation: shouldAnimate ? 'dashdraw 0.5s linear infinite' : undefined, - strokeDasharray: shouldAnimate ? 5 : 'none', - }), - [isSelected, shouldAnimate, stroke] + const edgeStyles = useMemo( + () => getEdgeStyles(stroke, selected, shouldAnimateEdges, areConnectedNodesSelected), + [areConnectedNodesSelected, stroke, selected, shouldAnimateEdges] ); return ( @@ -65,13 +59,13 @@ const InvocationDefaultEdge = ({ bg="base.800" borderRadius="base" borderWidth={1} - borderColor={isSelected ? 'undefined' : 'transparent'} - opacity={isSelected ? 1 : 0.5} + borderColor={selected ? 'undefined' : 'transparent'} + opacity={selected ? 1 : 0.5} py={1} px={3} shadow="md" > - + {label} diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts index e7fa43015b..91c834011c 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts @@ -1,6 +1,7 @@ import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; import { FIELD_COLORS } from 'features/nodes/types/constants'; import type { FieldType } from 'features/nodes/types/field'; +import type { CSSProperties } from 'react'; export const getFieldColor = (fieldType: FieldType | null): string => { if (!fieldType) { @@ -10,3 +11,16 @@ export const getFieldColor = (fieldType: FieldType | null): string => { return color ? colorTokenToCssVar(color) : colorTokenToCssVar('base.500'); }; + +export const getEdgeStyles = ( + stroke: string, + selected: boolean, + shouldAnimateEdges: boolean, + areConnectedNodesSelected: boolean +): CSSProperties => ({ + strokeWidth: selected ? 3 : areConnectedNodesSelected ? 2 : 1, + stroke, + opacity: selected ? 1 : 0.5, + animation: shouldAnimateEdges ? 'dashdraw 0.5s linear infinite' : undefined, + strokeDasharray: selected || areConnectedNodesSelected ? 5 : 'none', +}); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts index 87ef8eb629..9c67728722 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts @@ -1,5 +1,6 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; +import { deepClone } from 'common/util/deepClone'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; import type { Templates } from 'features/nodes/store/types'; import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice'; @@ -8,8 +9,8 @@ import { isInvocationNode } from 'features/nodes/types/invocation'; import { getFieldColor } from './getEdgeColor'; const defaultReturnValue = { - isSelected: false, - shouldAnimate: false, + areConnectedNodesSelected: false, + shouldAnimateEdges: false, stroke: colorTokenToCssVar('base.500'), label: '', }; @@ -19,21 +20,27 @@ export const makeEdgeSelector = ( source: string, sourceHandleId: string | null | undefined, target: string, - targetHandleId: string | null | undefined, - selected?: boolean + targetHandleId: string | null | undefined ) => createMemoizedSelector( selectNodesSlice, selectWorkflowSettingsSlice, - (nodes, workflowSettings): { isSelected: boolean; shouldAnimate: boolean; stroke: string; label: string } => { + ( + nodes, + workflowSettings + ): { areConnectedNodesSelected: boolean; shouldAnimateEdges: boolean; stroke: string; label: string } => { + const { shouldAnimateEdges, shouldColorEdges } = workflowSettings; const sourceNode = nodes.nodes.find((node) => node.id === source); const targetNode = nodes.nodes.find((node) => node.id === target); + const returnValue = deepClone(defaultReturnValue); + returnValue.shouldAnimateEdges = shouldAnimateEdges; + const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode); - const isSelected = Boolean(sourceNode?.selected || targetNode?.selected || selected); + returnValue.areConnectedNodesSelected = Boolean(sourceNode?.selected || targetNode?.selected); if (!sourceNode || !sourceHandleId || !targetNode || !targetHandleId) { - return defaultReturnValue; + return returnValue; } const sourceNodeTemplate = templates[sourceNode.data.type]; @@ -42,16 +49,10 @@ export const makeEdgeSelector = ( const outputFieldTemplate = sourceNodeTemplate?.outputs[sourceHandleId]; const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined; - const stroke = - sourceType && workflowSettings.shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500'); + returnValue.stroke = sourceType && shouldColorEdges ? getFieldColor(sourceType) : colorTokenToCssVar('base.500'); - const label = `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`; + returnValue.label = `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`; - return { - isSelected, - shouldAnimate: workflowSettings.shouldAnimateEdges && isSelected, - stroke, - label, - }; + return returnValue; } ); From b3429553bb7280556ff111160c85f585c94c2329 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 14:25:54 +1000 Subject: [PATCH 044/207] fix(ui): collapsed edges selected state --- invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts | 2 ++ 1 file changed, 2 insertions(+) diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index a1e32a72fe..05b53c518d 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -250,6 +250,7 @@ export const nodesSlice = createSlice({ type: 'collapsed', data: { count: 1 }, updatable: false, + selected: edge.selected, }); } } @@ -270,6 +271,7 @@ export const nodesSlice = createSlice({ type: 'collapsed', data: { count: 1 }, updatable: false, + selected: edge.selected, }); } } From 21fab9785ac3ae0a9723b587762e977d102a9a96 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 14:28:39 +1000 Subject: [PATCH 045/207] feat(ui): tweak edge styling --- .../features/nodes/components/flow/edges/util/getEdgeColor.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts index 91c834011c..b5801c45ed 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts @@ -18,7 +18,7 @@ export const getEdgeStyles = ( shouldAnimateEdges: boolean, areConnectedNodesSelected: boolean ): CSSProperties => ({ - strokeWidth: selected ? 3 : areConnectedNodesSelected ? 2 : 1, + strokeWidth: 3, stroke, opacity: selected ? 1 : 0.5, animation: shouldAnimateEdges ? 'dashdraw 0.5s linear infinite' : undefined, From e38d75c3dc7df3ecad9c897cabbc31d6644312a0 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 14:36:49 +1000 Subject: [PATCH 046/207] feat(ui): get rid of nodeAdded --- .../flow/AddNodePopover/AddNodePopover.tsx | 26 ++++++++++++++++--- .../src/features/nodes/store/nodesSlice.ts | 25 ------------------ 2 files changed, 23 insertions(+), 28 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx index 12592c86da..6e695561a2 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx @@ -15,9 +15,10 @@ import { $templates, closeAddNodePopover, edgesChanged, - nodeAdded, + nodesChanged, openAddNodePopover, } from 'features/nodes/store/nodesSlice'; +import { findUnoccupiedPosition } from 'features/nodes/store/util/findUnoccupiedPosition'; import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection'; import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil'; import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; @@ -30,6 +31,7 @@ import { useHotkeys } from 'react-hotkeys-hook'; import type { HotkeyCallback } from 'react-hotkeys-hook/dist/types'; import { useTranslation } from 'react-i18next'; import type { FilterOptionOption } from 'react-select/dist/declarations/src/filters'; +import type { EdgeChange, NodeChange } from 'reactflow'; const createRegex = memoize( (inputValue: string) => @@ -131,11 +133,29 @@ const AddNodePopover = () => { }); return null; } + + // Find a cozy spot for the node const cursorPos = $cursorPos.get(); - dispatch(nodeAdded({ node, cursorPos })); + const { nodes, edges } = store.getState().nodes.present; + node.position = findUnoccupiedPosition(nodes, cursorPos?.x ?? node.position.x, cursorPos?.y ?? node.position.y); + node.selected = true; + + // Deselect all other nodes and edges + const nodeChanges: NodeChange[] = [{ type: 'add', item: node }]; + const edgeChanges: EdgeChange[] = []; + nodes.forEach((n) => { + nodeChanges.push({ id: n.id, type: 'select', selected: false }); + }); + edges.forEach((e) => { + edgeChanges.push({ id: e.id, type: 'select', selected: false }); + }); + + // Onwards! + dispatch(nodesChanged(nodeChanges)); + dispatch(edgesChanged(edgeChanges)); return node; }, - [dispatch, buildInvocation, toaster, t] + [buildInvocation, store, dispatch, t, toaster] ); const onChange = useCallback( diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 05b53c518d..5f0dbb2b14 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -55,7 +55,6 @@ import type { UndoableOptions } from 'redux-undo'; import type { z } from 'zod'; import type { NodesState, PendingConnection, Templates } from './types'; -import { findUnoccupiedPosition } from './util/findUnoccupiedPosition'; const initialNodesState: NodesState = { _version: 1, @@ -102,28 +101,6 @@ export const nodesSlice = createSlice({ } state.nodes[nodeIndex] = action.payload.node; }, - nodeAdded: (state, action: PayloadAction<{ node: AnyNode; cursorPos: XYPosition | null }>) => { - const { node, cursorPos } = action.payload; - const position = findUnoccupiedPosition( - state.nodes, - cursorPos?.x ?? node.position.x, - cursorPos?.y ?? node.position.y - ); - node.position = position; - node.selected = true; - - state.nodes = applyNodeChanges( - state.nodes.map((n) => ({ id: n.id, type: 'select', selected: false })), - state.nodes - ); - - state.edges = applyEdgeChanges( - state.edges.map((e) => ({ id: e.id, type: 'select', selected: false })), - state.edges - ); - - state.nodes.push(node); - }, edgesChanged: (state, action: PayloadAction) => { const changes = deepClone(action.payload); action.payload.forEach((change) => { @@ -486,7 +463,6 @@ export const { fieldSchedulerValueChanged, fieldStringValueChanged, fieldVaeModelValueChanged, - nodeAdded, nodeReplaced, nodeEditorReset, nodeExclusivelySelected, @@ -604,7 +580,6 @@ export const isAnyNodeOrEdgeMutation = isAnyOf( fieldSchedulerValueChanged, fieldStringValueChanged, fieldVaeModelValueChanged, - nodeAdded, nodesChanged, nodeReplaced, nodeIsIntermediateChanged, From 1d7671298f768dad7ebbb037fd3a0e3b7cde6132 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 14:43:44 +1000 Subject: [PATCH 047/207] fix(ui): group edge selection actions --- invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 5f0dbb2b14..28a5e2edb2 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -528,6 +528,11 @@ const isSelectionAction = (action: UnknownAction) => { return true; } } + if (edgesChanged.match(action)) { + if (action.payload.every((change) => change.type === 'select')) { + return true; + } + } return false; }; From 9a8e0842bbe5da5c820b17e36789494fe1c8ae42 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 14:44:27 +1000 Subject: [PATCH 048/207] feat(ui): remove nodeReplaced action --- .../listeners/updateAllNodesRequested.ts | 9 +++++++-- .../web/src/features/nodes/store/nodesSlice.ts | 11 +---------- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts index 63d960b406..05cc2f8e83 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts @@ -1,7 +1,7 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { updateAllNodesRequested } from 'features/nodes/store/actions'; -import { $templates, nodeReplaced } from 'features/nodes/store/nodesSlice'; +import { $templates, nodesChanged } from 'features/nodes/store/nodesSlice'; import { NodeUpdateError } from 'features/nodes/types/error'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { getNeedsUpdate, updateNode } from 'features/nodes/util/node/nodeUpdate'; @@ -31,7 +31,12 @@ export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartLi } try { const updatedNode = updateNode(node, template); - dispatch(nodeReplaced({ nodeId: updatedNode.id, node: updatedNode })); + dispatch( + nodesChanged([ + { type: 'remove', id: updatedNode.id }, + { type: 'add', item: updatedNode }, + ]) + ); } catch (e) { if (e instanceof NodeUpdateError) { unableToUpdateCount++; diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 28a5e2edb2..4ef03ee658 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -49,7 +49,7 @@ import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocatio import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation'; import { atom } from 'nanostores'; import type { MouseEvent } from 'react'; -import type { Edge, EdgeChange, Node, NodeChange, Viewport, XYPosition } from 'reactflow'; +import type { Edge, EdgeChange, NodeChange, Viewport, XYPosition } from 'reactflow'; import { applyEdgeChanges, applyNodeChanges, getConnectedEdges, getIncomers, getOutgoers } from 'reactflow'; import type { UndoableOptions } from 'redux-undo'; import type { z } from 'zod'; @@ -94,13 +94,6 @@ export const nodesSlice = createSlice({ nodesChanged: (state, action: PayloadAction) => { state.nodes = applyNodeChanges(action.payload, state.nodes); }, - nodeReplaced: (state, action: PayloadAction<{ nodeId: string; node: Node }>) => { - const nodeIndex = state.nodes.findIndex((n) => n.id === action.payload.nodeId); - if (nodeIndex < 0) { - return; - } - state.nodes[nodeIndex] = action.payload.node; - }, edgesChanged: (state, action: PayloadAction) => { const changes = deepClone(action.payload); action.payload.forEach((change) => { @@ -463,7 +456,6 @@ export const { fieldSchedulerValueChanged, fieldStringValueChanged, fieldVaeModelValueChanged, - nodeReplaced, nodeEditorReset, nodeExclusivelySelected, nodeIsIntermediateChanged, @@ -586,7 +578,6 @@ export const isAnyNodeOrEdgeMutation = isAnyOf( fieldStringValueChanged, fieldVaeModelValueChanged, nodesChanged, - nodeReplaced, nodeIsIntermediateChanged, nodeIsOpenChanged, nodeLabelChanged, From cbe32b647a3d9829d0bc1b8054209b585f2e382b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 14:47:24 +1000 Subject: [PATCH 049/207] feat(ui): remove selectedAll action --- .../features/nodes/components/flow/Flow.tsx | 20 +++++++++++++++---- .../src/features/nodes/store/nodesSlice.ts | 13 +----------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx index 5327d72478..df233f4a18 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx @@ -1,6 +1,6 @@ import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks'; import { useConnection } from 'features/nodes/hooks/useConnection'; import { useCopyPaste } from 'features/nodes/hooks/useCopyPaste'; import { useSyncExecutionState } from 'features/nodes/hooks/useExecutionState'; @@ -17,7 +17,6 @@ import { edgesChanged, nodesChanged, redo, - selectedAll, selectionDeleted, undo, } from 'features/nodes/store/nodesSlice'; @@ -27,6 +26,8 @@ import type { CSSProperties, MouseEvent } from 'react'; import { memo, useCallback, useMemo, useRef } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; import type { + EdgeChange, + NodeChange, OnEdgesChange, OnEdgeUpdateFunc, OnInit, @@ -77,6 +78,7 @@ export const Flow = memo(() => { const isValidConnection = useIsValidConnection(); const cancelConnection = useReactFlowStore(selectCancelConnection); const updateNodeInternals = useUpdateNodeInternals(); + const store = useAppStore(); useWorkflowWatcher(); useSyncExecutionState(); const [borderRadius] = useToken('radii', ['base']); @@ -203,9 +205,19 @@ export const Flow = memo(() => { const onSelectAllHotkey = useCallback( (e: KeyboardEvent) => { e.preventDefault(); - dispatch(selectedAll()); + const { nodes, edges } = store.getState().nodes.present; + const nodeChanges: NodeChange[] = []; + const edgeChanges: EdgeChange[] = []; + nodes.forEach((n) => { + nodeChanges.push({ id: n.id, type: 'select', selected: true }); + }); + edges.forEach((e) => { + edgeChanges.push({ id: e.id, type: 'select', selected: true }); + }); + dispatch(nodesChanged(nodeChanges)); + dispatch(edgesChanged(edgeChanges)); }, - [dispatch] + [dispatch, store] ); useHotkeys(['Ctrl+a', 'Meta+a'], onSelectAllHotkey); diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 4ef03ee658..0838778454 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -347,16 +347,6 @@ export const nodesSlice = createSlice({ state.nodes = []; state.edges = []; }, - selectedAll: (state) => { - state.nodes = applyNodeChanges( - state.nodes.map((n) => ({ id: n.id, type: 'select', selected: true })), - state.nodes - ); - state.edges = applyEdgeChanges( - state.edges.map((e) => ({ id: e.id, type: 'select', selected: true })), - state.edges - ); - }, selectionPasted: (state, action: PayloadAction<{ nodes: AnyNode[]; edges: InvocationNodeEdge[] }>) => { const { nodes, edges } = action.payload; @@ -465,7 +455,6 @@ export const { nodesChanged, nodeUseCacheChanged, notesNodeValueChanged, - selectedAll, selectionPasted, selectionDeleted, undo, @@ -509,7 +498,7 @@ export const nodesPersistConfig: PersistConfig = { persistDenylist: [], }; -const selectionMatcher = isAnyOf(selectedAll, selectionPasted, nodeExclusivelySelected); +const selectionMatcher = isAnyOf(selectionPasted, nodeExclusivelySelected); const isSelectionAction = (action: UnknownAction) => { if (selectionMatcher(action)) { From 7cceafe0dd49eeb67ad8373faf5b16cf5ac7ef5d Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 14:52:11 +1000 Subject: [PATCH 050/207] feat(ui): remove selectionPasted action --- .../src/features/nodes/hooks/useCopyPaste.ts | 43 ++++++++++++++++-- .../src/features/nodes/store/nodesSlice.ts | 45 +------------------ 2 files changed, 40 insertions(+), 48 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useCopyPaste.ts b/invokeai/frontend/web/src/features/nodes/hooks/useCopyPaste.ts index 08def1514c..8be972363f 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useCopyPaste.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useCopyPaste.ts @@ -5,11 +5,13 @@ import { $copiedNodes, $cursorPos, $edgesToCopiedNodes, - selectionPasted, + edgesChanged, + nodesChanged, selectNodesSlice, } from 'features/nodes/store/nodesSlice'; import { findUnoccupiedPosition } from 'features/nodes/store/util/findUnoccupiedPosition'; import { isEqual, uniqWith } from 'lodash-es'; +import type { EdgeChange, NodeChange } from 'reactflow'; import { v4 as uuidv4 } from 'uuid'; const copySelection = () => { @@ -26,7 +28,7 @@ const copySelection = () => { const pasteSelection = (withEdgesToCopiedNodes?: boolean) => { const { getState, dispatch } = getStore(); - const currentNodes = selectNodesSlice(getState()).nodes; + const { nodes, edges } = selectNodesSlice(getState()); const cursorPos = $cursorPos.get(); const copiedNodes = deepClone($copiedNodes.get()); @@ -46,7 +48,7 @@ const pasteSelection = (withEdgesToCopiedNodes?: boolean) => { const offsetY = cursorPos ? cursorPos.y - minY : 50; copiedNodes.forEach((node) => { - const { x, y } = findUnoccupiedPosition(currentNodes, node.position.x + offsetX, node.position.y + offsetY); + const { x, y } = findUnoccupiedPosition(nodes, node.position.x + offsetX, node.position.y + offsetY); node.position.x = x; node.position.y = y; // Pasted nodes are selected @@ -68,7 +70,40 @@ const pasteSelection = (withEdgesToCopiedNodes?: boolean) => { node.data.id = id; }); - dispatch(selectionPasted({ nodes: copiedNodes, edges: copiedEdges })); + const nodeChanges: NodeChange[] = []; + const edgeChanges: EdgeChange[] = []; + // Deselect existing nodes + nodes.forEach((n) => { + nodeChanges.push({ + id: n.data.id, + type: 'select', + selected: false, + }); + }); + // Add new nodes + copiedNodes.forEach((n) => { + nodeChanges.push({ + item: n, + type: 'add', + }); + }); + // Deselect existing edges + edges.forEach((e) => { + edgeChanges.push({ + id: e.id, + type: 'select', + selected: false, + }); + }); + // Add new edges + copiedEdges.forEach((e) => { + edgeChanges.push({ + item: e, + type: 'add', + }); + }); + dispatch(nodesChanged(nodeChanges)); + dispatch(edgesChanged(edgeChanges)); }; const api = { copySelection, pasteSelection }; diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 0838778454..416d7065bb 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -347,47 +347,6 @@ export const nodesSlice = createSlice({ state.nodes = []; state.edges = []; }, - selectionPasted: (state, action: PayloadAction<{ nodes: AnyNode[]; edges: InvocationNodeEdge[] }>) => { - const { nodes, edges } = action.payload; - - const nodeChanges: NodeChange[] = []; - - // Deselect existing nodes - state.nodes.forEach((n) => { - nodeChanges.push({ - id: n.data.id, - type: 'select', - selected: false, - }); - }); - // Add new nodes - nodes.forEach((n) => { - nodeChanges.push({ - item: n, - type: 'add', - }); - }); - - const edgeChanges: EdgeChange[] = []; - // Deselect existing edges - state.edges.forEach((e) => { - edgeChanges.push({ - id: e.id, - type: 'select', - selected: false, - }); - }); - // Add new edges - edges.forEach((e) => { - edgeChanges.push({ - item: e, - type: 'add', - }); - }); - - state.nodes = applyNodeChanges(nodeChanges, state.nodes); - state.edges = applyEdgeChanges(edgeChanges, state.edges); - }, selectionDeleted: (state) => { const selectedNodes = state.nodes.filter((n) => n.selected); const selectedEdges = state.edges.filter((e) => e.selected); @@ -455,7 +414,6 @@ export const { nodesChanged, nodeUseCacheChanged, notesNodeValueChanged, - selectionPasted, selectionDeleted, undo, redo, @@ -498,7 +456,7 @@ export const nodesPersistConfig: PersistConfig = { persistDenylist: [], }; -const selectionMatcher = isAnyOf(selectionPasted, nodeExclusivelySelected); +const selectionMatcher = isAnyOf(nodeExclusivelySelected); const isSelectionAction = (action: UnknownAction) => { if (selectionMatcher(action)) { @@ -573,6 +531,5 @@ export const isAnyNodeOrEdgeMutation = isAnyOf( nodeNotesChanged, nodeUseCacheChanged, notesNodeValueChanged, - selectionPasted, selectionDeleted ); From b8b671c0db0f18935d5370d35ccbe830c4615348 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 14:54:35 +1000 Subject: [PATCH 051/207] feat(ui): remove selectionDeleted action --- .../features/nodes/components/flow/Flow.tsx | 19 ++++++++++++++--- .../src/features/nodes/store/nodesSlice.ts | 21 +------------------ 2 files changed, 17 insertions(+), 23 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx index df233f4a18..75983b1617 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx @@ -17,7 +17,6 @@ import { edgesChanged, nodesChanged, redo, - selectionDeleted, undo, } from 'features/nodes/store/nodesSlice'; import { $flow } from 'features/nodes/store/reactFlowInstance'; @@ -263,8 +262,22 @@ export const Flow = memo(() => { useHotkeys('esc', onEscapeHotkey); const onDeleteHotkey = useCallback(() => { - dispatch(selectionDeleted()); - }, [dispatch]); + const { nodes, edges } = store.getState().nodes.present; + const nodeChanges: NodeChange[] = []; + const edgeChanges: EdgeChange[] = []; + nodes + .filter((n) => n.selected) + .forEach(({ id }) => { + nodeChanges.push({ type: 'remove', id }); + }); + edges + .filter((e) => e.selected) + .forEach(({ id }) => { + edgeChanges.push({ type: 'remove', id }); + }); + dispatch(nodesChanged(nodeChanges)); + dispatch(edgesChanged(edgeChanges)); + }, [dispatch, store]); useHotkeys(['delete', 'backspace'], onDeleteHotkey); return ( diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 416d7065bb..70ac801009 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -347,23 +347,6 @@ export const nodesSlice = createSlice({ state.nodes = []; state.edges = []; }, - selectionDeleted: (state) => { - const selectedNodes = state.nodes.filter((n) => n.selected); - const selectedEdges = state.edges.filter((e) => e.selected); - - const nodeChanges: NodeChange[] = selectedNodes.map((n) => ({ - id: n.id, - type: 'remove', - })); - - const edgeChanges: EdgeChange[] = selectedEdges.map((e) => ({ - id: e.id, - type: 'remove', - })); - - state.nodes = applyNodeChanges(nodeChanges, state.nodes); - state.edges = applyEdgeChanges(edgeChanges, state.edges); - }, undo: (state) => state, redo: (state) => state, }, @@ -414,7 +397,6 @@ export const { nodesChanged, nodeUseCacheChanged, notesNodeValueChanged, - selectionDeleted, undo, redo, } = nodesSlice.actions; @@ -530,6 +512,5 @@ export const isAnyNodeOrEdgeMutation = isAnyOf( nodeLabelChanged, nodeNotesChanged, nodeUseCacheChanged, - notesNodeValueChanged, - selectionDeleted + notesNodeValueChanged ); From a51142674a3b1a435aceca9a6e435dce92a76cf5 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 14:57:54 +1000 Subject: [PATCH 052/207] tidy(ui): more succinct syntax for edge and node updates --- .../flow/AddNodePopover/AddNodePopover.tsx | 8 ++++---- .../features/nodes/components/flow/Flow.tsx | 8 ++++---- .../src/features/nodes/hooks/useCopyPaste.ts | 12 ++++++------ .../web/src/features/nodes/store/nodesSlice.ts | 18 +++++++++--------- 4 files changed, 23 insertions(+), 23 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx index 6e695561a2..357514f380 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx @@ -143,11 +143,11 @@ const AddNodePopover = () => { // Deselect all other nodes and edges const nodeChanges: NodeChange[] = [{ type: 'add', item: node }]; const edgeChanges: EdgeChange[] = []; - nodes.forEach((n) => { - nodeChanges.push({ id: n.id, type: 'select', selected: false }); + nodes.forEach(({ id }) => { + nodeChanges.push({ type: 'select', id, selected: false }); }); - edges.forEach((e) => { - edgeChanges.push({ id: e.id, type: 'select', selected: false }); + edges.forEach(({ id }) => { + edgeChanges.push({ type: 'select', id, selected: false }); }); // Onwards! diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx index 75983b1617..8e67758d62 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx @@ -207,11 +207,11 @@ export const Flow = memo(() => { const { nodes, edges } = store.getState().nodes.present; const nodeChanges: NodeChange[] = []; const edgeChanges: EdgeChange[] = []; - nodes.forEach((n) => { - nodeChanges.push({ id: n.id, type: 'select', selected: true }); + nodes.forEach(({ id }) => { + nodeChanges.push({ type: 'select', id, selected: true }); }); - edges.forEach((e) => { - edgeChanges.push({ id: e.id, type: 'select', selected: true }); + edges.forEach(({ id }) => { + edgeChanges.push({ type: 'select', id, selected: true }); }); dispatch(nodesChanged(nodeChanges)); dispatch(edgesChanged(edgeChanges)); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useCopyPaste.ts b/invokeai/frontend/web/src/features/nodes/hooks/useCopyPaste.ts index 8be972363f..4ca331d61b 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useCopyPaste.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useCopyPaste.ts @@ -73,33 +73,33 @@ const pasteSelection = (withEdgesToCopiedNodes?: boolean) => { const nodeChanges: NodeChange[] = []; const edgeChanges: EdgeChange[] = []; // Deselect existing nodes - nodes.forEach((n) => { + nodes.forEach(({ id }) => { nodeChanges.push({ - id: n.data.id, type: 'select', + id, selected: false, }); }); // Add new nodes copiedNodes.forEach((n) => { nodeChanges.push({ - item: n, type: 'add', + item: n, }); }); // Deselect existing edges - edges.forEach((e) => { + edges.forEach(({ id }) => { edgeChanges.push({ - id: e.id, type: 'select', + id, selected: false, }); }); // Add new edges copiedEdges.forEach((e) => { edgeChanges.push({ - item: e, type: 'add', + item: e, }); }); dispatch(nodesChanged(nodeChanges)); diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 70ac801009..3f8e76825c 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -103,13 +103,13 @@ export const nodesSlice = createSlice({ if (edge && edge.type === 'collapsed') { const hiddenEdges = state.edges.filter((e) => e.source === edge.source && e.target === edge.target); if (change.type === 'remove') { - hiddenEdges.forEach((e) => { - changes.push({ type: 'remove', id: e.id }); + hiddenEdges.forEach(({ id }) => { + changes.push({ type: 'remove', id }); }); } if (change.type === 'select') { - hiddenEdges.forEach((e) => { - changes.push({ type: 'select', id: e.id, selected: change.selected }); + hiddenEdges.forEach(({ id }) => { + changes.push({ type: 'select', id, selected: change.selected }); }); } } @@ -275,10 +275,10 @@ export const nodesSlice = createSlice({ nodeExclusivelySelected: (state, action: PayloadAction) => { const nodeId = action.payload; state.nodes = applyNodeChanges( - state.nodes.map((n) => ({ - id: n.id, + state.nodes.map(({ id }) => ({ type: 'select', - selected: n.id === nodeId ? true : false, + id, + selected: id === nodeId ? true : false, })), state.nodes ); @@ -355,13 +355,13 @@ export const nodesSlice = createSlice({ const { nodes, edges } = action.payload; state.nodes = applyNodeChanges( nodes.map((node) => ({ - item: { ...node, ...SHARED_NODE_PROPERTIES }, type: 'add', + item: { ...node, ...SHARED_NODE_PROPERTIES }, })), [] ); state.edges = applyEdgeChanges( - edges.map((edge) => ({ item: edge, type: 'add' })), + edges.map((edge) => ({ type: 'add', item: edge })), [] ); }); From 0b5696c5d4ce5df220f28cdc258e4febe5317d55 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 15:01:30 +1000 Subject: [PATCH 053/207] feat(ui): remove nodeExclusivelySelected action --- .../flow/nodes/common/NodeWrapper.tsx | 12 ++++++++---- .../web/src/features/nodes/store/nodesSlice.ts | 17 ----------------- 2 files changed, 8 insertions(+), 21 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeWrapper.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeWrapper.tsx index 57426982ef..a0260c7301 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeWrapper.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeWrapper.tsx @@ -1,14 +1,15 @@ import type { ChakraProps } from '@invoke-ai/ui-library'; import { Box, useGlobalMenuClose, useToken } from '@invoke-ai/ui-library'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks'; import NodeSelectionOverlay from 'common/components/NodeSelectionOverlay'; import { useExecutionState } from 'features/nodes/hooks/useExecutionState'; import { useMouseOverNode } from 'features/nodes/hooks/useMouseOverNode'; -import { nodeExclusivelySelected } from 'features/nodes/store/nodesSlice'; +import { nodesChanged } from 'features/nodes/store/nodesSlice'; import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from 'features/nodes/types/constants'; import { zNodeStatus } from 'features/nodes/types/invocation'; import type { MouseEvent, PropsWithChildren } from 'react'; import { memo, useCallback } from 'react'; +import type { NodeChange } from 'reactflow'; type NodeWrapperProps = PropsWithChildren & { nodeId: string; @@ -18,6 +19,7 @@ type NodeWrapperProps = PropsWithChildren & { const NodeWrapper = (props: NodeWrapperProps) => { const { nodeId, width, children, selected } = props; + const store = useAppStore(); const { isMouseOverNode, handleMouseOut, handleMouseOver } = useMouseOverNode(nodeId); const executionState = useExecutionState(nodeId); @@ -37,11 +39,13 @@ const NodeWrapper = (props: NodeWrapperProps) => { const handleClick = useCallback( (e: MouseEvent) => { if (!e.ctrlKey && !e.altKey && !e.metaKey && !e.shiftKey) { - dispatch(nodeExclusivelySelected(nodeId)); + const { nodes } = store.getState().nodes.present; + const nodeChanges: NodeChange[] = nodes.map(({ id }) => ({ type: 'select', id, selected: id === nodeId })); + dispatch(nodesChanged(nodeChanges)); } onCloseGlobal(); }, - [dispatch, onCloseGlobal, nodeId] + [onCloseGlobal, store, dispatch, nodeId] ); return ( diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 3f8e76825c..9cc641769c 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -272,17 +272,6 @@ export const nodesSlice = createSlice({ } node.data.notes = notes; }, - nodeExclusivelySelected: (state, action: PayloadAction) => { - const nodeId = action.payload; - state.nodes = applyNodeChanges( - state.nodes.map(({ id }) => ({ - type: 'select', - id, - selected: id === nodeId ? true : false, - })), - state.nodes - ); - }, fieldValueReset: (state, action: FieldValueAction) => { fieldValueReducer(state, action, zStatefulFieldValue); }, @@ -389,7 +378,6 @@ export const { fieldStringValueChanged, fieldVaeModelValueChanged, nodeEditorReset, - nodeExclusivelySelected, nodeIsIntermediateChanged, nodeIsOpenChanged, nodeLabelChanged, @@ -438,12 +426,7 @@ export const nodesPersistConfig: PersistConfig = { persistDenylist: [], }; -const selectionMatcher = isAnyOf(nodeExclusivelySelected); - const isSelectionAction = (action: UnknownAction) => { - if (selectionMatcher(action)) { - return true; - } if (nodesChanged.match(action)) { if (action.payload.every((change) => change.type === 'select')) { return true; From 9ed5698aa8e7bbb3afc312c891a51f06ff4bbc64 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 15:11:51 +1000 Subject: [PATCH 054/207] fix(ui): do not remove exposed fields when updating workflows --- .../src/features/nodes/store/workflowSlice.ts | 26 +++++++++++++++---- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts b/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts index b3ec4f0614..0d358f56e4 100644 --- a/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/workflowSlice.ts @@ -142,13 +142,29 @@ export const workflowSlice = createSlice({ builder.addCase(nodeEditorReset, () => deepClone(initialWorkflowState)); builder.addCase(nodesChanged, (state, action) => { - // Not all changes to nodes should result in the workflow being marked touched - action.payload.forEach((change) => { - if (change.type === 'remove') { - state.exposedFields = state.exposedFields.filter((f) => f.nodeId !== change.id); + // If a node was removed, we should remove any exposed fields that were associated with it. However, node changes + // may remove and then add the same node back. For example, when updating a workflow, we replace old nodes with + // updated nodes. In this case, we should not remove the exposed fields. To handle this, we find the last remove + // and add changes for each exposed field. If the remove change comes after the add change, we remove the exposed + // field. + const exposedFieldsToRemove: FieldIdentifier[] = []; + state.exposedFields.forEach((field) => { + const removeIndex = action.payload.findLastIndex( + (change) => change.type === 'remove' && change.id === field.nodeId + ); + const addIndex = action.payload.findLastIndex( + (change) => change.type === 'add' && change.item.id === field.nodeId + ); + if (removeIndex > addIndex) { + exposedFieldsToRemove.push({ nodeId: field.nodeId, fieldName: field.fieldName }); } }); + state.exposedFields = state.exposedFields.filter( + (field) => !exposedFieldsToRemove.some((f) => isEqual(f, field)) + ); + + // Not all changes to nodes should result in the workflow being marked touched const filteredChanges = action.payload.filter((change) => { // We always want to mark the workflow as touched if a node is added, removed, or reset if (['add', 'remove', 'reset'].includes(change.type)) { @@ -165,7 +181,7 @@ export const workflowSlice = createSlice({ return false; }); - if (filteredChanges.length > 0) { + if (filteredChanges.length > 0 || exposedFieldsToRemove.length > 0) { state.isTouched = true; } }); From 059c5586a48628fb37ce654392e656087f6b8a6d Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 15:43:14 +1000 Subject: [PATCH 055/207] perf(ui): ignore all no-op node and edge changes --- .../flow/AddNodePopover/AddNodePopover.tsx | 20 +++++++---- .../features/nodes/components/flow/Flow.tsx | 36 +++++++++++++------ .../flow/nodes/common/NodeWrapper.tsx | 11 ++++-- .../src/features/nodes/hooks/useConnection.ts | 6 ++-- .../src/features/nodes/hooks/useCopyPaste.ts | 36 +++++++++++-------- 5 files changed, 73 insertions(+), 36 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx index 357514f380..226a8f006d 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx @@ -143,16 +143,24 @@ const AddNodePopover = () => { // Deselect all other nodes and edges const nodeChanges: NodeChange[] = [{ type: 'add', item: node }]; const edgeChanges: EdgeChange[] = []; - nodes.forEach(({ id }) => { - nodeChanges.push({ type: 'select', id, selected: false }); + nodes.forEach(({ id, selected }) => { + if (selected) { + nodeChanges.push({ type: 'select', id, selected: false }); + } }); - edges.forEach(({ id }) => { - edgeChanges.push({ type: 'select', id, selected: false }); + edges.forEach(({ id, selected }) => { + if (selected) { + edgeChanges.push({ type: 'select', id, selected: false }); + } }); // Onwards! - dispatch(nodesChanged(nodeChanges)); - dispatch(edgesChanged(edgeChanges)); + if (nodeChanges.length > 0) { + dispatch(nodesChanged(nodeChanges)); + } + if (edgeChanges.length > 0) { + dispatch(edgesChanged(edgeChanges)); + } return node; }, [buildInvocation, store, dispatch, t, toaster] diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx index 8e67758d62..19f56b7747 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx @@ -90,15 +90,17 @@ export const Flow = memo(() => { ); const onNodesChange: OnNodesChange = useCallback( - (changes) => { - dispatch(nodesChanged(changes)); + (nodeChanges) => { + dispatch(nodesChanged(nodeChanges)); }, [dispatch] ); const onEdgesChange: OnEdgesChange = useCallback( (changes) => { - dispatch(edgesChanged(changes)); + if (changes.length > 0) { + dispatch(edgesChanged(changes)); + } }, [dispatch] ); @@ -207,14 +209,22 @@ export const Flow = memo(() => { const { nodes, edges } = store.getState().nodes.present; const nodeChanges: NodeChange[] = []; const edgeChanges: EdgeChange[] = []; - nodes.forEach(({ id }) => { - nodeChanges.push({ type: 'select', id, selected: true }); + nodes.forEach(({ id, selected }) => { + if (!selected) { + nodeChanges.push({ type: 'select', id, selected: true }); + } }); - edges.forEach(({ id }) => { - edgeChanges.push({ type: 'select', id, selected: true }); + edges.forEach(({ id, selected }) => { + if (!selected) { + edgeChanges.push({ type: 'select', id, selected: true }); + } }); - dispatch(nodesChanged(nodeChanges)); - dispatch(edgesChanged(edgeChanges)); + if (nodeChanges.length > 0) { + dispatch(nodesChanged(nodeChanges)); + } + if (edgeChanges.length > 0) { + dispatch(edgesChanged(edgeChanges)); + } }, [dispatch, store] ); @@ -275,8 +285,12 @@ export const Flow = memo(() => { .forEach(({ id }) => { edgeChanges.push({ type: 'remove', id }); }); - dispatch(nodesChanged(nodeChanges)); - dispatch(edgesChanged(edgeChanges)); + if (nodeChanges.length > 0) { + dispatch(nodesChanged(nodeChanges)); + } + if (edgeChanges.length > 0) { + dispatch(edgesChanged(edgeChanges)); + } }, [dispatch, store]); useHotkeys(['delete', 'backspace'], onDeleteHotkey); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeWrapper.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeWrapper.tsx index a0260c7301..983aee1d48 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeWrapper.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeWrapper.tsx @@ -40,8 +40,15 @@ const NodeWrapper = (props: NodeWrapperProps) => { (e: MouseEvent) => { if (!e.ctrlKey && !e.altKey && !e.metaKey && !e.shiftKey) { const { nodes } = store.getState().nodes.present; - const nodeChanges: NodeChange[] = nodes.map(({ id }) => ({ type: 'select', id, selected: id === nodeId })); - dispatch(nodesChanged(nodeChanges)); + const nodeChanges: NodeChange[] = []; + nodes.forEach(({ id, selected }) => { + if (selected !== (id === nodeId)) { + nodeChanges.push({ type: 'select', id, selected: id === nodeId }); + } + }); + if (nodeChanges.length > 0) { + dispatch(nodesChanged(nodeChanges)); + } } onCloseGlobal(); }, diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts index 36491e80bc..0bca73731e 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts @@ -92,15 +92,15 @@ export const useConnection = () => { ); if (connection) { const newEdge = connectionToEdge(connection); - const changes: EdgeChange[] = [{ type: 'add', item: newEdge }]; + const edgeChanges: EdgeChange[] = [{ type: 'add', item: newEdge }]; const nodesToUpdate = [newEdge.source, newEdge.target]; if (edgePendingUpdate) { $didUpdateEdge.set(true); - changes.push({ type: 'remove', id: edgePendingUpdate.id }); + edgeChanges.push({ type: 'remove', id: edgePendingUpdate.id }); nodesToUpdate.push(edgePendingUpdate.source, edgePendingUpdate.target); } - dispatch(edgesChanged(changes)); + dispatch(edgesChanged(edgeChanges)); updateNodeInternals(nodesToUpdate); } $pendingConnection.set(null); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useCopyPaste.ts b/invokeai/frontend/web/src/features/nodes/hooks/useCopyPaste.ts index 4ca331d61b..32db806cde 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useCopyPaste.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useCopyPaste.ts @@ -73,12 +73,14 @@ const pasteSelection = (withEdgesToCopiedNodes?: boolean) => { const nodeChanges: NodeChange[] = []; const edgeChanges: EdgeChange[] = []; // Deselect existing nodes - nodes.forEach(({ id }) => { - nodeChanges.push({ - type: 'select', - id, - selected: false, - }); + nodes.forEach(({ id, selected }) => { + if (selected) { + nodeChanges.push({ + type: 'select', + id, + selected: false, + }); + } }); // Add new nodes copiedNodes.forEach((n) => { @@ -88,12 +90,14 @@ const pasteSelection = (withEdgesToCopiedNodes?: boolean) => { }); }); // Deselect existing edges - edges.forEach(({ id }) => { - edgeChanges.push({ - type: 'select', - id, - selected: false, - }); + edges.forEach(({ id, selected }) => { + if (selected) { + edgeChanges.push({ + type: 'select', + id, + selected: false, + }); + } }); // Add new edges copiedEdges.forEach((e) => { @@ -102,8 +106,12 @@ const pasteSelection = (withEdgesToCopiedNodes?: boolean) => { item: e, }); }); - dispatch(nodesChanged(nodeChanges)); - dispatch(edgesChanged(edgeChanges)); + if (nodeChanges.length > 0) { + dispatch(nodesChanged(nodeChanges)); + } + if (edgeChanges.length > 0) { + dispatch(edgesChanged(edgeChanges)); + } }; const api = { copySelection, pasteSelection }; From 26d0d55d9729f9734b0c0920d2c9466f8d0e6661 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 15:51:23 +1000 Subject: [PATCH 056/207] fix(ui): set nodeDragThreshold to prevent spurious position change events --- .../frontend/web/src/features/nodes/components/flow/Flow.tsx | 1 + 1 file changed, 1 insertion(+) diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx index 19f56b7747..1748989394 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx @@ -326,6 +326,7 @@ export const Flow = memo(() => { deleteKeyCode={null} selectionMode={selectionMode} elevateEdgesOnSelect + nodeDragThreshold={1} > From 89b0e9e4de57c1217f4bafcf88a38bb8fec6f377 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 17:07:55 +1000 Subject: [PATCH 057/207] feat(ui): use connection validationResults directly in components --- .../nodes/Invocation/fields/FieldHandle.tsx | 19 +++++++++++-------- .../nodes/Invocation/fields/InputField.tsx | 6 +++--- .../nodes/Invocation/fields/OutputField.tsx | 4 ++-- .../nodes/hooks/useConnectionState.ts | 10 +++++----- .../store/util/makeConnectionErrorSelector.ts | 13 +++++-------- .../nodes/store/util/validateConnection.ts | 8 ++++---- 6 files changed, 30 insertions(+), 30 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx index 959b13c2d0..033aa61bdf 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx @@ -2,10 +2,12 @@ import { Tooltip } from '@invoke-ai/ui-library'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdgeColor'; import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType'; +import type { ValidationResult } from 'features/nodes/store/util/validateConnection'; import { HANDLE_TOOLTIP_OPEN_DELAY, MODEL_TYPES } from 'features/nodes/types/constants'; import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field'; import type { CSSProperties } from 'react'; import { memo, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; import type { HandleType } from 'reactflow'; import { Handle, Position } from 'reactflow'; @@ -14,11 +16,12 @@ type FieldHandleProps = { handleType: HandleType; isConnectionInProgress: boolean; isConnectionStartField: boolean; - connectionError?: string; + validationResult: ValidationResult; }; const FieldHandle = (props: FieldHandleProps) => { - const { fieldTemplate, handleType, isConnectionInProgress, isConnectionStartField, connectionError } = props; + const { fieldTemplate, handleType, isConnectionInProgress, isConnectionStartField, validationResult } = props; + const { t } = useTranslation(); const { name } = fieldTemplate; const type = fieldTemplate.type; const fieldTypeName = useFieldTypeName(type); @@ -43,11 +46,11 @@ const FieldHandle = (props: FieldHandleProps) => { s.insetInlineEnd = '-1rem'; } - if (isConnectionInProgress && !isConnectionStartField && connectionError) { + if (isConnectionInProgress && !isConnectionStartField && !validationResult.isValid) { s.filter = 'opacity(0.4) grayscale(0.7)'; } - if (isConnectionInProgress && connectionError) { + if (isConnectionInProgress && !validationResult.isValid) { if (isConnectionStartField) { s.cursor = 'grab'; } else { @@ -58,14 +61,14 @@ const FieldHandle = (props: FieldHandleProps) => { } return s; - }, [connectionError, handleType, isConnectionInProgress, isConnectionStartField, type]); + }, [handleType, isConnectionInProgress, isConnectionStartField, type, validationResult.isValid]); const tooltip = useMemo(() => { - if (isConnectionInProgress && connectionError) { - return connectionError; + if (isConnectionInProgress && validationResult.messageTKey) { + return t(validationResult.messageTKey); } return fieldTypeName; - }, [connectionError, fieldTypeName, isConnectionInProgress]); + }, [fieldTypeName, isConnectionInProgress, t, validationResult.messageTKey]); return ( { const doesFieldHaveValue = useDoesInputHaveValue(nodeId, fieldName); const [isHovered, setIsHovered] = useState(false); - const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } = + const { isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim } = useConnectionState({ nodeId, fieldName, kind: 'inputs' }); const isMissingInput = useMemo(() => { @@ -88,7 +88,7 @@ const InputField = ({ nodeId, fieldName }: Props) => { handleType="target" isConnectionInProgress={isConnectionInProgress} isConnectionStartField={isConnectionStartField} - connectionError={connectionError} + validationResult={validationResult} /> ); @@ -126,7 +126,7 @@ const InputField = ({ nodeId, fieldName }: Props) => { handleType="target" isConnectionInProgress={isConnectionInProgress} isConnectionStartField={isConnectionStartField} - connectionError={connectionError} + validationResult={validationResult} /> )} diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx index f2d776a2da..94e8b62744 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx @@ -18,7 +18,7 @@ const OutputField = ({ nodeId, fieldName }: Props) => { const { t } = useTranslation(); const fieldTemplate = useFieldOutputTemplate(nodeId, fieldName); - const { isConnected, isConnectionInProgress, isConnectionStartField, connectionError, shouldDim } = + const { isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim } = useConnectionState({ nodeId, fieldName, kind: 'outputs' }); if (!fieldTemplate) { @@ -52,7 +52,7 @@ const OutputField = ({ nodeId, fieldName }: Props) => { handleType="source" isConnectionInProgress={isConnectionInProgress} isConnectionStartField={isConnectionStartField} - connectionError={connectionError} + validationResult={validationResult} /> ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts index d218734fff..64bb72c54e 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts @@ -31,7 +31,7 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta [fieldName, kind, nodeId] ); - const selectConnectionError = useMemo( + const selectValidationResult = useMemo( () => makeConnectionErrorSelector(templates, nodeId, fieldName, kind === 'inputs' ? 'target' : 'source'), [templates, nodeId, fieldName, kind] ); @@ -48,18 +48,18 @@ export const useConnectionState = ({ nodeId, fieldName, kind }: UseConnectionSta pendingConnection.fieldTemplate.fieldKind === { inputs: 'input', outputs: 'output' }[kind] ); }, [fieldName, kind, nodeId, pendingConnection]); - const connectionError = useAppSelector((s) => selectConnectionError(s, pendingConnection, edgePendingUpdate)); + const validationResult = useAppSelector((s) => selectValidationResult(s, pendingConnection, edgePendingUpdate)); const shouldDim = useMemo( - () => Boolean(isConnectionInProgress && connectionError && !isConnectionStartField), - [connectionError, isConnectionInProgress, isConnectionStartField] + () => Boolean(isConnectionInProgress && !validationResult.isValid && !isConnectionStartField), + [validationResult, isConnectionInProgress, isConnectionStartField] ); return { isConnected, isConnectionInProgress, isConnectionStartField, - connectionError, + validationResult, shouldDim, }; }; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts index c6d05d2c7c..ec607c60c5 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts @@ -2,8 +2,7 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import type { RootState } from 'app/store/store'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types'; -import { validateConnection } from 'features/nodes/store/util/validateConnection'; -import i18n from 'i18next'; +import { buildRejectResult, validateConnection } from 'features/nodes/store/util/validateConnection'; import type { Edge, HandleType } from 'reactflow'; /** @@ -33,14 +32,14 @@ export const makeConnectionErrorSelector = ( const { nodes, edges } = nodesSlice; if (!pendingConnection) { - return i18n.t('nodes.noConnectionInProgress'); + return buildRejectResult('nodes.noConnectionInProgress'); } if (handleType === pendingConnection.handleType) { if (handleType === 'source') { - return i18n.t('nodes.cannotConnectOutputToOutput'); + return buildRejectResult('nodes.cannotConnectOutputToOutput'); } - return i18n.t('nodes.cannotConnectInputToInput'); + return buildRejectResult('nodes.cannotConnectInputToInput'); } // we have to figure out which is the target and which is the source @@ -62,9 +61,7 @@ export const makeConnectionErrorSelector = ( edgePendingUpdate ); - if (!validationResult.isValid) { - return i18n.t(validationResult.messageTKey); - } + return validationResult; } ); }; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts index 56e45c0d80..8ece852b07 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts @@ -9,7 +9,7 @@ import type { O } from 'ts-toolbelt'; type Connection = O.NonNullable; -type ValidateConnectionResult = +export type ValidationResult = | { isValid: true; messageTKey?: string; @@ -26,7 +26,7 @@ type ValidateConnectionFunc = ( templates: Templates, ignoreEdge: Edge | null, strict?: boolean -) => ValidateConnectionResult; +) => ValidationResult; const getEqualityPredicate = (c: Connection) => @@ -45,8 +45,8 @@ const getTargetEqualityPredicate = return e.target === c.target && e.targetHandle === c.targetHandle; }; -export const buildAcceptResult = (): ValidateConnectionResult => ({ isValid: true }); -export const buildRejectResult = (messageTKey: string): ValidateConnectionResult => ({ isValid: false, messageTKey }); +export const buildAcceptResult = (): ValidationResult => ({ isValid: true }); +export const buildRejectResult = (messageTKey: string): ValidationResult => ({ isValid: false, messageTKey }); export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, templates, ignoreEdge, strict = true) => { if (c.source === c.target) { From cea1874e009361383a766327262185f7c83c4ed3 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 17:13:21 +1000 Subject: [PATCH 058/207] perf(ui): memoize WorkflowName selectors --- .../src/features/nodes/components/sidePanel/WorkflowName.tsx | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/WorkflowName.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/WorkflowName.tsx index 14852945ab..b983e12e11 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/WorkflowName.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/WorkflowName.tsx @@ -7,8 +7,10 @@ import WorkflowInfoTooltipContent from './viewMode/WorkflowInfoTooltipContent'; import { WorkflowWarning } from './viewMode/WorkflowWarning'; export const WorkflowName = () => { - const { name, isTouched, mode } = useAppSelector((s) => s.workflow); const { t } = useTranslation(); + const name = useAppSelector((s) => s.workflow.name); + const isTouched = useAppSelector((s) => s.workflow.isTouched); + const mode = useAppSelector((s) => s.workflow.mode); return ( From 281bd31db2b1b9bd10c79c17433fc007c513facc Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 19:26:17 +1000 Subject: [PATCH 059/207] feat(nodes): make ModelIdentifierInvocation a prototype --- invokeai/app/invocations/model.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 6f78cf43bf..94a6136fcb 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -11,6 +11,7 @@ from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, + Classification, invocation, invocation_output, ) @@ -106,9 +107,12 @@ class ModelIdentifierOutput(BaseInvocationOutput): tags=["model"], category="model", version="1.0.0", + classification=Classification.Prototype, ) class ModelIdentifierInvocation(BaseInvocation): - """Selects any model, outputting it.""" + """Selects any model, outputting it its identifier. Be careful with this one! The identifier will be accepted as + input for any model, even if the model types don't match. If you connect this to a mismatched input, you'll get an + error.""" model: ModelIdentifierField = InputField(description="The model to select", title="Model") From e2f109807c4ca29ae729a4e52b014fcb1ab4b652 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 19:37:26 +1000 Subject: [PATCH 060/207] fix(ui): delete edges when their source or target no longer exists --- .../web/src/features/nodes/store/nodesSlice.ts | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 9cc641769c..c63734c871 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -93,6 +93,16 @@ export const nodesSlice = createSlice({ reducers: { nodesChanged: (state, action: PayloadAction) => { state.nodes = applyNodeChanges(action.payload, state.nodes); + // Remove edges that are no longer valid, due to a removed or otherwise changed node + const edgeChanges: EdgeChange[] = []; + state.edges.forEach((e) => { + const sourceExists = state.nodes.some((n) => n.id === e.source); + const targetExists = state.nodes.some((n) => n.id === e.target); + if (!(sourceExists && targetExists)) { + edgeChanges.push({ type: 'remove', id: e.id }); + } + }); + state.edges = applyEdgeChanges(edgeChanges, state.edges); }, edgesChanged: (state, action: PayloadAction) => { const changes = deepClone(action.payload); From ca186bca614eff10f256fb3ac706229fda744700 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 19:44:48 +1000 Subject: [PATCH 061/207] fix(ui): missed node execution state for progress images --- .../listeners/socketio/socketGeneratorProgress.ts | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts index 2dd598396a..e0c6d4f33d 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketGeneratorProgress.ts @@ -1,7 +1,7 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { deepClone } from 'common/util/deepClone'; -import { $nodeExecutionStates } from 'features/nodes/hooks/useExecutionState'; +import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useExecutionState'; import { zNodeStatus } from 'features/nodes/types/invocation'; import { socketGeneratorProgress } from 'services/events/actions'; @@ -18,6 +18,7 @@ export const addGeneratorProgressEventListener = (startAppListening: AppStartLis nes.status = zNodeStatus.enum.IN_PROGRESS; nes.progress = (step + 1) / total_steps; nes.progressImage = progress_image ?? null; + upsertExecutionState(nes.nodeId, nes); } }, }); From ba8bed68708b99b65184bfcb7dea149228318136 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 23:44:07 +1000 Subject: [PATCH 062/207] fix(ui): edge case resulting in no node templates when loading workflow, causing failure Depending on the user behaviour and network conditions, it's possible that we could try to load a workflow before the invocation templates are available. Fix is simple: - Use the RTKQ query hook for openAPI schema in App.tsx - Disable the load workflow buttons until w have templates parsed --- invokeai/frontend/web/src/app/components/App.tsx | 2 ++ .../ImageContextMenu/SingleSelectionMenuItems.tsx | 5 ++++- .../gallery/components/ImageViewer/CurrentImageButtons.tsx | 7 +++++-- .../WorkflowLibraryMenu/LoadWorkflowFromGraphMenuItem.tsx | 6 +++++- 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/invokeai/frontend/web/src/app/components/App.tsx b/invokeai/frontend/web/src/app/components/App.tsx index 30d8f41200..1ff093f348 100644 --- a/invokeai/frontend/web/src/app/components/App.tsx +++ b/invokeai/frontend/web/src/app/components/App.tsx @@ -21,6 +21,7 @@ import i18n from 'i18n'; import { size } from 'lodash-es'; import { memo, useCallback, useEffect } from 'react'; import { ErrorBoundary } from 'react-error-boundary'; +import { useGetOpenAPISchemaQuery } from 'services/api/endpoints/appInfo'; import AppErrorBoundaryFallback from './AppErrorBoundaryFallback'; import PreselectedImage from './PreselectedImage'; @@ -46,6 +47,7 @@ const App = ({ config = DEFAULT_CONFIG, selectedImage }: Props) => { useSocketIO(); useGlobalModifiersInit(); useGlobalHotkeys(); + useGetOpenAPISchemaQuery(); const { dropzone, isHandlingUpload, setIsHandlingUpload } = useFullscreenDropzone(); diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx index a25f6d8c0e..f5063ea717 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageContextMenu/SingleSelectionMenuItems.tsx @@ -11,10 +11,12 @@ import { iiLayerAdded } from 'features/controlLayers/store/controlLayersSlice'; import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice'; import { useImageActions } from 'features/gallery/hooks/useImageActions'; import { sentImageToCanvas, sentImageToImg2Img } from 'features/gallery/store/actions'; +import { $templates } from 'features/nodes/store/nodesSlice'; import { selectOptimalDimension } from 'features/parameters/store/generationSlice'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { setActiveTab } from 'features/ui/store/uiSlice'; import { useGetAndLoadEmbeddedWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadEmbeddedWorkflow'; +import { size } from 'lodash-es'; import { memo, useCallback } from 'react'; import { flushSync } from 'react-dom'; import { useTranslation } from 'react-i18next'; @@ -48,6 +50,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => { const isCanvasEnabled = useFeatureStatus('canvas'); const customStarUi = useStore($customStarUI); const { downloadImage } = useDownloadImage(); + const templates = useStore($templates); const { recallAll, remix, recallSeed, recallPrompts, hasMetadata, hasSeed, hasPrompts, isLoadingMetadata } = useImageActions(imageDTO?.image_name); @@ -133,7 +136,7 @@ const SingleSelectionMenuItems = (props: SingleSelectionMenuItemsProps) => { : } onClickCapture={handleLoadWorkflow} - isDisabled={!imageDTO.has_workflow} + isDisabled={!imageDTO.has_workflow || !size(templates)} > {t('nodes.loadWorkflow')} diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImageButtons.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImageButtons.tsx index ada9c35d28..d500d692fe 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImageButtons.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageViewer/CurrentImageButtons.tsx @@ -1,4 +1,5 @@ import { ButtonGroup, IconButton, Menu, MenuButton, MenuList } from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; import { createSelector } from '@reduxjs/toolkit'; import { skipToken } from '@reduxjs/toolkit/query'; import { upscaleRequested } from 'app/store/middleware/listenerMiddleware/listeners/upscaleRequested'; @@ -12,12 +13,14 @@ import { sentImageToImg2Img } from 'features/gallery/store/actions'; import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors'; import { selectGallerySlice } from 'features/gallery/store/gallerySlice'; import { parseAndRecallImageDimensions } from 'features/metadata/util/handlers'; +import { $templates } from 'features/nodes/store/nodesSlice'; import ParamUpscalePopover from 'features/parameters/components/Upscale/ParamUpscaleSettings'; import { useIsQueueMutationInProgress } from 'features/queue/hooks/useIsQueueMutationInProgress'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import { selectSystemSlice } from 'features/system/store/systemSlice'; import { setActiveTab } from 'features/ui/store/uiSlice'; import { useGetAndLoadEmbeddedWorkflow } from 'features/workflowLibrary/hooks/useGetAndLoadEmbeddedWorkflow'; +import { size } from 'lodash-es'; import { memo, useCallback } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; import { useTranslation } from 'react-i18next'; @@ -48,7 +51,7 @@ const CurrentImageButtons = () => { const lastSelectedImage = useAppSelector(selectLastSelectedImage); const selection = useAppSelector((s) => s.gallery.selection); const shouldDisableToolbarButtons = useAppSelector(selectShouldDisableToolbarButtons); - + const templates = useStore($templates); const isUpscalingEnabled = useFeatureStatus('upscaling'); const isQueueMutationInProgress = useIsQueueMutationInProgress(); const { t } = useTranslation(); @@ -143,7 +146,7 @@ const CurrentImageButtons = () => { icon={} tooltip={`${t('nodes.loadWorkflow')} (W)`} aria-label={`${t('nodes.loadWorkflow')} (W)`} - isDisabled={!imageDTO?.has_workflow} + isDisabled={!imageDTO?.has_workflow || !size(templates)} onClick={handleLoadWorkflow} isLoading={getAndLoadEmbeddedWorkflowResult.isLoading} /> diff --git a/invokeai/frontend/web/src/features/workflowLibrary/components/WorkflowLibraryMenu/LoadWorkflowFromGraphMenuItem.tsx b/invokeai/frontend/web/src/features/workflowLibrary/components/WorkflowLibraryMenu/LoadWorkflowFromGraphMenuItem.tsx index 8f3cb0c6f6..8006ca937f 100644 --- a/invokeai/frontend/web/src/features/workflowLibrary/components/WorkflowLibraryMenu/LoadWorkflowFromGraphMenuItem.tsx +++ b/invokeai/frontend/web/src/features/workflowLibrary/components/WorkflowLibraryMenu/LoadWorkflowFromGraphMenuItem.tsx @@ -1,15 +1,19 @@ import { MenuItem } from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; +import { $templates } from 'features/nodes/store/nodesSlice'; import { useLoadWorkflowFromGraphModal } from 'features/workflowLibrary/components/LoadWorkflowFromGraphModal/LoadWorkflowFromGraphModal'; +import { size } from 'lodash-es'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; import { PiFlaskBold } from 'react-icons/pi'; const LoadWorkflowFromGraphMenuItem = () => { const { t } = useTranslation(); + const templates = useStore($templates); const { onOpen } = useLoadWorkflowFromGraphModal(); return ( - } onClick={onOpen}> + } onClick={onOpen} isDisabled={!size(templates)}> {t('workflows.loadFromGraph')} ); From ecfff6cb1e1cf672c44bfccfe37e6458290d0260 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 20 May 2024 09:33:50 +1000 Subject: [PATCH 063/207] feat(api): add metadata to upload route Canvas images are saved by uploading a blob generated from the HTML canvas element. This means the existing metadata handling, inside the graph execution engine, is not available. To save metadata to canvas images, we need to provide it when uploading that blob. The upload route now has a `metadata` body param. If this is provided, we use it over any metadata embedded in the image. --- invokeai/app/api/routers/images.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 9c55ff6531..84d4a5d27f 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -6,7 +6,7 @@ from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, from fastapi.responses import FileResponse from fastapi.routing import APIRouter from PIL import Image -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, JsonValue from invokeai.app.invocations.fields import MetadataField from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin @@ -41,14 +41,17 @@ async def upload_image( board_id: Optional[str] = Query(default=None, description="The board to add this image to, if any"), session_id: Optional[str] = Query(default=None, description="The session ID associated with this upload, if any"), crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"), + metadata: Optional[JsonValue] = Body( + default=None, description="The metadata to associate with the image", embed=True + ), ) -> ImageDTO: """Uploads an image""" if not file.content_type or not file.content_type.startswith("image"): raise HTTPException(status_code=415, detail="Not an image") - metadata = None - workflow = None - graph = None + _metadata = None + _workflow = None + _graph = None contents = await file.read() try: @@ -62,9 +65,9 @@ async def upload_image( # TODO: retain non-invokeai metadata on upload? # attempt to parse metadata from image - metadata_raw = pil_image.info.get("invokeai_metadata", None) + metadata_raw = metadata if isinstance(metadata, str) else pil_image.info.get("invokeai_metadata", None) if isinstance(metadata_raw, str): - metadata = metadata_raw + _metadata = metadata_raw else: ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image") pass @@ -72,7 +75,7 @@ async def upload_image( # attempt to parse workflow from image workflow_raw = pil_image.info.get("invokeai_workflow", None) if isinstance(workflow_raw, str): - workflow = workflow_raw + _workflow = workflow_raw else: ApiDependencies.invoker.services.logger.warn("Failed to parse workflow for uploaded image") pass @@ -80,7 +83,7 @@ async def upload_image( # attempt to extract graph from image graph_raw = pil_image.info.get("invokeai_graph", None) if isinstance(graph_raw, str): - graph = graph_raw + _graph = graph_raw else: ApiDependencies.invoker.services.logger.warn("Failed to parse graph for uploaded image") pass @@ -92,9 +95,9 @@ async def upload_image( image_category=image_category, session_id=session_id, board_id=board_id, - metadata=metadata, - workflow=workflow, - graph=graph, + metadata=_metadata, + workflow=_workflow, + graph=_graph, is_intermediate=is_intermediate, ) From a34faf0bd8d036e595e07de0df71d976af5c0088 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 20 May 2024 09:34:01 +1000 Subject: [PATCH 064/207] chore(ui): typegen --- .../frontend/web/src/services/api/schema.ts | 335 +++++++++++------- 1 file changed, 198 insertions(+), 137 deletions(-) diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index c1f9486bc7..cb3d11c06b 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -1175,6 +1175,11 @@ export type components = { * Format: binary */ file: Blob; + /** + * Metadata + * @description The metadata to associate with the image + */ + metadata?: Record | null; }; /** * Boolean Collection Primitive @@ -2542,7 +2547,7 @@ export type components = { /** @description The control image */ image?: components["schemas"]["ImageField"]; /** @description ControlNet model to load */ - control_model: components["schemas"]["ModelIdentifierField"]; + control_model?: components["schemas"]["ModelIdentifierField"]; /** * Control Weight * @description The weight given to the ControlNet @@ -4256,7 +4261,7 @@ export type components = { * @description The nodes in this graph */ nodes: { - [key: string]: components["schemas"]["IdealSizeInvocation"] | components["schemas"]["LoRACollectionLoader"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["LoRASelectorInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["AlphaMaskToTensorInvocation"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["ImageMaskToTensorInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["VAELoaderInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["SDXLLoRACollectionLoader"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["InvertTensorMaskInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["HeuristicResizeInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["MaskFromIDInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["RectangleMaskInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["ImageScaleInvocation"]; + [key: string]: components["schemas"]["CalculateImageTilesInvocation"] | components["schemas"]["CLIPSkipInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["LoRALoaderInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["CropLatentsCoreInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["InvertTensorMaskInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["CreateGradientMaskInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["MergeTilesToImageInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["VAELoaderInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["RectangleMaskInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["HeuristicResizeInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["CenterPadCropInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["UnsharpMaskInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["PairTileImageInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["CanvasPasteBackInvocation"] | components["schemas"]["DepthAnythingImageProcessorInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["ImageMaskToTensorInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["ModelIdentifierInvocation"] | components["schemas"]["AlphaMaskToTensorInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["DWOpenposeImageProcessorInvocation"] | components["schemas"]["MaskFromIDInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["CalculateImageTilesMinimumOverlapInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["LoRACollectionLoader"] | components["schemas"]["SDXLLoRACollectionLoader"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["SDXLLoRALoaderInvocation"] | components["schemas"]["IdealSizeInvocation"] | components["schemas"]["TileToPropertiesInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["LoRASelectorInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["CalculateImageTilesEvenSplitInvocation"]; }; /** * Edges @@ -4293,7 +4298,7 @@ export type components = { * @description The results of node executions */ results: { - [key: string]: components["schemas"]["LoRALoaderOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["LoRASelectorOutput"] | components["schemas"]["String2Output"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["CLIPSkipInvocationOutput"]; + [key: string]: components["schemas"]["NoiseOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["TileToPropertiesOutput"] | components["schemas"]["CLIPSkipInvocationOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["PairTileImageOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["SDXLLoRALoaderOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["CalculateImageTilesOutput"] | components["schemas"]["IdealSizeOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["ModelIdentifierOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["GradientMaskOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["String2Output"] | components["schemas"]["LoRALoaderOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["LoRASelectorOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["MaskOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["MetadataOutput"]; }; /** * Errors @@ -4635,7 +4640,7 @@ export type components = { * IP-Adapter Model * @description The IP-Adapter model. */ - ip_adapter_model: components["schemas"]["ModelIdentifierField"]; + ip_adapter_model?: components["schemas"]["ModelIdentifierField"]; /** * Clip Vision Model * @description CLIP Vision model to use. Overrides model settings. Mandatory for checkpoint models. @@ -6926,7 +6931,7 @@ export type components = { * LoRA * @description LoRA model to load */ - lora: components["schemas"]["ModelIdentifierField"]; + lora?: components["schemas"]["ModelIdentifierField"]; /** * Weight * @description The weight at which the LoRA is applied to each model @@ -7084,7 +7089,7 @@ export type components = { * LoRA * @description LoRA model to load */ - lora: components["schemas"]["ModelIdentifierField"]; + lora?: components["schemas"]["ModelIdentifierField"]; /** * Weight * @description The weight at which the LoRA is applied to each model @@ -7373,7 +7378,7 @@ export type components = { */ use_cache?: boolean; /** @description Main model (UNet, VAE, CLIP) to load */ - model: components["schemas"]["ModelIdentifierField"]; + model?: components["schemas"]["ModelIdentifierField"]; /** * type * @default main_model_loader @@ -8014,6 +8019,61 @@ export type components = { */ submodel_type?: components["schemas"]["SubModelType"] | null; }; + /** + * Model identifier + * @description Selects any model, outputting it its identifier. Be careful with this one! The identifier will be accepted as + * input for any model, even if the model types don't match. If you connect this to a mismatched input, you'll get an + * error. + */ + ModelIdentifierInvocation: { + /** + * Id + * @description The id of this instance of an invocation. Must be unique among all instances of invocations. + */ + id: string; + /** + * Is Intermediate + * @description Whether or not this is an intermediate invocation. + * @default false + */ + is_intermediate?: boolean; + /** + * Use Cache + * @description Whether or not to use the cache + * @default true + */ + use_cache?: boolean; + /** + * Model + * @description The model to select + */ + model?: components["schemas"]["ModelIdentifierField"]; + /** + * type + * @default model_identifier + * @constant + * @enum {string} + */ + type: "model_identifier"; + }; + /** + * ModelIdentifierOutput + * @description Model identifier output + */ + ModelIdentifierOutput: { + /** + * Model + * @description Model identifier + */ + model: components["schemas"]["ModelIdentifierField"]; + /** + * type + * @default model_identifier_output + * @constant + * @enum {string} + */ + type: "model_identifier_output"; + }; /** * ModelInstallJob * @description Object that tracks the current status of an install request. @@ -9241,7 +9301,7 @@ export type components = { * LoRA * @description LoRA model to load */ - lora: components["schemas"]["ModelIdentifierField"]; + lora?: components["schemas"]["ModelIdentifierField"]; /** * Weight * @description The weight at which the LoRA is applied to each model @@ -9325,7 +9385,7 @@ export type components = { */ use_cache?: boolean; /** @description SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load */ - model: components["schemas"]["ModelIdentifierField"]; + model?: components["schemas"]["ModelIdentifierField"]; /** * type * @default sdxl_model_loader @@ -9454,7 +9514,7 @@ export type components = { */ use_cache?: boolean; /** @description SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load */ - model: components["schemas"]["ModelIdentifierField"]; + model?: components["schemas"]["ModelIdentifierField"]; /** * type * @default sdxl_refiner_model_loader @@ -10682,7 +10742,7 @@ export type components = { * T2I-Adapter Model * @description The T2I-Adapter model. */ - t2i_adapter_model: components["schemas"]["ModelIdentifierField"]; + t2i_adapter_model?: components["schemas"]["ModelIdentifierField"]; /** * Weight * @description The weight given to the T2I-Adapter @@ -11356,7 +11416,7 @@ export type components = { * VAE * @description VAE model to load */ - vae_model: components["schemas"]["ModelIdentifierField"]; + vae_model?: components["schemas"]["ModelIdentifierField"]; /** * type * @default vae_loader @@ -11841,143 +11901,144 @@ export type components = { */ UIType: "MainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "T2IAdapterModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict"; InvocationOutputMap: { - ideal_size: components["schemas"]["IdealSizeOutput"]; - lora_collection_loader: components["schemas"]["LoRALoaderOutput"]; - color_map_image_processor: components["schemas"]["ImageOutput"]; - img_resize: components["schemas"]["ImageOutput"]; - calculate_image_tiles_min_overlap: components["schemas"]["CalculateImageTilesOutput"]; - lineart_image_processor: components["schemas"]["ImageOutput"]; - boolean_collection: components["schemas"]["BooleanCollectionOutput"]; - ip_adapter: components["schemas"]["IPAdapterOutput"]; - face_mask_detection: components["schemas"]["FaceMaskOutput"]; - string_replace: components["schemas"]["StringOutput"]; - infill_lama: components["schemas"]["ImageOutput"]; - calculate_image_tiles_even_split: components["schemas"]["CalculateImageTilesOutput"]; - tile_image_processor: components["schemas"]["ImageOutput"]; calculate_image_tiles: components["schemas"]["CalculateImageTilesOutput"]; - img_blur: components["schemas"]["ImageOutput"]; - scheduler: components["schemas"]["SchedulerOutput"]; - range: components["schemas"]["IntegerCollectionOutput"]; - lora_selector: components["schemas"]["LoRASelectorOutput"]; - metadata: components["schemas"]["MetadataOutput"]; clip_skip: components["schemas"]["CLIPSkipInvocationOutput"]; - rand_float: components["schemas"]["FloatOutput"]; - float_collection: components["schemas"]["FloatCollectionOutput"]; - zoe_depth_image_processor: components["schemas"]["ImageOutput"]; - create_gradient_mask: components["schemas"]["GradientMaskOutput"]; - i2l: components["schemas"]["LatentsOutput"]; - dynamic_prompt: components["schemas"]["StringCollectionOutput"]; - create_denoise_mask: components["schemas"]["DenoiseMaskOutput"]; - img_ilerp: components["schemas"]["ImageOutput"]; - tile_to_properties: components["schemas"]["TileToPropertiesOutput"]; - infill_cv2: components["schemas"]["ImageOutput"]; - string_join_three: components["schemas"]["StringOutput"]; - denoise_latents: components["schemas"]["LatentsOutput"]; - iterate: components["schemas"]["IterateInvocationOutput"]; - step_param_easing: components["schemas"]["FloatCollectionOutput"]; - img_nsfw: components["schemas"]["ImageOutput"]; - infill_patchmatch: components["schemas"]["ImageOutput"]; - pair_tile_image: components["schemas"]["PairTileImageOutput"]; - alpha_mask_to_tensor: components["schemas"]["MaskOutput"]; - lora_loader: components["schemas"]["LoRALoaderOutput"]; - normalbae_image_processor: components["schemas"]["ImageOutput"]; - img_hue_adjust: components["schemas"]["ImageOutput"]; - conditioning_collection: components["schemas"]["ConditioningCollectionOutput"]; - image_mask_to_tensor: components["schemas"]["MaskOutput"]; - t2i_adapter: components["schemas"]["T2IAdapterOutput"]; - infill_rgba: components["schemas"]["ImageOutput"]; - vae_loader: components["schemas"]["VAEOutput"]; - blank_image: components["schemas"]["ImageOutput"]; - latents: components["schemas"]["LatentsOutput"]; - sdxl_lora_loader: components["schemas"]["SDXLLoRALoaderOutput"]; - boolean: components["schemas"]["BooleanOutput"]; - float_range: components["schemas"]["FloatCollectionOutput"]; - integer: components["schemas"]["IntegerOutput"]; - mul: components["schemas"]["IntegerOutput"]; - img_crop: components["schemas"]["ImageOutput"]; - face_identifier: components["schemas"]["ImageOutput"]; - main_model_loader: components["schemas"]["ModelLoaderOutput"]; - mlsd_image_processor: components["schemas"]["ImageOutput"]; - esrgan: components["schemas"]["ImageOutput"]; - integer_math: components["schemas"]["IntegerOutput"]; - sdxl_lora_collection_loader: components["schemas"]["SDXLLoRALoaderOutput"]; - img_chan: components["schemas"]["ImageOutput"]; - round_float: components["schemas"]["FloatOutput"]; - random_range: components["schemas"]["IntegerCollectionOutput"]; - image_collection: components["schemas"]["ImageCollectionOutput"]; - sub: components["schemas"]["IntegerOutput"]; - lblend: components["schemas"]["LatentsOutput"]; - sdxl_refiner_model_loader: components["schemas"]["SDXLRefinerModelLoaderOutput"]; - cv_inpaint: components["schemas"]["ImageOutput"]; - sdxl_model_loader: components["schemas"]["SDXLModelLoaderOutput"]; - invert_tensor_mask: components["schemas"]["MaskOutput"]; - image: components["schemas"]["ImageOutput"]; - img_mul: components["schemas"]["ImageOutput"]; - l2i: components["schemas"]["ImageOutput"]; - canny_image_processor: components["schemas"]["ImageOutput"]; - save_image: components["schemas"]["ImageOutput"]; - string_split: components["schemas"]["String2Output"]; - segment_anything_processor: components["schemas"]["ImageOutput"]; - heuristic_resize: components["schemas"]["ImageOutput"]; - face_off: components["schemas"]["FaceOffOutput"]; - img_channel_offset: components["schemas"]["ImageOutput"]; - img_conv: components["schemas"]["ImageOutput"]; - add: components["schemas"]["IntegerOutput"]; - infill_tile: components["schemas"]["ImageOutput"]; - color: components["schemas"]["ColorOutput"]; - mediapipe_face_processor: components["schemas"]["ImageOutput"]; - freeu: components["schemas"]["UNetOutput"]; - pidi_image_processor: components["schemas"]["ImageOutput"]; - depth_anything_image_processor: components["schemas"]["ImageOutput"]; - noise: components["schemas"]["NoiseOutput"]; - collect: components["schemas"]["CollectInvocationOutput"]; - content_shuffle_image_processor: components["schemas"]["ImageOutput"]; - string_split_neg: components["schemas"]["StringPosNegOutput"]; - img_lerp: components["schemas"]["ImageOutput"]; - leres_image_processor: components["schemas"]["ImageOutput"]; - div: components["schemas"]["IntegerOutput"]; - lscale: components["schemas"]["LatentsOutput"]; - metadata_item: components["schemas"]["MetadataItemOutput"]; - seamless: components["schemas"]["SeamlessModeOutput"]; - img_paste: components["schemas"]["ImageOutput"]; - string: components["schemas"]["StringOutput"]; - mask_combine: components["schemas"]["ImageOutput"]; - float_math: components["schemas"]["FloatOutput"]; - tomask: components["schemas"]["ImageOutput"]; - img_channel_multiply: components["schemas"]["ImageOutput"]; - sdxl_compel_prompt: components["schemas"]["ConditioningOutput"]; - mask_edge: components["schemas"]["ImageOutput"]; - merge_tiles_to_image: components["schemas"]["ImageOutput"]; range_of_size: components["schemas"]["IntegerCollectionOutput"]; - sdxl_refiner_compel_prompt: components["schemas"]["ConditioningOutput"]; - canvas_paste_back: components["schemas"]["ImageOutput"]; + lora_loader: components["schemas"]["LoRALoaderOutput"]; + img_mul: components["schemas"]["ImageOutput"]; + div: components["schemas"]["IntegerOutput"]; + crop_latents: components["schemas"]["LatentsOutput"]; + scheduler: components["schemas"]["SchedulerOutput"]; + blank_image: components["schemas"]["ImageOutput"]; + invert_tensor_mask: components["schemas"]["MaskOutput"]; controlnet: components["schemas"]["ControlOutput"]; - dw_openpose_image_processor: components["schemas"]["ImageOutput"]; - string_collection: components["schemas"]["StringCollectionOutput"]; - float_to_int: components["schemas"]["IntegerOutput"]; - color_correct: components["schemas"]["ImageOutput"]; - unsharp_mask: components["schemas"]["ImageOutput"]; - float: components["schemas"]["FloatOutput"]; + create_gradient_mask: components["schemas"]["GradientMaskOutput"]; + img_crop: components["schemas"]["ImageOutput"]; + img_chan: components["schemas"]["ImageOutput"]; + iterate: components["schemas"]["IterateInvocationOutput"]; + img_hue_adjust: components["schemas"]["ImageOutput"]; + merge_tiles_to_image: components["schemas"]["ImageOutput"]; + denoise_latents: components["schemas"]["LatentsOutput"]; + string_split_neg: components["schemas"]["StringPosNegOutput"]; + metadata_item: components["schemas"]["MetadataItemOutput"]; + face_off: components["schemas"]["FaceOffOutput"]; + zoe_depth_image_processor: components["schemas"]["ImageOutput"]; + prompt_from_file: components["schemas"]["StringCollectionOutput"]; + img_nsfw: components["schemas"]["ImageOutput"]; + infill_lama: components["schemas"]["ImageOutput"]; + vae_loader: components["schemas"]["VAEOutput"]; + noise: components["schemas"]["NoiseOutput"]; + midas_depth_image_processor: components["schemas"]["ImageOutput"]; + string: components["schemas"]["StringOutput"]; + img_conv: components["schemas"]["ImageOutput"]; + mlsd_image_processor: components["schemas"]["ImageOutput"]; + core_metadata: components["schemas"]["MetadataOutput"]; + float_math: components["schemas"]["FloatOutput"]; + hed_image_processor: components["schemas"]["ImageOutput"]; + lineart_anime_image_processor: components["schemas"]["ImageOutput"]; + main_model_loader: components["schemas"]["ModelLoaderOutput"]; + infill_cv2: components["schemas"]["ImageOutput"]; + image: components["schemas"]["ImageOutput"]; + normalbae_image_processor: components["schemas"]["ImageOutput"]; rand_int: components["schemas"]["IntegerOutput"]; - mask_from_id: components["schemas"]["ImageOutput"]; - latents_collection: components["schemas"]["LatentsCollectionOutput"]; - conditioning: components["schemas"]["ConditioningOutput"]; + image_collection: components["schemas"]["ImageCollectionOutput"]; + step_param_easing: components["schemas"]["FloatCollectionOutput"]; + infill_patchmatch: components["schemas"]["ImageOutput"]; + sdxl_model_loader: components["schemas"]["SDXLModelLoaderOutput"]; + string_collection: components["schemas"]["StringCollectionOutput"]; + img_paste: components["schemas"]["ImageOutput"]; + infill_rgba: components["schemas"]["ImageOutput"]; integer_collection: components["schemas"]["IntegerCollectionOutput"]; + float_to_int: components["schemas"]["IntegerOutput"]; + tile_image_processor: components["schemas"]["ImageOutput"]; + mask_combine: components["schemas"]["ImageOutput"]; + merge_metadata: components["schemas"]["MetadataOutput"]; + rectangle_mask: components["schemas"]["MaskOutput"]; + color_map_image_processor: components["schemas"]["ImageOutput"]; + img_lerp: components["schemas"]["ImageOutput"]; + mask_edge: components["schemas"]["ImageOutput"]; + ip_adapter: components["schemas"]["IPAdapterOutput"]; + lineart_image_processor: components["schemas"]["ImageOutput"]; + seamless: components["schemas"]["SeamlessModeOutput"]; + img_channel_offset: components["schemas"]["ImageOutput"]; + sdxl_refiner_model_loader: components["schemas"]["SDXLRefinerModelLoaderOutput"]; + range: components["schemas"]["IntegerCollectionOutput"]; + lresize: components["schemas"]["LatentsOutput"]; + freeu: components["schemas"]["UNetOutput"]; string_join: components["schemas"]["StringOutput"]; compel: components["schemas"]["ConditioningOutput"]; - crop_latents: components["schemas"]["LatentsOutput"]; + collect: components["schemas"]["CollectInvocationOutput"]; img_watermark: components["schemas"]["ImageOutput"]; - rectangle_mask: components["schemas"]["MaskOutput"]; - prompt_from_file: components["schemas"]["StringCollectionOutput"]; - merge_metadata: components["schemas"]["MetadataOutput"]; + float_range: components["schemas"]["FloatCollectionOutput"]; + sdxl_compel_prompt: components["schemas"]["ConditioningOutput"]; + i2l: components["schemas"]["LatentsOutput"]; + sdxl_refiner_compel_prompt: components["schemas"]["ConditioningOutput"]; + float: components["schemas"]["FloatOutput"]; + dynamic_prompt: components["schemas"]["StringCollectionOutput"]; + save_image: components["schemas"]["ImageOutput"]; + heuristic_resize: components["schemas"]["ImageOutput"]; + lblend: components["schemas"]["LatentsOutput"]; + tomask: components["schemas"]["ImageOutput"]; + leres_image_processor: components["schemas"]["ImageOutput"]; + lscale: components["schemas"]["LatentsOutput"]; + conditioning: components["schemas"]["ConditioningOutput"]; + mediapipe_face_processor: components["schemas"]["ImageOutput"]; + esrgan: components["schemas"]["ImageOutput"]; img_pad_crop: components["schemas"]["ImageOutput"]; - midas_depth_image_processor: components["schemas"]["ImageOutput"]; - core_metadata: components["schemas"]["MetadataOutput"]; + content_shuffle_image_processor: components["schemas"]["ImageOutput"]; + color_correct: components["schemas"]["ImageOutput"]; + unsharp_mask: components["schemas"]["ImageOutput"]; + infill_tile: components["schemas"]["ImageOutput"]; + canny_image_processor: components["schemas"]["ImageOutput"]; show_image: components["schemas"]["ImageOutput"]; - hed_image_processor: components["schemas"]["ImageOutput"]; - lresize: components["schemas"]["LatentsOutput"]; - lineart_anime_image_processor: components["schemas"]["ImageOutput"]; + pidi_image_processor: components["schemas"]["ImageOutput"]; + pair_tile_image: components["schemas"]["PairTileImageOutput"]; + segment_anything_processor: components["schemas"]["ImageOutput"]; + rand_float: components["schemas"]["FloatOutput"]; + canvas_paste_back: components["schemas"]["ImageOutput"]; + depth_anything_image_processor: components["schemas"]["ImageOutput"]; + img_channel_multiply: components["schemas"]["ImageOutput"]; + metadata: components["schemas"]["MetadataOutput"]; + string_replace: components["schemas"]["StringOutput"]; + image_mask_to_tensor: components["schemas"]["MaskOutput"]; + mul: components["schemas"]["IntegerOutput"]; img_scale: components["schemas"]["ImageOutput"]; + model_identifier: components["schemas"]["ModelIdentifierOutput"]; + alpha_mask_to_tensor: components["schemas"]["MaskOutput"]; + latents: components["schemas"]["LatentsOutput"]; + dw_openpose_image_processor: components["schemas"]["ImageOutput"]; + mask_from_id: components["schemas"]["ImageOutput"]; + conditioning_collection: components["schemas"]["ConditioningCollectionOutput"]; + round_float: components["schemas"]["FloatOutput"]; + face_mask_detection: components["schemas"]["FaceMaskOutput"]; + calculate_image_tiles_min_overlap: components["schemas"]["CalculateImageTilesOutput"]; + img_resize: components["schemas"]["ImageOutput"]; + l2i: components["schemas"]["ImageOutput"]; + color: components["schemas"]["ColorOutput"]; + lora_collection_loader: components["schemas"]["LoRALoaderOutput"]; + sdxl_lora_collection_loader: components["schemas"]["SDXLLoRALoaderOutput"]; + string_join_three: components["schemas"]["StringOutput"]; + sub: components["schemas"]["IntegerOutput"]; + img_blur: components["schemas"]["ImageOutput"]; + float_collection: components["schemas"]["FloatCollectionOutput"]; + integer: components["schemas"]["IntegerOutput"]; + face_identifier: components["schemas"]["ImageOutput"]; + latents_collection: components["schemas"]["LatentsCollectionOutput"]; + cv_inpaint: components["schemas"]["ImageOutput"]; + t2i_adapter: components["schemas"]["T2IAdapterOutput"]; + create_denoise_mask: components["schemas"]["DenoiseMaskOutput"]; + random_range: components["schemas"]["IntegerCollectionOutput"]; + sdxl_lora_loader: components["schemas"]["SDXLLoRALoaderOutput"]; + ideal_size: components["schemas"]["IdealSizeOutput"]; + tile_to_properties: components["schemas"]["TileToPropertiesOutput"]; + img_ilerp: components["schemas"]["ImageOutput"]; + integer_math: components["schemas"]["IntegerOutput"]; + add: components["schemas"]["IntegerOutput"]; + boolean: components["schemas"]["BooleanOutput"]; + string_split: components["schemas"]["String2Output"]; + lora_selector: components["schemas"]["LoRASelectorOutput"]; + boolean_collection: components["schemas"]["BooleanCollectionOutput"]; + calculate_image_tiles_even_split: components["schemas"]["CalculateImageTilesOutput"]; }; }; responses: never; From c94742bde678a5993f908fb7d61c8fa4528e5a67 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 20 May 2024 09:35:34 +1000 Subject: [PATCH 065/207] feat(ui): add canvas objects to metadata when saving canvas to gallery --- .../listenerMiddleware/listeners/canvasSavedToGallery.ts | 4 ++++ invokeai/frontend/web/src/services/api/endpoints/images.ts | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts index e3ba988886..7f456e9a68 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/canvasSavedToGallery.ts @@ -1,5 +1,6 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; +import { parseify } from 'common/util/serialize'; import { canvasSavedToGallery } from 'features/canvas/store/actions'; import { getBaseLayerBlob } from 'features/canvas/util/getBaseLayerBlob'; import { addToast } from 'features/system/store/systemSlice'; @@ -43,6 +44,9 @@ export const addCanvasSavedToGalleryListener = (startAppListening: AppStartListe type: 'TOAST', toastOptions: { title: t('toast.canvasSavedGallery') }, }, + metadata: { + _canvas_objects: parseify(state.canvas.layerState.objects), + }, }) ); }, diff --git a/invokeai/frontend/web/src/services/api/endpoints/images.ts b/invokeai/frontend/web/src/services/api/endpoints/images.ts index 98c253b479..14edf6fb87 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/images.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/images.ts @@ -571,11 +571,13 @@ export const imagesApi = api.injectEndpoints({ session_id?: string; board_id?: string; crop_visible?: boolean; + metadata?: JSONObject; } >({ - query: ({ file, image_category, is_intermediate, session_id, board_id, crop_visible }) => { + query: ({ file, image_category, is_intermediate, session_id, board_id, crop_visible, metadata }) => { const formData = new FormData(); formData.append('file', file); + formData.append('metadata', JSON.stringify(metadata)); return { url: buildImagesUrl('upload'), method: 'POST', From f4625c2671c94c1d9bda79678aca143e412bfa24 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 20 May 2024 09:35:47 +1000 Subject: [PATCH 066/207] feat(ui): add canvas objects to metadat a for all canvas graphs --- .../util/graph/canvas/buildCanvasImageToImageGraph.ts | 1 + .../nodes/util/graph/canvas/buildCanvasInpaintGraph.ts | 10 ++++++++++ .../util/graph/canvas/buildCanvasOutpaintGraph.ts | 10 ++++++++++ .../graph/canvas/buildCanvasSDXLImageToImageGraph.ts | 1 + .../util/graph/canvas/buildCanvasSDXLInpaintGraph.ts | 10 ++++++++++ .../util/graph/canvas/buildCanvasSDXLOutpaintGraph.ts | 10 ++++++++++ .../graph/canvas/buildCanvasSDXLTextToImageGraph.ts | 1 + .../util/graph/canvas/buildCanvasTextToImageGraph.ts | 1 + 8 files changed, 44 insertions(+) diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasImageToImageGraph.ts index 8f5fe9f2b8..5c89dcbf29 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasImageToImageGraph.ts @@ -330,6 +330,7 @@ export const buildCanvasImageToImageGraph = async ( clip_skip: clipSkip, strength, init_image: initialImage.image_name, + _canvas_objects: state.canvas.layerState.objects, }, CANVAS_OUTPUT ); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasInpaintGraph.ts index c995c38a3c..20304b8830 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasInpaintGraph.ts @@ -1,5 +1,6 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { addCoreMetadataNode } from 'features/nodes/util/graph/canvas/metadata'; import { CANVAS_INPAINT_GRAPH, CANVAS_OUTPUT, @@ -421,6 +422,15 @@ export const buildCanvasInpaintGraph = async ( }); } + addCoreMetadataNode( + graph, + { + generation_mode: 'inpaint', + _canvas_objects: state.canvas.layerState.objects, + }, + CANVAS_OUTPUT + ); + // Add Seamless To Graph if (seamlessXAxis || seamlessYAxis) { addSeamlessToLinearGraph(state, graph, modelLoaderNodeId); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasOutpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasOutpaintGraph.ts index e4a9b11b96..2c85b20222 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasOutpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasOutpaintGraph.ts @@ -1,5 +1,6 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { addCoreMetadataNode } from 'features/nodes/util/graph/canvas/metadata'; import { CANVAS_OUTPAINT_GRAPH, CANVAS_OUTPUT, @@ -579,6 +580,15 @@ export const buildCanvasOutpaintGraph = async ( ); } + addCoreMetadataNode( + graph, + { + generation_mode: 'outpaint', + _canvas_objects: state.canvas.layerState.objects, + }, + CANVAS_OUTPUT + ); + // Add Seamless To Graph if (seamlessXAxis || seamlessYAxis) { addSeamlessToLinearGraph(state, graph, modelLoaderNodeId); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLImageToImageGraph.ts index 186dfa53b3..b4549ff582 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLImageToImageGraph.ts @@ -332,6 +332,7 @@ export const buildCanvasSDXLImageToImageGraph = async ( init_image: initialImage.image_name, positive_style_prompt: positiveStylePrompt, negative_style_prompt: negativeStylePrompt, + _canvas_objects: state.canvas.layerState.objects, }, CANVAS_OUTPUT ); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLInpaintGraph.ts index 277b713079..dfbe2436d2 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLInpaintGraph.ts @@ -1,5 +1,6 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { addCoreMetadataNode } from 'features/nodes/util/graph/canvas/metadata'; import { CANVAS_OUTPUT, INPAINT_CREATE_MASK, @@ -432,6 +433,15 @@ export const buildCanvasSDXLInpaintGraph = async ( }); } + addCoreMetadataNode( + graph, + { + generation_mode: 'sdxl_inpaint', + _canvas_objects: state.canvas.layerState.objects, + }, + CANVAS_OUTPUT + ); + // Add Seamless To Graph if (seamlessXAxis || seamlessYAxis) { addSeamlessToLinearGraph(state, graph, modelLoaderNodeId); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLOutpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLOutpaintGraph.ts index b09d7d8b90..d58796575c 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLOutpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLOutpaintGraph.ts @@ -1,5 +1,6 @@ import { logger } from 'app/logging/logger'; import type { RootState } from 'app/store/store'; +import { addCoreMetadataNode } from 'features/nodes/util/graph/canvas/metadata'; import { CANVAS_OUTPUT, INPAINT_CREATE_MASK, @@ -588,6 +589,15 @@ export const buildCanvasSDXLOutpaintGraph = async ( ); } + addCoreMetadataNode( + graph, + { + generation_mode: 'sdxl_outpaint', + _canvas_objects: state.canvas.layerState.objects, + }, + CANVAS_OUTPUT + ); + // Add Seamless To Graph if (seamlessXAxis || seamlessYAxis) { addSeamlessToLinearGraph(state, graph, modelLoaderNodeId); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLTextToImageGraph.ts index b2a8aa6ada..b9e8e011b3 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasSDXLTextToImageGraph.ts @@ -291,6 +291,7 @@ export const buildCanvasSDXLTextToImageGraph = async (state: RootState): Promise steps, rand_device: use_cpu ? 'cpu' : 'cuda', scheduler, + _canvas_objects: state.canvas.layerState.objects, }, CANVAS_OUTPUT ); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasTextToImageGraph.ts index 8ce5134480..fe33ab5cf3 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/buildCanvasTextToImageGraph.ts @@ -280,6 +280,7 @@ export const buildCanvasTextToImageGraph = async (state: RootState): Promise Date: Sat, 18 May 2024 09:12:10 +1000 Subject: [PATCH 067/207] fix(ui): fix t2i adapter dimensions error message It now indicates the correct dimension of 64 (SD1.5) or 32 (SDXL) - before was hardcoded to 64. --- invokeai/frontend/web/public/locales/en.json | 2 +- invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 1f44e641fc..5dd411c544 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -951,7 +951,7 @@ "controlAdapterIncompatibleBaseModel": "incompatible Control Adapter base model", "controlAdapterNoImageSelected": "no Control Adapter image selected", "controlAdapterImageNotProcessed": "Control Adapter image not processed", - "t2iAdapterIncompatibleDimensions": "T2I Adapter requires image dimension to be multiples of 64", + "t2iAdapterIncompatibleDimensions": "T2I Adapter requires image dimension to be multiples of {{multiple}}", "ipAdapterNoModelSelected": "no IP adapter selected", "ipAdapterIncompatibleBaseModel": "incompatible IP Adapter base model", "ipAdapterNoImageSelected": "no IP Adapter image selected", diff --git a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts index 41d6f4607e..dbf3c41480 100644 --- a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts +++ b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts @@ -137,7 +137,7 @@ const createSelector = (templates: Templates) => if (l.controlAdapter.type === 't2i_adapter') { const multiple = model?.base === 'sdxl' ? 32 : 64; if (size.width % multiple !== 0 || size.height % multiple !== 0) { - problems.push(i18n.t('parameters.invoke.layer.t2iAdapterIncompatibleDimensions')); + problems.push(i18n.t('parameters.invoke.layer.t2iAdapterIncompatibleDimensions', { multiple })); } } } From dba8c43ecbbde65708da82baa3eae5a8a1a96f27 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 22:54:11 +1000 Subject: [PATCH 068/207] feat(ui): explicit field type cardinality Replace the `isCollection` and `isCollectionOrScalar` flags with a single enum value `cardinality`. Valid values are `SINGLE`, `COLLECTION` and `SINGLE_OR_COLLECTION`. Why: - The two flags were mutually exclusive, but this wasn't enforce. You could create a field type that had both `isCollection` and `isCollectionOrScalar` set to true, whuch makes no sense. - There was no explicit declaration for scalar/single types. - Checking if a type had only a single flag was tedious. Thanks to a change a couple months back in which the workflows schema was revised, field types are internal implementation details. Changes to them are non-breaking. --- .../frontend/web/src/features/nodes/types/field.ts | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts index 8a1a0b5039..e2a84e3390 100644 --- a/invokeai/frontend/web/src/features/nodes/types/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -54,9 +54,10 @@ const zFieldOutputTemplateBase = zFieldTemplateBase.extend({ fieldKind: z.literal('output'), }); +const zCardinality = z.enum(['SINGLE', 'COLLECTION', 'SINGLE_OR_COLLECTION']); + const zFieldTypeBase = z.object({ - isCollection: z.boolean(), - isCollectionOrScalar: z.boolean(), + cardinality: zCardinality, }); export const zFieldIdentifier = z.object({ @@ -168,6 +169,11 @@ export const isStatefulFieldType = (fieldType: FieldType): fieldType is Stateful (statefulFieldTypeNames as string[]).includes(fieldType.name); const zFieldType = z.union([zStatefulFieldType, zStatelessFieldType]); export type FieldType = z.infer; + +export const isSingle = (fieldType: FieldType): boolean => fieldType.cardinality === zCardinality.enum.SINGLE; +export const isCollection = (fieldType: FieldType): boolean => fieldType.cardinality === zCardinality.enum.COLLECTION; +export const isSingleOrCollection = (fieldType: FieldType): boolean => + fieldType.cardinality === zCardinality.enum.SINGLE_OR_COLLECTION; // #endregion // #region IntegerField From 8062a47d16c713afd3374e9199a72e93f85cb634 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 22:57:12 +1000 Subject: [PATCH 069/207] fix(ui): use new field type cardinality throughout app Update business logic and tests. --- .../nodes/Invocation/fields/FieldHandle.tsx | 6 +- .../hooks/useAnyOrDirectInputFieldNames.ts | 3 +- .../hooks/useConnectionInputFieldNames.ts | 3 +- .../nodes/hooks/usePrettyFieldType.ts | 6 +- .../nodes/store/util/areTypesEqual.test.ts | 73 +++----- .../store/util/getCollectItemType.test.ts | 2 +- .../features/nodes/store/util/testUtils.ts | 93 ++++----- .../util/validateConnectionTypes.test.ts | 176 +++++++++--------- .../store/util/validateConnectionTypes.ts | 47 +++-- .../nodes/util/schema/parseFieldType.test.ts | 83 +++++---- .../nodes/util/schema/parseFieldType.ts | 27 +-- .../features/nodes/util/schema/parseSchema.ts | 10 +- 12 files changed, 239 insertions(+), 290 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx index 033aa61bdf..143dee983f 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx @@ -4,7 +4,7 @@ import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdge import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType'; import type { ValidationResult } from 'features/nodes/store/util/validateConnection'; import { HANDLE_TOOLTIP_OPEN_DELAY, MODEL_TYPES } from 'features/nodes/types/constants'; -import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field'; +import { type FieldInputTemplate, type FieldOutputTemplate, isSingle } from 'features/nodes/types/field'; import type { CSSProperties } from 'react'; import { memo, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -29,11 +29,11 @@ const FieldHandle = (props: FieldHandleProps) => { const isModelType = MODEL_TYPES.some((t) => t === type.name); const color = getFieldColor(type); const s: CSSProperties = { - backgroundColor: type.isCollection || type.isCollectionOrScalar ? colorTokenToCssVar('base.900') : color, + backgroundColor: !isSingle(type) ? colorTokenToCssVar('base.900') : color, position: 'absolute', width: '1rem', height: '1rem', - borderWidth: type.isCollection || type.isCollectionOrScalar ? 4 : 0, + borderWidth: !isSingle(type) ? 4 : 0, borderStyle: 'solid', borderColor: color, borderRadius: isModelType ? 4 : '100%', diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts index 3b7a1b74c1..7fae0de16e 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts @@ -1,5 +1,6 @@ import { EMPTY_ARRAY } from 'app/store/constants'; import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate'; +import { isSingleOrCollection } from 'features/nodes/types/field'; import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames'; import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate'; import { keys, map } from 'lodash-es'; @@ -11,7 +12,7 @@ export const useAnyOrDirectInputFieldNames = (nodeId: string): string[] => { const fieldNames = useMemo(() => { const fields = map(template.inputs).filter((field) => { return ( - (['any', 'direct'].includes(field.input) || field.type.isCollectionOrScalar) && + (['any', 'direct'].includes(field.input) || isSingleOrCollection(field.type)) && keys(TEMPLATE_BUILDER_MAP).includes(field.type.name) ); }); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts index d071ac76d2..16ace597c1 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts @@ -1,5 +1,6 @@ import { EMPTY_ARRAY } from 'app/store/constants'; import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate'; +import { isSingleOrCollection } from 'features/nodes/types/field'; import { getSortedFilteredFieldNames } from 'features/nodes/util/node/getSortedFilteredFieldNames'; import { TEMPLATE_BUILDER_MAP } from 'features/nodes/util/schema/buildFieldInputTemplate'; import { keys, map } from 'lodash-es'; @@ -11,7 +12,7 @@ export const useConnectionInputFieldNames = (nodeId: string): string[] => { // get the visible fields const fields = map(template.inputs).filter( (field) => - (field.input === 'connection' && !field.type.isCollectionOrScalar) || + (field.input === 'connection' && !isSingleOrCollection(field.type)) || !keys(TEMPLATE_BUILDER_MAP).includes(field.type.name) ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/usePrettyFieldType.ts b/invokeai/frontend/web/src/features/nodes/hooks/usePrettyFieldType.ts index df4b742842..2600eae078 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/usePrettyFieldType.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/usePrettyFieldType.ts @@ -1,4 +1,4 @@ -import type { FieldType } from 'features/nodes/types/field'; +import { type FieldType, isCollection, isSingleOrCollection } from 'features/nodes/types/field'; import { useMemo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -10,10 +10,10 @@ export const useFieldTypeName = (fieldType?: FieldType): string => { return ''; } const { name } = fieldType; - if (fieldType.isCollection) { + if (isCollection(fieldType)) { return t('nodes.collectionFieldType', { name }); } - if (fieldType.isCollectionOrScalar) { + if (isSingleOrCollection(fieldType)) { return t('nodes.collectionOrScalarFieldType', { name }); } return name; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/areTypesEqual.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/areTypesEqual.test.ts index 7be307d07e..ae9d4f6742 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/areTypesEqual.test.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/areTypesEqual.test.ts @@ -1,99 +1,84 @@ +import type { FieldType } from 'features/nodes/types/field'; import { describe, expect, it } from 'vitest'; import { areTypesEqual } from './areTypesEqual'; describe(areTypesEqual.name, () => { it('should handle equal source and target type', () => { - const sourceType = { + const sourceType: FieldType = { name: 'IntegerField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', originalType: { name: 'Foo', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, }; - const targetType = { + const targetType: FieldType = { name: 'IntegerField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', originalType: { name: 'Bar', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, }; expect(areTypesEqual(sourceType, targetType)).toBe(true); }); it('should handle equal source type and original target type', () => { - const sourceType = { + const sourceType: FieldType = { name: 'IntegerField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', originalType: { name: 'Foo', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, }; - const targetType = { - name: 'Bar', - isCollection: false, - isCollectionOrScalar: false, + const targetType: FieldType = { + name: 'MainModelField', + cardinality: 'SINGLE', originalType: { name: 'IntegerField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, }; expect(areTypesEqual(sourceType, targetType)).toBe(true); }); it('should handle equal original source type and target type', () => { - const sourceType = { - name: 'Foo', - isCollection: false, - isCollectionOrScalar: false, + const sourceType: FieldType = { + name: 'MainModelField', + cardinality: 'SINGLE', originalType: { name: 'IntegerField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, }; - const targetType = { + const targetType: FieldType = { name: 'IntegerField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', originalType: { name: 'Bar', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, }; expect(areTypesEqual(sourceType, targetType)).toBe(true); }); it('should handle equal original source type and original target type', () => { - const sourceType = { - name: 'Foo', - isCollection: false, - isCollectionOrScalar: false, + const sourceType: FieldType = { + name: 'MainModelField', + cardinality: 'SINGLE', originalType: { name: 'IntegerField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, }; - const targetType = { - name: 'Bar', - isCollection: false, - isCollectionOrScalar: false, + const targetType: FieldType = { + name: 'LoRAModelField', + cardinality: 'SINGLE', originalType: { name: 'IntegerField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, }; expect(areTypesEqual(sourceType, targetType)).toBe(true); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts index 935250b697..be0b553d8b 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts @@ -11,7 +11,7 @@ describe(getCollectItemType.name, () => { const n2 = buildNode(collect); const e1 = buildEdge(n1.id, 'value', n2.id, 'item'); const result = getCollectItemType(templates, [n1, n2], [e1], n2.id); - expect(result).toEqual({ name: 'IntegerField', isCollection: false, isCollectionOrScalar: false }); + expect(result).toEqual({ name: 'IntegerField', cardinality: 'SINGLE' }); }); it('should return null if the collect node does not have any connections', () => { const n1 = buildNode(collect); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts b/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts index 5155bb14ea..83988d55ea 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts @@ -33,8 +33,7 @@ export const add: InvocationTemplate = { ui_hidden: false, type: { name: 'IntegerField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, default: 0, }, @@ -48,8 +47,7 @@ export const add: InvocationTemplate = { ui_hidden: false, type: { name: 'IntegerField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, default: 0, }, @@ -62,8 +60,7 @@ export const add: InvocationTemplate = { description: 'The output integer', type: { name: 'IntegerField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, ui_hidden: false, }, @@ -91,8 +88,7 @@ export const sub: InvocationTemplate = { ui_hidden: false, type: { name: 'IntegerField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, default: 0, }, @@ -106,8 +102,7 @@ export const sub: InvocationTemplate = { ui_hidden: false, type: { name: 'IntegerField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, default: 0, }, @@ -120,8 +115,7 @@ export const sub: InvocationTemplate = { description: 'The output integer', type: { name: 'IntegerField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, ui_hidden: false, }, @@ -150,8 +144,7 @@ export const collect: InvocationTemplate = { ui_type: 'CollectionItemField', type: { name: 'CollectionItemField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, }, }, @@ -163,8 +156,7 @@ export const collect: InvocationTemplate = { description: 'The collection of input items', type: { name: 'CollectionField', - isCollection: true, - isCollectionOrScalar: false, + cardinality: 'COLLECTION', }, ui_hidden: false, ui_type: 'CollectionField', @@ -193,12 +185,11 @@ const scheduler: InvocationTemplate = { ui_type: 'SchedulerField', type: { name: 'SchedulerField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', + originalType: { name: 'EnumField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, }, default: 'euler', @@ -212,12 +203,11 @@ const scheduler: InvocationTemplate = { description: 'Scheduler to use during inference', type: { name: 'SchedulerField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', + originalType: { name: 'EnumField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, }, ui_hidden: false, @@ -248,12 +238,11 @@ export const main_model_loader: InvocationTemplate = { ui_type: 'MainModelField', type: { name: 'MainModelField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', + originalType: { name: 'ModelIdentifierField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, }, }, @@ -266,8 +255,7 @@ export const main_model_loader: InvocationTemplate = { description: 'VAE', type: { name: 'VAEField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, ui_hidden: false, }, @@ -278,8 +266,7 @@ export const main_model_loader: InvocationTemplate = { description: 'CLIP (tokenizer, text encoder, LoRAs) and skipped layer count', type: { name: 'CLIPField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, ui_hidden: false, }, @@ -290,8 +277,7 @@ export const main_model_loader: InvocationTemplate = { description: 'UNet (scheduler, LoRAs)', type: { name: 'UNetField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, ui_hidden: false, }, @@ -319,8 +305,7 @@ export const img_resize: InvocationTemplate = { ui_hidden: false, type: { name: 'BoardField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, }, metadata: { @@ -333,8 +318,7 @@ export const img_resize: InvocationTemplate = { ui_hidden: false, type: { name: 'MetadataField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, }, image: { @@ -347,8 +331,7 @@ export const img_resize: InvocationTemplate = { ui_hidden: false, type: { name: 'ImageField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, }, width: { @@ -361,8 +344,7 @@ export const img_resize: InvocationTemplate = { ui_hidden: false, type: { name: 'IntegerField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, default: 512, exclusiveMinimum: 0, @@ -377,8 +359,7 @@ export const img_resize: InvocationTemplate = { ui_hidden: false, type: { name: 'IntegerField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, default: 512, exclusiveMinimum: 0, @@ -393,8 +374,7 @@ export const img_resize: InvocationTemplate = { ui_hidden: false, type: { name: 'EnumField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, options: ['nearest', 'box', 'bilinear', 'hamming', 'bicubic', 'lanczos'], default: 'bicubic', @@ -408,8 +388,7 @@ export const img_resize: InvocationTemplate = { description: 'The output image', type: { name: 'ImageField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, ui_hidden: false, }, @@ -420,8 +399,7 @@ export const img_resize: InvocationTemplate = { description: 'The width of the image in pixels', type: { name: 'IntegerField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, ui_hidden: false, }, @@ -432,8 +410,7 @@ export const img_resize: InvocationTemplate = { description: 'The height of the image in pixels', type: { name: 'IntegerField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, ui_hidden: false, }, @@ -462,8 +439,7 @@ const iterate: InvocationTemplate = { ui_type: 'CollectionField', type: { name: 'CollectionField', - isCollection: true, - isCollectionOrScalar: false, + cardinality: 'COLLECTION', }, }, }, @@ -475,8 +451,7 @@ const iterate: InvocationTemplate = { description: 'The item being iterated over', type: { name: 'CollectionItemField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, ui_hidden: false, ui_type: 'CollectionItemField', @@ -488,8 +463,7 @@ const iterate: InvocationTemplate = { description: 'The index of the item', type: { name: 'IntegerField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, ui_hidden: false, }, @@ -500,8 +474,7 @@ const iterate: InvocationTemplate = { description: 'The total number of items', type: { name: 'IntegerField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }, ui_hidden: false, }, diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.test.ts index 10344dd349..56d4cfe70a 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.test.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.test.ts @@ -4,148 +4,148 @@ import { validateConnectionTypes } from './validateConnectionTypes'; describe(validateConnectionTypes.name, () => { describe('generic cases', () => { - it('should accept Scalar to Scalar of same type', () => { + it('should accept SINGLE to SINGLE of same type', () => { const r = validateConnectionTypes( - { name: 'FooField', isCollection: false, isCollectionOrScalar: false }, - { name: 'FooField', isCollection: false, isCollectionOrScalar: false } + { name: 'FooField', cardinality: 'SINGLE' }, + { name: 'FooField', cardinality: 'SINGLE' } ); expect(r).toBe(true); }); - it('should accept Collection to Collection of same type', () => { + it('should accept COLLECTION to COLLECTION of same type', () => { const r = validateConnectionTypes( - { name: 'FooField', isCollection: true, isCollectionOrScalar: false }, - { name: 'FooField', isCollection: true, isCollectionOrScalar: false } + { name: 'FooField', cardinality: 'COLLECTION' }, + { name: 'FooField', cardinality: 'COLLECTION' } ); expect(r).toBe(true); }); - it('should accept Scalar to CollectionOrScalar of same type', () => { + it('should accept SINGLE to SINGLE_OR_COLLECTION of same type', () => { const r = validateConnectionTypes( - { name: 'FooField', isCollection: false, isCollectionOrScalar: false }, - { name: 'FooField', isCollection: false, isCollectionOrScalar: true } + { name: 'FooField', cardinality: 'SINGLE' }, + { name: 'FooField', cardinality: 'SINGLE_OR_COLLECTION' } ); expect(r).toBe(true); }); - it('should accept Collection to CollectionOrScalar of same type', () => { + it('should accept COLLECTION to SINGLE_OR_COLLECTION of same type', () => { const r = validateConnectionTypes( - { name: 'FooField', isCollection: true, isCollectionOrScalar: false }, - { name: 'FooField', isCollection: false, isCollectionOrScalar: true } + { name: 'FooField', cardinality: 'COLLECTION' }, + { name: 'FooField', cardinality: 'SINGLE_OR_COLLECTION' } ); expect(r).toBe(true); }); - it('should reject Collection to Scalar of same type', () => { + it('should reject COLLECTION to SINGLE of same type', () => { const r = validateConnectionTypes( - { name: 'FooField', isCollection: true, isCollectionOrScalar: false }, - { name: 'FooField', isCollection: false, isCollectionOrScalar: false } + { name: 'FooField', cardinality: 'COLLECTION' }, + { name: 'FooField', cardinality: 'SINGLE' } ); expect(r).toBe(false); }); - it('should reject CollectionOrScalar to Scalar of same type', () => { + it('should reject SINGLE_OR_COLLECTION to SINGLE of same type', () => { const r = validateConnectionTypes( - { name: 'FooField', isCollection: false, isCollectionOrScalar: true }, - { name: 'FooField', isCollection: false, isCollectionOrScalar: false } + { name: 'FooField', cardinality: 'SINGLE_OR_COLLECTION' }, + { name: 'FooField', cardinality: 'SINGLE' } ); expect(r).toBe(false); }); it('should reject mismatched types', () => { const r = validateConnectionTypes( - { name: 'FooField', isCollection: false, isCollectionOrScalar: false }, - { name: 'BarField', isCollection: false, isCollectionOrScalar: false } + { name: 'FooField', cardinality: 'SINGLE' }, + { name: 'BarField', cardinality: 'SINGLE' } ); expect(r).toBe(false); }); }); describe('special cases', () => { - it('should reject a collection input to a collection input', () => { + it('should reject a COLLECTION input to a COLLECTION input', () => { const r = validateConnectionTypes( - { name: 'CollectionField', isCollection: true, isCollectionOrScalar: false }, - { name: 'CollectionField', isCollection: true, isCollectionOrScalar: false } + { name: 'CollectionField', cardinality: 'COLLECTION' }, + { name: 'CollectionField', cardinality: 'COLLECTION' } ); expect(r).toBe(false); }); it('should accept equal types', () => { const r = validateConnectionTypes( - { name: 'IntegerField', isCollection: false, isCollectionOrScalar: false }, - { name: 'IntegerField', isCollection: false, isCollectionOrScalar: false } + { name: 'IntegerField', cardinality: 'SINGLE' }, + { name: 'IntegerField', cardinality: 'SINGLE' } ); expect(r).toBe(true); }); describe('CollectionItemField', () => { - it('should accept CollectionItemField to any Scalar target', () => { + it('should accept CollectionItemField to any SINGLE target', () => { const r = validateConnectionTypes( - { name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false }, - { name: 'IntegerField', isCollection: false, isCollectionOrScalar: false } + { name: 'CollectionItemField', cardinality: 'SINGLE' }, + { name: 'IntegerField', cardinality: 'SINGLE' } ); expect(r).toBe(true); }); - it('should accept CollectionItemField to any CollectionOrScalar target', () => { + it('should accept CollectionItemField to any SINGLE_OR_COLLECTION target', () => { const r = validateConnectionTypes( - { name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false }, - { name: 'IntegerField', isCollection: false, isCollectionOrScalar: true } + { name: 'CollectionItemField', cardinality: 'SINGLE' }, + { name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' } ); expect(r).toBe(true); }); - it('should accept any non-Collection to CollectionItemField', () => { + it('should accept any SINGLE to CollectionItemField', () => { const r = validateConnectionTypes( - { name: 'IntegerField', isCollection: false, isCollectionOrScalar: false }, - { name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false } + { name: 'IntegerField', cardinality: 'SINGLE' }, + { name: 'CollectionItemField', cardinality: 'SINGLE' } ); expect(r).toBe(true); }); - it('should reject any Collection to CollectionItemField', () => { + it('should reject any COLLECTION to CollectionItemField', () => { const r = validateConnectionTypes( - { name: 'IntegerField', isCollection: true, isCollectionOrScalar: false }, - { name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false } + { name: 'IntegerField', cardinality: 'COLLECTION' }, + { name: 'CollectionItemField', cardinality: 'SINGLE' } ); expect(r).toBe(false); }); - it('should reject any CollectionOrScalar to CollectionItemField', () => { + it('should reject any SINGLE_OR_COLLECTION to CollectionItemField', () => { const r = validateConnectionTypes( - { name: 'IntegerField', isCollection: false, isCollectionOrScalar: true }, - { name: 'CollectionItemField', isCollection: false, isCollectionOrScalar: false } + { name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' }, + { name: 'CollectionItemField', cardinality: 'SINGLE' } ); expect(r).toBe(false); }); }); - describe('CollectionOrScalar', () => { - it('should accept any Scalar of same type to CollectionOrScalar', () => { + describe('SINGLE_OR_COLLECTION', () => { + it('should accept any SINGLE of same type to SINGLE_OR_COLLECTION', () => { const r = validateConnectionTypes( - { name: 'IntegerField', isCollection: false, isCollectionOrScalar: false }, - { name: 'IntegerField', isCollection: false, isCollectionOrScalar: true } + { name: 'IntegerField', cardinality: 'SINGLE' }, + { name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' } ); expect(r).toBe(true); }); - it('should accept any Collection of same type to CollectionOrScalar', () => { + it('should accept any COLLECTION of same type to SINGLE_OR_COLLECTION', () => { const r = validateConnectionTypes( - { name: 'IntegerField', isCollection: true, isCollectionOrScalar: false }, - { name: 'IntegerField', isCollection: false, isCollectionOrScalar: true } + { name: 'IntegerField', cardinality: 'COLLECTION' }, + { name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' } ); expect(r).toBe(true); }); - it('should accept any CollectionOrScalar of same type to CollectionOrScalar', () => { + it('should accept any SINGLE_OR_COLLECTION of same type to SINGLE_OR_COLLECTION', () => { const r = validateConnectionTypes( - { name: 'IntegerField', isCollection: false, isCollectionOrScalar: true }, - { name: 'IntegerField', isCollection: false, isCollectionOrScalar: true } + { name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' }, + { name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' } ); expect(r).toBe(true); }); }); describe('CollectionField', () => { - it('should accept any CollectionField to any Collection type', () => { + it('should accept any CollectionField to any COLLECTION type', () => { const r = validateConnectionTypes( - { name: 'CollectionField', isCollection: false, isCollectionOrScalar: false }, - { name: 'IntegerField', isCollection: true, isCollectionOrScalar: false } + { name: 'CollectionField', cardinality: 'SINGLE' }, + { name: 'IntegerField', cardinality: 'COLLECTION' } ); expect(r).toBe(true); }); - it('should accept any CollectionField to any CollectionOrScalar type', () => { + it('should accept any CollectionField to any SINGLE_OR_COLLECTION type', () => { const r = validateConnectionTypes( - { name: 'CollectionField', isCollection: false, isCollectionOrScalar: false }, - { name: 'IntegerField', isCollection: false, isCollectionOrScalar: true } + { name: 'CollectionField', cardinality: 'SINGLE' }, + { name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' } ); expect(r).toBe(true); }); @@ -158,62 +158,62 @@ describe(validateConnectionTypes.name, () => { { t1: 'IntegerField', t2: 'StringField' }, { t1: 'FloatField', t2: 'StringField' }, ]; - it.each(typePairs)('should accept Scalar $t1 to Scalar $t2', ({ t1, t2 }: TypePair) => { + it.each(typePairs)('should accept SINGLE $t1 to SINGLE $t2', ({ t1, t2 }: TypePair) => { + const r = validateConnectionTypes({ name: t1, cardinality: 'SINGLE' }, { name: t2, cardinality: 'SINGLE' }); + expect(r).toBe(true); + }); + it.each(typePairs)('should accept SINGLE $t1 to SINGLE_OR_COLLECTION $t2', ({ t1, t2 }: TypePair) => { const r = validateConnectionTypes( - { name: t1, isCollection: false, isCollectionOrScalar: false }, - { name: t2, isCollection: false, isCollectionOrScalar: false } + { name: t1, cardinality: 'SINGLE' }, + { name: t2, cardinality: 'SINGLE_OR_COLLECTION' } ); expect(r).toBe(true); }); - it.each(typePairs)('should accept Scalar $t1 to CollectionOrScalar $t2', ({ t1, t2 }: TypePair) => { + it.each(typePairs)('should accept COLLECTION $t1 to COLLECTION $t2', ({ t1, t2 }: TypePair) => { const r = validateConnectionTypes( - { name: t1, isCollection: false, isCollectionOrScalar: false }, - { name: t2, isCollection: false, isCollectionOrScalar: true } + { name: t1, cardinality: 'COLLECTION' }, + { name: t2, cardinality: 'COLLECTION' } ); expect(r).toBe(true); }); - it.each(typePairs)('should accept Collection $t1 to Collection $t2', ({ t1, t2 }: TypePair) => { + it.each(typePairs)('should accept COLLECTION $t1 to SINGLE_OR_COLLECTION $t2', ({ t1, t2 }: TypePair) => { const r = validateConnectionTypes( - { name: t1, isCollection: true, isCollectionOrScalar: false }, - { name: t2, isCollection: true, isCollectionOrScalar: false } - ); - expect(r).toBe(true); - }); - it.each(typePairs)('should accept Collection $t1 to CollectionOrScalar $t2', ({ t1, t2 }: TypePair) => { - const r = validateConnectionTypes( - { name: t1, isCollection: true, isCollectionOrScalar: false }, - { name: t2, isCollection: false, isCollectionOrScalar: true } - ); - expect(r).toBe(true); - }); - it.each(typePairs)('should accept CollectionOrScalar $t1 to CollectionOrScalar $t2', ({ t1, t2 }: TypePair) => { - const r = validateConnectionTypes( - { name: t1, isCollection: false, isCollectionOrScalar: true }, - { name: t2, isCollection: false, isCollectionOrScalar: true } + { name: t1, cardinality: 'COLLECTION' }, + { name: t2, cardinality: 'SINGLE_OR_COLLECTION' } ); expect(r).toBe(true); }); + it.each(typePairs)( + 'should accept SINGLE_OR_COLLECTION $t1 to SINGLE_OR_COLLECTION $t2', + ({ t1, t2 }: TypePair) => { + const r = validateConnectionTypes( + { name: t1, cardinality: 'SINGLE_OR_COLLECTION' }, + { name: t2, cardinality: 'SINGLE_OR_COLLECTION' } + ); + expect(r).toBe(true); + } + ); }); describe('AnyField', () => { - it('should accept any Scalar type to AnyField', () => { + it('should accept any SINGLE type to AnyField', () => { const r = validateConnectionTypes( - { name: 'FooField', isCollection: false, isCollectionOrScalar: false }, - { name: 'AnyField', isCollection: false, isCollectionOrScalar: false } + { name: 'FooField', cardinality: 'SINGLE' }, + { name: 'AnyField', cardinality: 'SINGLE' } ); expect(r).toBe(true); }); - it('should accept any Collection type to AnyField', () => { + it('should accept any COLLECTION type to AnyField', () => { const r = validateConnectionTypes( - { name: 'FooField', isCollection: false, isCollectionOrScalar: false }, - { name: 'AnyField', isCollection: true, isCollectionOrScalar: false } + { name: 'FooField', cardinality: 'SINGLE' }, + { name: 'AnyField', cardinality: 'COLLECTION' } ); expect(r).toBe(true); }); - it('should accept any CollectionOrScalar type to AnyField', () => { + it('should accept any SINGLE_OR_COLLECTION type to AnyField', () => { const r = validateConnectionTypes( - { name: 'FooField', isCollection: false, isCollectionOrScalar: false }, - { name: 'AnyField', isCollection: false, isCollectionOrScalar: true } + { name: 'FooField', cardinality: 'SINGLE' }, + { name: 'AnyField', cardinality: 'SINGLE_OR_COLLECTION' } ); expect(r).toBe(true); }); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts index 778b33a7b1..a71ff513aa 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts @@ -1,5 +1,5 @@ import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual'; -import type { FieldType } from 'features/nodes/types/field'; +import { type FieldType, isCollection, isSingle, isSingleOrCollection } from 'features/nodes/types/field'; /** * Validates that the source and target types are compatible for a connection. @@ -27,38 +27,37 @@ export const validateConnectionTypes = (sourceType: FieldType, targetType: Field * - Generic Collection can connect to any other Collection or CollectionOrScalar * - Any Collection can connect to a Generic Collection */ - const isCollectionItemToNonCollection = sourceType.name === 'CollectionItemField' && !targetType.isCollection; + const isCollectionItemToNonCollection = sourceType.name === 'CollectionItemField' && !isCollection(targetType); - const isNonCollectionToCollectionItem = - targetType.name === 'CollectionItemField' && !sourceType.isCollection && !sourceType.isCollectionOrScalar; + const isNonCollectionToCollectionItem = isSingle(sourceType) && targetType.name === 'CollectionItemField'; - const isAnythingToCollectionOrScalarOfSameBaseType = - targetType.isCollectionOrScalar && sourceType.name === targetType.name; + const isAnythingToSingleOrCollectionOfSameBaseType = + isSingleOrCollection(targetType) && sourceType.name === targetType.name; - const isGenericCollectionToAnyCollectionOrCollectionOrScalar = - sourceType.name === 'CollectionField' && (targetType.isCollection || targetType.isCollectionOrScalar); + const isGenericCollectionToAnyCollectionOrSingleOrCollection = + sourceType.name === 'CollectionField' && !isSingle(targetType); - const isCollectionToGenericCollection = targetType.name === 'CollectionField' && sourceType.isCollection; + const isCollectionToGenericCollection = targetType.name === 'CollectionField' && isCollection(sourceType); - const isSourceScalar = !sourceType.isCollection && !sourceType.isCollectionOrScalar; - const isTargetScalar = !targetType.isCollection && !targetType.isCollectionOrScalar; - const isScalarToScalar = isSourceScalar && isTargetScalar; - const isScalarToCollectionOrScalar = isSourceScalar && targetType.isCollectionOrScalar; - const isCollectionToCollection = sourceType.isCollection && targetType.isCollection; - const isCollectionToCollectionOrScalar = sourceType.isCollection && targetType.isCollectionOrScalar; - const isCollectionOrScalarToCollectionOrScalar = sourceType.isCollectionOrScalar && targetType.isCollectionOrScalar; - const isPluralityMatch = - isScalarToScalar || + const isSourceSingle = isSingle(sourceType); + const isTargetSingle = isSingle(targetType); + const isSingleToSingle = isSourceSingle && isTargetSingle; + const isSingleToSingleOrCollection = isSourceSingle && isSingleOrCollection(targetType); + const isCollectionToCollection = isCollection(sourceType) && isCollection(targetType); + const isCollectionToSingleOrCollection = isCollection(sourceType) && isSingleOrCollection(targetType); + const isSingleOrCollectionToSingleOrCollection = isSingleOrCollection(sourceType) && isSingleOrCollection(targetType); + const doesCardinalityMatch = + isSingleToSingle || isCollectionToCollection || - isCollectionToCollectionOrScalar || - isCollectionOrScalarToCollectionOrScalar || - isScalarToCollectionOrScalar; + isCollectionToSingleOrCollection || + isSingleOrCollectionToSingleOrCollection || + isSingleToSingleOrCollection; const isIntToFloat = sourceType.name === 'IntegerField' && targetType.name === 'FloatField'; const isIntToString = sourceType.name === 'IntegerField' && targetType.name === 'StringField'; const isFloatToString = sourceType.name === 'FloatField' && targetType.name === 'StringField'; - const isSubTypeMatch = isPluralityMatch && (isIntToFloat || isIntToString || isFloatToString); + const isSubTypeMatch = doesCardinalityMatch && (isIntToFloat || isIntToString || isFloatToString); const isTargetAnyType = targetType.name === 'AnyField'; @@ -66,8 +65,8 @@ export const validateConnectionTypes = (sourceType: FieldType, targetType: Field return ( isCollectionItemToNonCollection || isNonCollectionToCollectionItem || - isAnythingToCollectionOrScalarOfSameBaseType || - isGenericCollectionToAnyCollectionOrCollectionOrScalar || + isAnythingToSingleOrCollectionOfSameBaseType || + isGenericCollectionToAnyCollectionOrSingleOrCollection || isCollectionToGenericCollection || isSubTypeMatch || isTargetAnyType diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.test.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.test.ts index cc12b45aa6..3d3aff3cd6 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.test.ts @@ -4,6 +4,7 @@ import { UnsupportedPrimitiveTypeError, UnsupportedUnionError, } from 'features/nodes/types/error'; +import type { FieldType } from 'features/nodes/types/field'; import type { InvocationFieldSchema, OpenAPIV3_1SchemaOrRef } from 'features/nodes/types/openapi'; import { parseFieldType, refObjectToSchemaName } from 'features/nodes/util/schema/parseFieldType'; import { describe, expect, it } from 'vitest'; @@ -11,52 +12,52 @@ import { describe, expect, it } from 'vitest'; type ParseFieldTypeTestCase = { name: string; schema: OpenAPIV3_1SchemaOrRef | InvocationFieldSchema; - expected: { name: string; isCollection: boolean; isCollectionOrScalar: boolean }; + expected: FieldType; }; const primitiveTypes: ParseFieldTypeTestCase[] = [ { - name: 'Scalar IntegerField', + name: 'SINGLE IntegerField', schema: { type: 'integer' }, - expected: { name: 'IntegerField', isCollection: false, isCollectionOrScalar: false }, + expected: { name: 'IntegerField', cardinality: 'SINGLE' }, }, { - name: 'Scalar FloatField', + name: 'SINGLE FloatField', schema: { type: 'number' }, - expected: { name: 'FloatField', isCollection: false, isCollectionOrScalar: false }, + expected: { name: 'FloatField', cardinality: 'SINGLE' }, }, { - name: 'Scalar StringField', + name: 'SINGLE StringField', schema: { type: 'string' }, - expected: { name: 'StringField', isCollection: false, isCollectionOrScalar: false }, + expected: { name: 'StringField', cardinality: 'SINGLE' }, }, { - name: 'Scalar BooleanField', + name: 'SINGLE BooleanField', schema: { type: 'boolean' }, - expected: { name: 'BooleanField', isCollection: false, isCollectionOrScalar: false }, + expected: { name: 'BooleanField', cardinality: 'SINGLE' }, }, { - name: 'Collection IntegerField', + name: 'COLLECTION IntegerField', schema: { items: { type: 'integer' }, type: 'array' }, - expected: { name: 'IntegerField', isCollection: true, isCollectionOrScalar: false }, + expected: { name: 'IntegerField', cardinality: 'COLLECTION' }, }, { - name: 'Collection FloatField', + name: 'COLLECTION FloatField', schema: { items: { type: 'number' }, type: 'array' }, - expected: { name: 'FloatField', isCollection: true, isCollectionOrScalar: false }, + expected: { name: 'FloatField', cardinality: 'COLLECTION' }, }, { - name: 'Collection StringField', + name: 'COLLECTION StringField', schema: { items: { type: 'string' }, type: 'array' }, - expected: { name: 'StringField', isCollection: true, isCollectionOrScalar: false }, + expected: { name: 'StringField', cardinality: 'COLLECTION' }, }, { - name: 'Collection BooleanField', + name: 'COLLECTION BooleanField', schema: { items: { type: 'boolean' }, type: 'array' }, - expected: { name: 'BooleanField', isCollection: true, isCollectionOrScalar: false }, + expected: { name: 'BooleanField', cardinality: 'COLLECTION' }, }, { - name: 'CollectionOrScalar IntegerField', + name: 'SINGLE_OR_COLLECTION IntegerField', schema: { anyOf: [ { @@ -70,10 +71,10 @@ const primitiveTypes: ParseFieldTypeTestCase[] = [ }, ], }, - expected: { name: 'IntegerField', isCollection: false, isCollectionOrScalar: true }, + expected: { name: 'IntegerField', cardinality: 'SINGLE_OR_COLLECTION' }, }, { - name: 'CollectionOrScalar FloatField', + name: 'SINGLE_OR_COLLECTION FloatField', schema: { anyOf: [ { @@ -87,10 +88,10 @@ const primitiveTypes: ParseFieldTypeTestCase[] = [ }, ], }, - expected: { name: 'FloatField', isCollection: false, isCollectionOrScalar: true }, + expected: { name: 'FloatField', cardinality: 'SINGLE_OR_COLLECTION' }, }, { - name: 'CollectionOrScalar StringField', + name: 'SINGLE_OR_COLLECTION StringField', schema: { anyOf: [ { @@ -104,10 +105,10 @@ const primitiveTypes: ParseFieldTypeTestCase[] = [ }, ], }, - expected: { name: 'StringField', isCollection: false, isCollectionOrScalar: true }, + expected: { name: 'StringField', cardinality: 'SINGLE_OR_COLLECTION' }, }, { - name: 'CollectionOrScalar BooleanField', + name: 'SINGLE_OR_COLLECTION BooleanField', schema: { anyOf: [ { @@ -121,13 +122,13 @@ const primitiveTypes: ParseFieldTypeTestCase[] = [ }, ], }, - expected: { name: 'BooleanField', isCollection: false, isCollectionOrScalar: true }, + expected: { name: 'BooleanField', cardinality: 'SINGLE_OR_COLLECTION' }, }, ]; const complexTypes: ParseFieldTypeTestCase[] = [ { - name: 'Scalar ConditioningField', + name: 'SINGLE ConditioningField', schema: { allOf: [ { @@ -135,10 +136,10 @@ const complexTypes: ParseFieldTypeTestCase[] = [ }, ], }, - expected: { name: 'ConditioningField', isCollection: false, isCollectionOrScalar: false }, + expected: { name: 'ConditioningField', cardinality: 'SINGLE' }, }, { - name: 'Nullable Scalar ConditioningField', + name: 'Nullable SINGLE ConditioningField', schema: { anyOf: [ { @@ -149,10 +150,10 @@ const complexTypes: ParseFieldTypeTestCase[] = [ }, ], }, - expected: { name: 'ConditioningField', isCollection: false, isCollectionOrScalar: false }, + expected: { name: 'ConditioningField', cardinality: 'SINGLE' }, }, { - name: 'Collection ConditioningField', + name: 'COLLECTION ConditioningField', schema: { anyOf: [ { @@ -163,7 +164,7 @@ const complexTypes: ParseFieldTypeTestCase[] = [ }, ], }, - expected: { name: 'ConditioningField', isCollection: true, isCollectionOrScalar: false }, + expected: { name: 'ConditioningField', cardinality: 'COLLECTION' }, }, { name: 'Nullable Collection ConditioningField', @@ -180,10 +181,10 @@ const complexTypes: ParseFieldTypeTestCase[] = [ }, ], }, - expected: { name: 'ConditioningField', isCollection: true, isCollectionOrScalar: false }, + expected: { name: 'ConditioningField', cardinality: 'COLLECTION' }, }, { - name: 'CollectionOrScalar ConditioningField', + name: 'SINGLE_OR_COLLECTION ConditioningField', schema: { anyOf: [ { @@ -197,10 +198,10 @@ const complexTypes: ParseFieldTypeTestCase[] = [ }, ], }, - expected: { name: 'ConditioningField', isCollection: false, isCollectionOrScalar: true }, + expected: { name: 'ConditioningField', cardinality: 'SINGLE_OR_COLLECTION' }, }, { - name: 'Nullable CollectionOrScalar ConditioningField', + name: 'Nullable SINGLE_OR_COLLECTION ConditioningField', schema: { anyOf: [ { @@ -217,7 +218,7 @@ const complexTypes: ParseFieldTypeTestCase[] = [ }, ], }, - expected: { name: 'ConditioningField', isCollection: false, isCollectionOrScalar: true }, + expected: { name: 'ConditioningField', cardinality: 'SINGLE_OR_COLLECTION' }, }, ]; @@ -228,14 +229,14 @@ const specialCases: ParseFieldTypeTestCase[] = [ type: 'string', enum: ['large', 'base', 'small'], }, - expected: { name: 'EnumField', isCollection: false, isCollectionOrScalar: false }, + expected: { name: 'EnumField', cardinality: 'SINGLE' }, }, { name: 'String EnumField with one value', schema: { const: 'Some Value', }, - expected: { name: 'EnumField', isCollection: false, isCollectionOrScalar: false }, + expected: { name: 'EnumField', cardinality: 'SINGLE' }, }, { name: 'Explicit ui_type (SchedulerField)', @@ -244,7 +245,7 @@ const specialCases: ParseFieldTypeTestCase[] = [ enum: ['ddim', 'ddpm', 'deis'], ui_type: 'SchedulerField', }, - expected: { name: 'EnumField', isCollection: false, isCollectionOrScalar: false }, + expected: { name: 'EnumField', cardinality: 'SINGLE' }, }, { name: 'Explicit ui_type (AnyField)', @@ -253,7 +254,7 @@ const specialCases: ParseFieldTypeTestCase[] = [ enum: ['ddim', 'ddpm', 'deis'], ui_type: 'AnyField', }, - expected: { name: 'EnumField', isCollection: false, isCollectionOrScalar: false }, + expected: { name: 'EnumField', cardinality: 'SINGLE' }, }, { name: 'Explicit ui_type (CollectionField)', @@ -262,7 +263,7 @@ const specialCases: ParseFieldTypeTestCase[] = [ enum: ['ddim', 'ddpm', 'deis'], ui_type: 'CollectionField', }, - expected: { name: 'EnumField', isCollection: false, isCollectionOrScalar: false }, + expected: { name: 'EnumField', cardinality: 'SINGLE' }, }, ]; diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts index 6f6ecaa5bb..ea9bf5bce4 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts @@ -48,8 +48,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType // Fields with a single const value are defined as `Literal["value"]` in the pydantic schema - it's actually an enum return { name: 'EnumField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }; } if (!schemaObject.type) { @@ -65,8 +64,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType } return { name, - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }; } } else if (schemaObject.anyOf) { @@ -89,8 +87,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType return { name, - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }; } else if (isSchemaObject(filteredAnyOf[0])) { return parseFieldType(filteredAnyOf[0]); @@ -143,8 +140,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType if (firstType && firstType === secondType) { return { name: OPENAPI_TO_FIELD_TYPE_MAP[firstType] ?? firstType, - isCollection: false, - isCollectionOrScalar: true, // <-- don't forget, CollectionOrScalar type! + cardinality: 'SINGLE_OR_COLLECTION', }; } @@ -158,8 +154,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType } else if (schemaObject.enum) { return { name: 'EnumField', - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }; } else if (schemaObject.type) { if (schemaObject.type === 'array') { @@ -185,8 +180,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType } return { name, - isCollection: true, // <-- don't forget, collection! - isCollectionOrScalar: false, + cardinality: 'COLLECTION', }; } @@ -197,8 +191,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType } return { name, - isCollection: true, // <-- don't forget, collection! - isCollectionOrScalar: false, + cardinality: 'COLLECTION', }; } else if (!isArray(schemaObject.type)) { // This is an OpenAPI primitive - 'null', 'object', 'array', 'integer', 'number', 'string', 'boolean' @@ -213,8 +206,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType } return { name, - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }; } } @@ -225,8 +217,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType } return { name, - isCollection: false, - isCollectionOrScalar: false, + cardinality: 'SINGLE', }; } throw new FieldParseError(t('nodes.unableToParseFieldType')); diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts index f9b93382f9..3981b759db 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts @@ -100,11 +100,10 @@ export const parseSchema = ( return inputsAccumulator; } - const fieldTypeOverride = property.ui_type + const fieldTypeOverride: FieldType | null = property.ui_type ? { name: property.ui_type, - isCollection: isCollectionFieldType(property.ui_type), - isCollectionOrScalar: false, + cardinality: isCollectionFieldType(property.ui_type) ? 'COLLECTION' : 'SINGLE', } : null; @@ -178,11 +177,10 @@ export const parseSchema = ( return outputsAccumulator; } - const fieldTypeOverride = property.ui_type + const fieldTypeOverride: FieldType | null = property.ui_type ? { name: property.ui_type, - isCollection: isCollectionFieldType(property.ui_type), - isCollectionOrScalar: false, + cardinality: isCollectionFieldType(property.ui_type) ? 'COLLECTION' : 'SINGLE', } : null; From 9e55ef3d4bbd8da2c787f3cb5af9162580a53a2c Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 22:56:35 +1000 Subject: [PATCH 070/207] fix(ui): workflow migration field type At some point, I made a mistake and imported the wrong types to some files for the old v1 and v2 workflow schema migration data. The relevant zod schemas and inferred types have been restored. This change doesn't alter runtime behaviour. Only type annotations. --- .../features/nodes/types/v1/fieldTypeMap.ts | 4 ++-- .../web/src/features/nodes/types/v2/field.ts | 22 +++++++++++++++++++ .../nodes/util/workflow/migrations.ts | 6 ++--- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts b/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts index 79946cd8d5..f1d4e61300 100644 --- a/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts +++ b/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts @@ -1,4 +1,4 @@ -import type { FieldType, StatefulFieldType } from 'features/nodes/types/field'; +import type { StatefulFieldType, StatelessFieldType } from 'features/nodes/types/v2/field'; import type { FieldTypeV1 } from './workflowV1'; @@ -165,7 +165,7 @@ const FIELD_TYPE_V1_TO_STATEFUL_FIELD_TYPE_V2: { * Thus, this object was manually edited to ensure it is correct. */ const FIELD_TYPE_V1_TO_STATELESS_FIELD_TYPE_V2: { - [key in FieldTypeV1]?: FieldType; + [key in FieldTypeV1]?: StatelessFieldType; } = { Any: { name: 'AnyField', isCollection: false, isCollectionOrScalar: false }, ClipField: { diff --git a/invokeai/frontend/web/src/features/nodes/types/v2/field.ts b/invokeai/frontend/web/src/features/nodes/types/v2/field.ts index 1e464fa76d..4b680d1de3 100644 --- a/invokeai/frontend/web/src/features/nodes/types/v2/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/v2/field.ts @@ -316,6 +316,7 @@ const zSchedulerFieldOutputInstance = zFieldOutputInstanceBase.extend({ const zStatelessFieldType = zFieldTypeBase.extend({ name: z.string().min(1), // stateless --> we accept the field's name as the type }); +export type StatelessFieldType = z.infer; const zStatelessFieldValue = z.undefined().catch(undefined); // stateless --> no value, but making this z.never() introduces a lot of extra TS fanagling const zStatelessFieldInputInstance = zFieldInputInstanceBase.extend({ type: zStatelessFieldType, @@ -327,6 +328,27 @@ const zStatelessFieldOutputInstance = zFieldOutputInstanceBase.extend({ // #endregion +const zStatefulFieldType = z.union([ + zIntegerFieldType, + zFloatFieldType, + zStringFieldType, + zBooleanFieldType, + zEnumFieldType, + zImageFieldType, + zBoardFieldType, + zMainModelFieldType, + zSDXLMainModelFieldType, + zSDXLRefinerModelFieldType, + zVAEModelFieldType, + zLoRAModelFieldType, + zControlNetModelFieldType, + zIPAdapterModelFieldType, + zT2IAdapterModelFieldType, + zColorFieldType, + zSchedulerFieldType, +]); +export type StatefulFieldType = z.infer; + /** * Here we define the main field unions: * - FieldType diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts index 32369b88c9..c7bcbf0953 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts @@ -1,12 +1,12 @@ import { deepClone } from 'common/util/deepClone'; import { $templates } from 'features/nodes/store/nodesSlice'; import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error'; -import type { FieldType } from 'features/nodes/types/field'; import type { InvocationNodeData } from 'features/nodes/types/invocation'; import { zSemVer } from 'features/nodes/types/semver'; import { FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING } from 'features/nodes/types/v1/fieldTypeMap'; import type { WorkflowV1 } from 'features/nodes/types/v1/workflowV1'; import { zWorkflowV1 } from 'features/nodes/types/v1/workflowV1'; +import type { StatelessFieldType } from 'features/nodes/types/v2/field'; import type { WorkflowV2 } from 'features/nodes/types/v2/workflow'; import { zWorkflowV2 } from 'features/nodes/types/v2/workflow'; import type { WorkflowV3 } from 'features/nodes/types/workflow'; @@ -43,14 +43,14 @@ const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => { if (!newFieldType) { throw new WorkflowMigrationError(t('nodes.unknownFieldType', { type: input.type })); } - (input.type as unknown as FieldType) = newFieldType; + (input.type as unknown as StatelessFieldType) = newFieldType; }); forEach(node.data.outputs, (output) => { const newFieldType = FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING[output.type]; if (!newFieldType) { throw new WorkflowMigrationError(t('nodes.unknownFieldType', { type: output.type })); } - (output.type as unknown as FieldType) = newFieldType; + (output.type as unknown as StatelessFieldType) = newFieldType; }); // Add node pack const invocationTemplate = templates[node.data.type]; From e88b807a13ab6f664c56f97a10d0c97467e2a103 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 23:22:07 +1000 Subject: [PATCH 071/207] docs(ui): update field type docs & comments --- docs/contributing/frontend/WORKFLOWS.md | 21 +++++++++---------- .../store/util/validateConnectionTypes.ts | 10 ++++----- .../canvas/addControlNetToLinearGraph.ts | 2 +- .../graph/canvas/addIPAdapterToLinearGraph.ts | 2 +- .../canvas/addT2IAdapterToLinearGraph.ts | 2 +- .../nodes/util/schema/parseFieldType.ts | 2 +- 6 files changed, 19 insertions(+), 20 deletions(-) diff --git a/docs/contributing/frontend/WORKFLOWS.md b/docs/contributing/frontend/WORKFLOWS.md index e71d797b8a..533419e070 100644 --- a/docs/contributing/frontend/WORKFLOWS.md +++ b/docs/contributing/frontend/WORKFLOWS.md @@ -117,13 +117,13 @@ Stateless fields do not store their value in the node, so their field instances "Custom" fields will always be treated as stateless fields. -##### Collection and Scalar Fields +##### Single and Collection Fields -Field types have a name and two flags which may identify it as a **collection** or **collection or scalar** field. +Field types have a name and cardinality property which may identify it as a **SINGLE**, **COLLECTION** or **SINGLE_OR_COLLECTION** field. -If a field is annotated in python as a list, its field type is parsed and flagged as a **collection** type (e.g. `list[int]`). - -If it is annotated as a union of a type and list, the type will be flagged as a **collection or scalar** type (e.g. `Union[int, list[int]]`). Fields may not be unions of different types (e.g. `Union[int, list[str]]` and `Union[int, str]` are not allowed). +- If a field is annotated in python as a singular value or class, its field type is parsed as a **SINGLE** type (e.g. `int`, `ImageField`, `str`). +- If a field is annotated in python as a list, its field type is parsed as a **COLLECTION** type (e.g. `list[int]`). +- If it is annotated as a union of a type and list, the type will be parsed as a **SINGLE_OR_COLLECTION** type (e.g. `Union[int, list[int]]`). Fields may not be unions of different types (e.g. `Union[int, list[str]]` and `Union[int, str]` are not allowed). ## Implementation @@ -173,8 +173,7 @@ Field types are represented as structured objects: ```ts type FieldType = { name: string; - isCollection: boolean; - isCollectionOrScalar: boolean; + cardinality: 'SINGLE' | 'COLLECTION' | 'SINGLE_OR_COLLECTION'; }; ``` @@ -186,7 +185,7 @@ There are 4 general cases for field type parsing. When a field is annotated as a primitive values (e.g. `int`, `str`, `float`), the field type parsing is fairly straightforward. The field is represented by a simple OpenAPI **schema object**, which has a `type` property. -We create a field type name from this `type` string (e.g. `string` -> `StringField`). +We create a field type name from this `type` string (e.g. `string` -> `StringField`). The cardinality is `"SINGLE"`. ##### Complex Types @@ -200,13 +199,13 @@ We need to **dereference** the schema to pull these out. Dereferencing may requi When a field is annotated as a list of a single type, the schema object has an `items` property. They may be a schema object or reference object and must be parsed to determine the item type. -We use the item type for field type name, adding `isCollection: true` to the field type. +We use the item type for field type name. The cardinality is `"COLLECTION"`. -##### Collection or Scalar Types +##### Single or Collection Types When a field is annotated as a union of a type and list of that type, the schema object has an `anyOf` property, which holds a list of valid types for the union. -After verifying that the union has two members (a type and list of the same type), we use the type for field type name, adding `isCollectionOrScalar: true` to the field type. +After verifying that the union has two members (a type and list of the same type), we use the type for field type name, with cardinality `"SINGLE_OR_COLLECTION"`. ##### Optional Fields diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts index a71ff513aa..d5dee6dbaf 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnectionTypes.ts @@ -21,11 +21,11 @@ export const validateConnectionTypes = (sourceType: FieldType, targetType: Field /** * Connection types must be the same for a connection, with exceptions: - * - CollectionItem can connect to any non-Collection - * - Non-Collections can connect to CollectionItem - * - Anything (non-Collections, Collections, CollectionOrScalar) can connect to CollectionOrScalar of the same base type - * - Generic Collection can connect to any other Collection or CollectionOrScalar - * - Any Collection can connect to a Generic Collection + * - CollectionItem can connect to any non-COLLECTION (e.g. SINGLE or SINGLE_OR_COLLECTION) + * - SINGLE can connect to CollectionItem + * - Anything (SINGLE, COLLECTION, SINGLE_OR_COLLECTION) can connect to SINGLE_OR_COLLECTION of the same base type + * - Generic CollectionField can connect to any other COLLECTION or SINGLE_OR_COLLECTION + * - Any COLLECTION can connect to a Generic Collection */ const isCollectionItemToNonCollection = sourceType.name === 'CollectionItemField' && !isCollection(targetType); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/addControlNetToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/addControlNetToLinearGraph.ts index 2feba262c2..110a20e5a7 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/addControlNetToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/addControlNetToLinearGraph.ts @@ -29,7 +29,7 @@ export const addControlNetToLinearGraph = async ( assert(activeTabName !== 'generation', 'Tried to use addControlNetToLinearGraph on generation tab'); if (controlNets.length) { - // Even though denoise_latents' control input is collection or scalar, keep it simple and always use a collect + // Even though denoise_latents' control input is SINGLE_OR_COLLECTION, keep it simple and always use a collect const controlNetIterateNode: Invocation<'collect'> = { id: CONTROL_NET_COLLECT, type: 'collect', diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/addIPAdapterToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/addIPAdapterToLinearGraph.ts index e9d9bd4663..1f24463419 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/addIPAdapterToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/addIPAdapterToLinearGraph.ts @@ -25,7 +25,7 @@ export const addIPAdapterToLinearGraph = async ( }); if (ipAdapters.length) { - // Even though denoise_latents' ip adapter input is collection or scalar, keep it simple and always use a collect + // Even though denoise_latents' ip adapter input is SINGLE_OR_COLLECTION, keep it simple and always use a collect const ipAdapterCollectNode: Invocation<'collect'> = { id: IP_ADAPTER_COLLECT, type: 'collect', diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/addT2IAdapterToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/addT2IAdapterToLinearGraph.ts index 7c51d9488f..72cf9ca0f8 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/canvas/addT2IAdapterToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/canvas/addT2IAdapterToLinearGraph.ts @@ -28,7 +28,7 @@ export const addT2IAdaptersToLinearGraph = async ( ); if (t2iAdapters.length) { - // Even though denoise_latents' t2i adapter input is collection or scalar, keep it simple and always use a collect + // Even though denoise_latents' t2i adapter input is SINGLE_OR_COLLECTION, keep it simple and always use a collect const t2iAdapterCollectNode: Invocation<'collect'> = { id: T2I_ADAPTER_COLLECT, type: 'collect', diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts index ea9bf5bce4..18dcd8fb21 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts @@ -94,7 +94,7 @@ export const parseFieldType = (schemaObject: OpenAPIV3_1SchemaOrRef): FieldType } } /** - * Handle CollectionOrScalar inputs, eg string | string[]. In OpenAPI, this is: + * Handle SINGLE_OR_COLLECTION inputs, eg string | string[]. In OpenAPI, this is: * - an `anyOf` with two items * - one is an `ArraySchemaObject` with a single `SchemaObject or ReferenceObject` of type T in its `items` * - the other is a `SchemaObject` or `ReferenceObject` of type T From 1c29b3bd8573913deb915e4d5c67b86ec1daff8b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 19 May 2024 23:24:16 +1000 Subject: [PATCH 072/207] feat(ui): updated field type translations --- invokeai/frontend/web/public/locales/en.json | 5 +++-- .../web/src/features/nodes/hooks/usePrettyFieldType.ts | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 5dd411c544..1d41a1de63 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -780,8 +780,9 @@ "missingFieldTemplate": "Missing field template", "nodePack": "Node pack", "collection": "Collection", - "collectionFieldType": "{{name}} Collection", - "collectionOrScalarFieldType": "{{name}} Collection|Scalar", + "singleFieldType": "{{name}} (Single)", + "collectionFieldType": "{{name}} (Collection)", + "collectionOrScalarFieldType": "{{name}} (Single or Collection)", "colorCodeEdges": "Color-Code Edges", "colorCodeEdgesHelp": "Color-code edges according to their connected fields", "connectionWouldCreateCycle": "Connection would create a cycle", diff --git a/invokeai/frontend/web/src/features/nodes/hooks/usePrettyFieldType.ts b/invokeai/frontend/web/src/features/nodes/hooks/usePrettyFieldType.ts index 2600eae078..7f531c3dba 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/usePrettyFieldType.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/usePrettyFieldType.ts @@ -16,7 +16,7 @@ export const useFieldTypeName = (fieldType?: FieldType): string => { if (isSingleOrCollection(fieldType)) { return t('nodes.collectionOrScalarFieldType', { name }); } - return name; + return t('nodes.singleFieldType', { name }); }, [fieldType, t]); return name; From 55535881473c3dd9aa9672ff0f575802f87f08f2 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 20 May 2024 08:38:29 +1000 Subject: [PATCH 073/207] fix(ui): ensure invocation edges have a type --- .../web/src/features/nodes/store/nodesSlice.ts | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index c63734c871..e7c1877647 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -1,7 +1,6 @@ import type { PayloadAction, UnknownAction } from '@reduxjs/toolkit'; import { createSlice, isAnyOf } from '@reduxjs/toolkit'; import type { PersistConfig, RootState } from 'app/store/store'; -import { deepClone } from 'common/util/deepClone'; import { workflowLoaded } from 'features/nodes/store/actions'; import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants'; import type { @@ -105,7 +104,8 @@ export const nodesSlice = createSlice({ state.edges = applyEdgeChanges(edgeChanges, state.edges); }, edgesChanged: (state, action: PayloadAction) => { - const changes = deepClone(action.payload); + const changes: EdgeChange[] = []; + // We may need to massage the edge changes or otherwise handle them action.payload.forEach((change) => { if (change.type === 'remove' || change.type === 'select') { const edge = state.edges.find((e) => e.id === change.id); @@ -124,6 +124,13 @@ export const nodesSlice = createSlice({ } } } + if (change.type === 'add') { + if (!change.item.type) { + // We must add the edge type! + change.item.type = 'default'; + } + } + changes.push(change); }); state.edges = applyEdgeChanges(changes, state.edges); }, From 620ee2875ef52f796bb7ac4eb3146f32d9e9fa8f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 20 May 2024 08:41:05 +1000 Subject: [PATCH 074/207] fix(ui): store `hidden` state of edges in workflows This prevents a minor visual bug where collapsed edges between collapsed nodes didn't display correctly on first load of a workflow. --- invokeai/frontend/web/src/features/nodes/types/workflow.ts | 1 + .../web/src/features/nodes/util/workflow/buildWorkflow.ts | 1 + 2 files changed, 2 insertions(+) diff --git a/invokeai/frontend/web/src/features/nodes/types/workflow.ts b/invokeai/frontend/web/src/features/nodes/types/workflow.ts index a424bf8d4b..9805edfaf2 100644 --- a/invokeai/frontend/web/src/features/nodes/types/workflow.ts +++ b/invokeai/frontend/web/src/features/nodes/types/workflow.ts @@ -47,6 +47,7 @@ const zWorkflowEdgeDefault = zWorkflowEdgeBase.extend({ type: z.literal('default'), sourceHandle: z.string().trim().min(1), targetHandle: z.string().trim().min(1), + hidden: z.boolean().optional(), }); const zWorkflowEdgeCollapsed = zWorkflowEdgeBase.extend({ type: z.literal('collapsed'), diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts index b164dde90e..cec8b0a2b7 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts @@ -66,6 +66,7 @@ export const buildWorkflowFast: BuildWorkflowFunction = ({ nodes, edges, workflo target: edge.target, sourceHandle: edge.sourceHandle, targetHandle: edge.targetHandle, + hidden: edge.hidden, }); } else if (edge.type === 'collapsed') { newWorkflow.edges.push({ From 32277193b6768dc1ad5b11b6ceaf7b17ddfc3dbd Mon Sep 17 00:00:00 2001 From: steffylo Date: Mon, 20 May 2024 15:49:18 +0800 Subject: [PATCH 075/207] fix(ui): retain denoise strength and opacity when changing image --- .../controlLayers/store/controlLayersSlice.ts | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/store/controlLayersSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/controlLayersSlice.ts index 32e29918ae..dbd99c2450 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/controlLayersSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/controlLayersSlice.ts @@ -616,12 +616,24 @@ export const controlLayersSlice = createSlice({ iiLayerAdded: { reducer: (state, action: PayloadAction<{ layerId: string; imageDTO: ImageDTO | null }>) => { const { layerId, imageDTO } = action.payload; + + // Retain opacity and denoising strength of existing initial image layer if exists + let opacity = 1; + let denoisingStrength = 0.75; + const iiLayer = state.layers.find((l) => l.id === layerId); + if (iiLayer) { + assert(isInitialImageLayer(iiLayer)); + opacity = iiLayer.opacity; + denoisingStrength = iiLayer.denoisingStrength; + } + // Highlander! There can be only one! state.layers = state.layers.filter((l) => (isInitialImageLayer(l) ? false : true)); + const layer: InitialImageLayer = { id: layerId, type: 'initial_image_layer', - opacity: 1, + opacity, x: 0, y: 0, bbox: null, @@ -629,7 +641,7 @@ export const controlLayersSlice = createSlice({ isEnabled: true, image: imageDTO ? imageDTOToImageWithDims(imageDTO) : null, isSelected: true, - denoisingStrength: 0.75, + denoisingStrength, }; state.layers.push(layer); exclusivelySelectLayer(state, layer.id); From 66c9f4708d14682d98ef0dc326c08d6a249de07c Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 21 May 2024 06:59:56 +1000 Subject: [PATCH 076/207] Update invokeai_version.py --- invokeai/version/invokeai_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/version/invokeai_version.py b/invokeai/version/invokeai_version.py index aef46acb47..2e905e44da 100644 --- a/invokeai/version/invokeai_version.py +++ b/invokeai/version/invokeai_version.py @@ -1 +1 @@ -__version__ = "4.2.1" +__version__ = "4.2.2" From 1249d4a6e3a9237272aef59835035cc683dc6e22 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 21 May 2024 10:06:09 +1000 Subject: [PATCH 077/207] fix(ui): crash when using a notes node --- .../src/features/nodes/hooks/useNodeLabel.ts | 6 ++--- .../nodes/hooks/useNodeTemplateTitle.ts | 22 ++++++++++++++++--- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts index 31dcb9c466..56e77a39e8 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts @@ -1,14 +1,14 @@ import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { selectNodeData } from 'features/nodes/store/selectors'; import { useMemo } from 'react'; export const useNodeLabel = (nodeId: string) => { const selector = useMemo( () => - createSelector(selectNodesSlice, (nodes) => { - return selectNodeData(nodes, nodeId)?.label ?? null; + createSelector(selectNodesSlice, (nodesSlice) => { + const node = nodesSlice.nodes.find((node) => node.id === nodeId); + return node?.data.label; }), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts index a63e0433aa..39ae617460 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts @@ -1,8 +1,24 @@ -import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate'; +import { useStore } from '@nanostores/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { useAppSelector } from 'app/store/storeHooks'; +import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice'; +import { isInvocationNode } from 'features/nodes/types/invocation'; import { useMemo } from 'react'; export const useNodeTemplateTitle = (nodeId: string): string | null => { - const template = useNodeTemplate(nodeId); - const title = useMemo(() => template.title, [template.title]); + const templates = useStore($templates); + const selector = useMemo( + () => + createSelector(selectNodesSlice, (nodesSlice) => { + const node = nodesSlice.nodes.find((node) => node.id === nodeId); + if (!isInvocationNode(node)) { + return null; + } + const template = templates[node.data.type]; + return template?.title ?? null; + }), + [nodeId, templates] + ); + const title = useAppSelector(selector); return title; }; From e75f98317f2333909cd9e180c211876711d9cccb Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 21 May 2024 10:06:25 +1000 Subject: [PATCH 078/207] fix(ui): notes node text not selectable --- .../features/nodes/components/flow/nodes/Notes/NotesNode.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Notes/NotesNode.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Notes/NotesNode.tsx index 966809cb0e..76666af396 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Notes/NotesNode.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Notes/NotesNode.tsx @@ -48,7 +48,7 @@ const NotesNode = (props: NodeProps) => { gap={1} > -