From ace0eb366b429aabd2b05474f51ff6d67404f74e Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 27 Nov 2023 23:18:23 -0500 Subject: [PATCH 01/65] pin opencv-python to get required cv2.typing module --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 97449a700f..08858059fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,7 @@ dependencies = [ "omegaconf", "onnx", "onnxruntime", - "opencv-python", + "opencv-python~=4.8.1.1", "pydantic~=2.5.0", "pydantic-settings~=2.0.3", "picklescan", From ff0a25bd9c800868656bf5a08b826aa68f75af48 Mon Sep 17 00:00:00 2001 From: skunkworxdark Date: Tue, 28 Nov 2023 12:07:29 +0000 Subject: [PATCH 02/65] Update communityNodes.md Added New Match Histogram node Updated XYGrid nodes and Prompt Tools nodes --- docs/nodes/communityNodes.md | 64 ++++++++++++++++++++++++++++-------- 1 file changed, 51 insertions(+), 13 deletions(-) diff --git a/docs/nodes/communityNodes.md b/docs/nodes/communityNodes.md index 3879cdc3c3..6007b3338b 100644 --- a/docs/nodes/communityNodes.md +++ b/docs/nodes/communityNodes.md @@ -26,6 +26,7 @@ To use a community workflow, download the the `.json` node graph file and load i + [Image Picker](#image-picker) + [Load Video Frame](#load-video-frame) + [Make 3D](#make-3d) + + [Match Histogram](#match-histogram) + [Oobabooga](#oobabooga) + [Prompt Tools](#prompt-tools) + [Remote Image](#remote-image) @@ -208,6 +209,23 @@ This includes 15 Nodes: +-------------------------------- +### Match Histogram + +**Description:** An InvokeAI node to match a histogram from one image to another. This is a bit like the `color correct` node in the main InvokeAI but this works in the YCbCr colourspace and can handle images of different sizes. Also does not require a mask input. +- Option to only transfer luminance channel. +- Option to save output as grayscale + +A good use case for this node is to normalize the colors of an image that has been through the tiled scaling workflow of my XYGrid Nodes. + +See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/main/README.md + +**Node Link:** https://github.com/skunkworxdark/match_histogram + +**Output Examples** + + + -------------------------------- ### Oobabooga @@ -237,22 +255,30 @@ This node works best with SDXL models, especially as the style can be described -------------------------------- ### Prompt Tools -**Description:** A set of InvokeAI nodes that add general prompt manipulation tools. These were written to accompany the PromptsFromFile node and other prompt generation nodes. +**Description:** A set of InvokeAI nodes that add general prompt (string) manipulation tools. Designed to accompany the `Prompts From File` node and other prompt generation nodes. + +1. `Prompt To File` - saves a prompt or collection of prompts to a file. one per line. There is an append/overwrite option. +2. `PTFields Collect` - Converts image generation fields into a Json format string that can be passed to Prompt to file. +3. `PTFields Expand` - Takes Json string and converts it to individual generation parameters. This can be fed from the Prompts to file node. +4. `Prompt Strength` - Formats prompt with strength like the weighted format of compel +5. `Prompt Strength Combine` - Combines weighted prompts for .and()/.blend() +6. `CSV To Index String` - Gets a string from a CSV by index. Includes a Random index option + +The following Nodes are now included in v3.2 of Invoke and are nolonger in this set of tools.
+- `Prompt Join` -> `String Join` +- `Prompt Join Three` -> `String Join Three` +- `Prompt Replace` -> `String Replace` +- `Prompt Split Neg` -> `String Split Neg` -1. PromptJoin - Joins to prompts into one. -2. PromptReplace - performs a search and replace on a prompt. With the option of using regex. -3. PromptSplitNeg - splits a prompt into positive and negative using the old V2 method of [] for negative. -4. PromptToFile - saves a prompt or collection of prompts to a file. one per line. There is an append/overwrite option. -5. PTFieldsCollect - Converts image generation fields into a Json format string that can be passed to Prompt to file. -6. PTFieldsExpand - Takes Json string and converts it to individual generation parameters This can be fed from the Prompts to file node. -7. PromptJoinThree - Joins 3 prompt together. -8. PromptStrength - This take a string and float and outputs another string in the format of (string)strength like the weighted format of compel. -9. PromptStrengthCombine - This takes a collection of prompt strength strings and outputs a string in the .and() or .blend() format that can be fed into a proper prompt node. See full docs here: https://github.com/skunkworxdark/Prompt-tools-nodes/edit/main/README.md **Node Link:** https://github.com/skunkworxdark/Prompt-tools-nodes +**Workflow Examples** + + + -------------------------------- ### Remote Image @@ -339,15 +365,27 @@ Highlights/Midtones/Shadows (with LUT blur enabled): -------------------------------- ### XY Image to Grid and Images to Grids nodes -**Description:** Image to grid nodes and supporting tools. +**Description:** These nodes add the following to InvokeAI: +- Generate grids of images from multiple input images +- Create XY grid images with labels from parameters +- Split images into overlapping tiles for processing (for super-resolution workflows) +- Recombine image tiles into a single output image blending the seams -1. "Images To Grids" node - Takes a collection of images and creates a grid(s) of images. If there are more images than the size of a single grid then multiple grids will be created until it runs out of images. -2. "XYImage To Grid" node - Converts a collection of XYImages into a labeled Grid of images. The XYImages collection has to be built using the supporting nodes. See example node setups for more details. +The nodes include: +1. `Images To Grids` - Combine multiple images into a grid of images +2. `XYImage To Grid` - Take X & Y params and creates a labeled image grid. +3. `XYImage Tiles` - Super-resolution (embiggen) style tiled resizing +4. `Image Tot XYImages` - Takes an image and cuts it up into a number of columns and rows. +5. Multiple supporting nodes - Helper nodes for data wrangling and building `XYImage` collections See full docs here: https://github.com/skunkworxdark/XYGrid_nodes/edit/main/README.md **Node Link:** https://github.com/skunkworxdark/XYGrid_nodes +**Output Examples** + + + -------------------------------- ### Example Node Template From 4eca802cddcff6d98b78a9ba5d3cbf6190ce1a2e Mon Sep 17 00:00:00 2001 From: Mary Hipp Rogers Date: Tue, 28 Nov 2023 09:24:54 -0500 Subject: [PATCH 03/65] fix preselected image (#5185) * fix for new response shape * unused import --------- Co-authored-by: Mary Hipp --- .../web/src/features/parameters/hooks/usePreselectedImage.ts | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/invokeai/frontend/web/src/features/parameters/hooks/usePreselectedImage.ts b/invokeai/frontend/web/src/features/parameters/hooks/usePreselectedImage.ts index 4ea4f93bac..b27bf3d572 100644 --- a/invokeai/frontend/web/src/features/parameters/hooks/usePreselectedImage.ts +++ b/invokeai/frontend/web/src/features/parameters/hooks/usePreselectedImage.ts @@ -1,5 +1,4 @@ import { skipToken } from '@reduxjs/toolkit/dist/query'; -import { CoreMetadata } from 'features/nodes/types/types'; import { t } from 'i18next'; import { useCallback, useEffect } from 'react'; import { useAppToaster } from '../../../app/components/Toaster'; @@ -51,7 +50,7 @@ export const usePreselectedImage = (selectedImage?: { const handleUseAllMetadata = useCallback(() => { if (selectedImageMetadata) { - recallAllParameters(selectedImageMetadata.metadata as CoreMetadata); + recallAllParameters(selectedImageMetadata); } // disabled because `recallAllParameters` changes the model, but its dep to prepare LoRAs has model as a dep. this introduces circular logic that causes infinite re-renders // eslint-disable-next-line react-hooks/exhaustive-deps From 0d524304811a146631c5313e19a227835f32c3fc Mon Sep 17 00:00:00 2001 From: Alexander Eichhorn Date: Tue, 28 Nov 2023 01:27:04 +0100 Subject: [PATCH 04/65] move toast to the bottom right --- invokeai/frontend/web/src/theme/theme.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/theme/theme.ts b/invokeai/frontend/web/src/theme/theme.ts index ae38aefca0..d51fae5ab7 100644 --- a/invokeai/frontend/web/src/theme/theme.ts +++ b/invokeai/frontend/web/src/theme/theme.ts @@ -150,5 +150,5 @@ export const theme: ThemeOverride = { }; export const TOAST_OPTIONS: ToastProviderProps = { - defaultOptions: { isClosable: true }, + defaultOptions: { isClosable: true, position: 'bottom-right' }, }; From 86a74e929a377092a17bdc84d95d19771f81f062 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 17 Nov 2023 11:32:35 +1100 Subject: [PATCH 05/65] feat(ui): add support for custom field types Node authors may now create their own arbitrary/custom field types. Any pydantic model is supported. Two notes: 1. Your field type's class name must be unique. Suggest prefixing fields with something related to the node pack as a kind of namespace. 2. Custom field types function as connection-only fields. For example, if your custom field has string attributes, you will not get a text input for that attribute when you give a node a field with your custom type. This is the same behaviour as other complex fields that don't have custom UIs in the workflow editor - like, say, a string collection. feat(ui): fix tooltips for custom types We need to hold onto the original type of the field so they don't all just show up as "Unknown". fix(ui): fix ts error with custom fields feat(ui): custom field types connection validation In the initial commit, a custom field's original type was added to the *field templates* only as `originalType`. Custom fields' `type` property was `"Custom"`*. This allowed for type safety throughout the UI logic. *Actually, it was `"Unknown"`, but I changed it to custom for clarity. Connection validation logic, however, uses the *field instance* of the node/field. Like the templates, *field instances* with custom types have their `type` set to `"Custom"`, but they didn't have an `originalType` property. As a result, all custom fields could be connected to all other custom fields. To resolve this, we need to add `originalType` to the *field instances*, then switch the validation logic to use this instead of `type`. This ended up needing a bit of fanagling: - If we make `originalType` a required property on field instances, existing workflows will break during connection validation, because they won't have this property. We'd need a new layer of logic to migrate the workflows, adding the new `originalType` property. While this layer is probably needed anyways, typing `originalType` as optional is much simpler. Workflow migration logic can come layer. (Technically, we could remove all references to field types from the workflow files, and let the templates hold all this information. This feels like a significant change and I'm reluctant to do it now.) - Because `originalType` is optional, anywhere we care about the type of a field, we need to use it over `type`. So there are a number of `field.originalType ?? field.type` expressions. This is a bit of a gotcha, we'll need to remember this in the future. - We use `Array.prototype.includes()` often in the workflow editor, e.g. `COLLECTION_TYPES.includes(type)`. In these cases, the const array is of type `FieldType[]`, and `type` is is `FieldType`. Because we now support custom types, the arg `type` is now widened from `FieldType` to `string`. This causes a TS error. This behaviour is somewhat controversial (see https://github.com/microsoft/TypeScript/issues/14520). These expressions are now rewritten as `COLLECTION_TYPES.some((t) => t === type)` to satisfy TS. It's logically equivalent. fix(ui): typo feat(ui): add CustomCollection and CustomPolymorphic field types feat(ui): add validation for CustomCollection & CustomPolymorphic types - Update connection validation for custom types - Use simple string parsing to determine if a field is a collection or polymorphic type. - No longer need to keep a list of collection and polymorphic types. - Added runtime checks in `baseinvocation.py` to ensure no fields are named in such a way that it could mess up the new parsing chore(ui): remove errant console.log fix(ui): rename 'nodes.currentConnectionFieldType' -> 'nodes.connectionStartFieldType' This was confusingly named and kept tripping me up. Renamed to be consistent with the `reactflow` `ConnectionStartParams` type. fix(ui): fix ts error feat(nodes): add runtime check for custom field names "Custom", "CustomCollection" and "CustomPolymorphic" are reserved field names. chore(ui): add TODO for revising field type names wip refactor fieldtype structured wip refactor field types wip refactor types wip refactor types fix node layout refactor field types chore: mypy organisation organisation organisation fix(nodes): fix field orig_required, field_kind and input statuses feat(nodes): remove broken implementation of default_factory on InputField Use of this could break connection validation due to the difference in node schemas required fields and invoke() required args. Removed entirely for now. It wasn't ever actually used by the system, because all graphs always had values provided for fields where default_factory was used. Also, pydantic is smart enough to not reuse the same object when specifying a default value - it clones the object first. So, the common pattern of `default_factory=list` is extraneous. It can just be `default=[]`. fix(nodes): fix InputField name validation workflow validation validation chore: ruff feat(nodes): fix up baseinvocation comments fix(ui): improve typing & logic of buildFieldInputTemplate improved error handling in parseFieldType fix: back compat for deprecated default_factory and UIType feat(nodes): do not show node packs loaded log if none loaded chore(ui): typegen --- invokeai/app/api_app.py | 22 +- invokeai/app/invocations/baseinvocation.py | 394 ++-- invokeai/app/invocations/collections.py | 6 +- invokeai/app/invocations/custom_nodes/init.py | 6 +- invokeai/app/invocations/infill.py | 6 +- invokeai/app/invocations/ip_adapter.py | 3 +- invokeai/app/invocations/latent.py | 5 +- invokeai/app/invocations/model.py | 2 - invokeai/app/invocations/noise.py | 6 +- invokeai/app/invocations/primitives.py | 20 +- invokeai/app/invocations/t2i_adapter.py | 3 +- invokeai/app/services/shared/graph.py | 10 +- invokeai/frontend/web/public/locales/en.json | 25 +- .../middleware/listenerMiddleware/index.ts | 4 +- .../listeners/controlNetImageProcessed.ts | 2 +- .../listeners/imageDeleted.ts | 9 +- .../listeners/modelSelected.ts | 4 +- .../listeners/modelsLoaded.ts | 12 +- .../listeners/receivedOpenAPISchema.ts | 1 + .../socketio/socketInvocationComplete.ts | 2 +- .../listeners/updateAllNodesRequested.ts | 48 +- .../listeners/workflowLoadRequested.ts | 105 + .../listeners/workflowLoaded.ts | 56 - .../src/common/hooks/useIsReadyToEnqueue.ts | 2 +- .../store/controlAdaptersSlice.ts | 12 +- .../features/controlAdapters/store/types.ts | 12 +- .../deleteImageModal/store/selectors.ts | 6 +- .../web/src/features/dnd/types/index.ts | 10 +- .../ImageMetadataActions.tsx | 18 +- .../web/src/features/lora/store/loraSlice.ts | 4 +- .../flow/AddNodePopover/AddNodePopover.tsx | 9 +- .../connectionLines/CustomConnectionLine.tsx | 11 +- .../flow/edges/util/getEdgeColor.ts | 12 + .../flow/edges/util/makeEdgeSelector.ts | 6 +- .../InvocationNodeCollapsedHandles.tsx | 2 +- .../Invocation/InvocationNodeInfoIcon.tsx | 6 +- .../InvocationNodeStatusIndicator.tsx | 21 +- .../Invocation/InvocationNodeWrapper.tsx | 2 +- .../flow/nodes/Invocation/NotesTextarea.tsx | 2 +- .../Invocation/fields/FieldContextMenu.tsx | 2 +- .../nodes/Invocation/fields/FieldHandle.tsx | 43 +- .../Invocation/fields/FieldTooltipContent.tsx | 22 +- .../nodes/Invocation/fields/InputField.tsx | 2 +- .../Invocation/fields/InputFieldRenderer.tsx | 259 ++- ...Field.tsx => BoardFieldInputComponent.tsx} | 14 +- ...eld.tsx => BooleanFieldInputComponent.tsx} | 18 +- ...Field.tsx => ColorFieldInputComponent.tsx} | 14 +- ...=> ControlNetModelFieldInputComponent.tsx} | 16 +- ...tField.tsx => EnumFieldInputComponent.tsx} | 14 +- ... => IPAdapterModelFieldInputComponent.tsx} | 16 +- ...Field.tsx => ImageFieldInputComponent.tsx} | 19 +- ...d.tsx => LoRAModelFieldInputComponent.tsx} | 16 +- ...d.tsx => MainModelFieldInputComponent.tsx} | 16 +- ...ield.tsx => NumberFieldInputComponent.tsx} | 32 +- ...sx => RefinerModelFieldInputComponent.tsx} | 16 +- ...x => SDXLMainModelFieldInputComponent.tsx} | 16 +- ...d.tsx => SchedulerFieldInputComponent.tsx} | 26 +- ...ield.tsx => StringFieldInputComponent.tsx} | 19 +- ...=> T2IAdapterModelFieldInputComponent.tsx} | 16 +- ...ld.tsx => VAEModelFieldInputComponent.tsx} | 16 +- .../nodes/Invocation/fields/inputs/types.ts | 13 + .../components/flow/nodes/Notes/NotesNode.tsx | 2 +- .../flow/nodes/common/NodeWrapper.tsx | 5 +- .../TopCenterPanel/LoadWorkflowButton.tsx | 2 +- .../panels/TopRightPanel/FieldTypeLegend.tsx | 31 - .../panels/TopRightPanel/TopRightPanel.tsx | 7 - .../inspector/InspectorDetailsTab.tsx | 39 +- .../inspector/InspectorOutputsTab.tsx | 2 +- .../hooks/useAnyOrDirectInputFieldNames.ts | 13 +- .../features/nodes/hooks/useBuildNodeData.ts | 26 +- .../hooks/useConnectionInputFieldNames.ts | 14 +- .../nodes/hooks/useConnectionState.ts | 2 +- .../nodes/hooks/useDoNodeVersionsMatch.ts | 2 +- .../nodes/hooks/useDoesInputHaveValue.ts | 2 +- .../features/nodes/hooks/useEmbedWorkflow.ts | 2 +- .../src/features/nodes/hooks/useFieldData.ts | 4 +- .../features/nodes/hooks/useFieldInputKind.ts | 2 +- .../src/features/nodes/hooks/useFieldLabel.ts | 2 +- .../features/nodes/hooks/useFieldTemplate.ts | 2 +- .../nodes/hooks/useFieldTemplateTitle.ts | 2 +- .../features/nodes/hooks/useFieldType.ts.ts | 5 +- .../nodes/hooks/useGetNodesNeedUpdate.ts | 8 +- .../features/nodes/hooks/useHasImageOutput.ts | 7 +- .../features/nodes/hooks/useIsIntermediate.ts | 2 +- .../nodes/hooks/useIsValidConnection.ts | 13 +- .../nodes/hooks/useLoadWorkflowFromFile.tsx | 48 +- .../src/features/nodes/hooks/useNodeLabel.ts | 2 +- .../nodes/hooks/useNodeNeedsUpdate.ts | 35 + .../nodes/hooks/useNodeTemplateByType.ts | 8 +- .../nodes/hooks/useNodeTemplateTitle.ts | 2 +- .../features/nodes/hooks/useNodeVersion.ts | 119 -- .../nodes/hooks/useOutputFieldNames.ts | 2 +- .../nodes/hooks/usePrettyFieldType.ts | 23 + .../src/features/nodes/hooks/useUseCache.ts | 2 +- .../features/nodes/hooks/useWithWorkflow.ts | 2 +- .../web/src/features/nodes/store/actions.ts | 3 +- .../nodes/store/nodesPersistDenylist.ts | 2 +- .../src/features/nodes/store/nodesSlice.ts | 204 +- .../web/src/features/nodes/store/types.ts | 18 +- .../nodes/store/util/buildNodeData.ts | 111 +- .../store/util/findConnectionToValidHandle.ts | 17 +- .../util/makeIsConnectionValidSelector.ts | 18 +- .../features/nodes/store/util/nodeUpdate.ts | 68 + .../util/validateSourceAndTargetTypes.ts | 72 +- .../web/src/features/nodes/types/common.ts | 216 ++ .../web/src/features/nodes/types/constants.ts | 463 +---- .../web/src/features/nodes/types/error.ts | 59 + .../web/src/features/nodes/types/field.ts | 1114 +++++++++++ .../src/features/nodes/types/invocation.ts | 108 + .../web/src/features/nodes/types/metadata.ts | 81 + .../nodes/types/migration/migrations.ts | 69 + .../nodes/types/migration/v1/fieldTypeMap.ts | 270 +++ .../nodes/types/migration/v1/workflowV1.ts | 711 +++++++ .../web/src/features/nodes/types/openapi.ts | 108 + .../web/src/features/nodes/types/semver.ts | 23 + .../web/src/features/nodes/types/types.ts | 1742 ----------------- .../web/src/features/nodes/types/workflow.ts | 91 + .../nodes/util/buildFieldInputInstance.ts | 42 + .../nodes/util/buildFieldInputTemplate.ts | 376 ++++ .../src/features/nodes/util/buildWorkflow.ts | 6 +- .../nodes/util/fieldTemplateBuilders.ts | 1210 ------------ .../features/nodes/util/fieldValueBuilders.ts | 85 - .../nodes/util/getSortedFilteredFieldNames.ts | 4 +- .../addControlNetToLinearGraph.ts | 2 +- .../nodes/util/graphBuilders/addHrfToGraph.ts | 2 +- .../addIPAdapterToLinearGraph.ts | 2 +- .../graphBuilders/addLinearUIOutputNode.ts | 3 +- .../util/graphBuilders/addLoRAsToGraph.ts | 2 +- .../graphBuilders/addNSFWCheckerToGraph.ts | 2 +- .../util/graphBuilders/addSDXLLoRAstoGraph.ts | 5 +- .../graphBuilders/addSDXLRefinerToGraph.ts | 2 +- .../graphBuilders/addSeamlessToLinearGraph.ts | 5 +- .../addT2IAdapterToLinearGraph.ts | 3 +- .../nodes/util/graphBuilders/addVAEToGraph.ts | 2 +- .../graphBuilders/addWatermarkerToGraph.ts | 2 +- .../graphBuilders/buildAdHocUpscaleGraph.ts | 2 +- .../util/graphBuilders/buildCanvasGraph.ts | 3 +- .../buildCanvasImageToImageGraph.ts | 9 +- .../graphBuilders/buildCanvasInpaintGraph.ts | 4 +- .../graphBuilders/buildCanvasOutpaintGraph.ts | 4 +- .../buildCanvasSDXLImageToImageGraph.ts | 11 +- .../buildCanvasSDXLInpaintGraph.ts | 4 +- .../buildCanvasSDXLOutpaintGraph.ts | 4 +- .../buildCanvasSDXLTextToImageGraph.ts | 4 +- .../buildCanvasTextToImageGraph.ts | 4 +- .../graphBuilders/buildLinearBatchConfig.ts | 3 +- .../buildLinearImageToImageGraph.ts | 4 +- .../buildLinearSDXLImageToImageGraph.ts | 4 +- .../buildLinearSDXLTextToImageGraph.ts | 6 +- .../buildLinearTextToImageGraph.ts | 6 +- .../util/graphBuilders/buildNodesGraph.ts | 10 +- .../nodes/util/graphBuilders/metadata.ts | 3 +- .../src/features/nodes/util/parseFieldType.ts | 233 +++ .../src/features/nodes/util/parseSchema.ts | 193 +- .../features/nodes/util/validateWorkflow.ts | 178 +- .../Parameters/Advanced/ParamClipSkip.tsx | 10 +- .../ParamCanvasCoherenceMode.tsx | 4 +- .../Parameters/Core/ParamScheduler.tsx | 10 +- .../Parameters/HighResFix/ParamHrfMethod.tsx | 4 +- .../Parameters/VAEModel/ParamVAEPrecision.tsx | 4 +- .../useCoreParametersCollapseLabel.ts | 0 .../parameters/hooks/useRecallParameters.ts | 154 +- .../parameters/store/generationSlice.ts | 93 +- .../features/parameters/types/constants.ts | 54 +- .../parameters/types/parameterSchemas.ts | 725 +++---- .../util/modelIdToControlNetModelParam.ts | 4 +- .../util/modelIdToIPAdapterModelParams.ts | 4 +- .../util/modelIdToLoRAModelParam.ts | 9 +- .../util/modelIdToMainModelParam.ts | 9 +- .../util/modelIdToSDXLRefinerModelParam.ts | 8 +- .../util/modelIdToT2IAdapterModelParam.ts | 4 +- .../parameters/util/modelIdToVAEModelParam.ts | 9 +- .../SDXLRefiner/ParamSDXLRefinerScheduler.tsx | 10 +- .../web/src/features/sdxl/store/sdxlSlice.ts | 20 +- .../SettingsModal/SettingsSchedulers.tsx | 8 +- .../ImageToImageTabCoreParameters.tsx | 2 +- .../TextToImageTabCoreParameters.tsx | 2 +- .../UnifiedCanvasCoreParameters.tsx | 2 +- .../web/src/features/ui/store/uiSlice.ts | 4 +- .../web/src/features/ui/store/uiTypes.ts | 4 +- .../web/src/services/api/endpoints/images.ts | 2 +- .../src/services/api/endpoints/workflows.ts | 6 +- .../frontend/web/src/services/api/guards.ts | 67 - .../frontend/web/src/services/api/schema.d.ts | 210 +- .../frontend/web/src/services/api/types.ts | 7 +- .../frontend/web/src/services/events/types.ts | 6 - 186 files changed, 5713 insertions(+), 5704 deletions(-) create mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts delete mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoaded.ts create mode 100644 invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts rename invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/{BoardInputField.tsx => BoardFieldInputComponent.tsx} (82%) rename invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/{BooleanInputField.tsx => BooleanFieldInputComponent.tsx} (65%) rename invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/{ColorInputField.tsx => ColorFieldInputComponent.tsx} (70%) rename invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/{ControlNetModelInputField.tsx => ControlNetModelFieldInputComponent.tsx} (87%) rename invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/{EnumInputField.tsx => EnumFieldInputComponent.tsx} (77%) rename invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/{IPAdapterModelInputField.tsx => IPAdapterModelFieldInputComponent.tsx} (87%) rename invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/{ImageInputField.tsx => ImageFieldInputComponent.tsx} (87%) rename invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/{LoRAModelInputField.tsx => LoRAModelFieldInputComponent.tsx} (91%) rename invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/{MainModelInputField.tsx => MainModelFieldInputComponent.tsx} (93%) rename invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/{NumberInputField.tsx => NumberFieldInputComponent.tsx} (71%) rename invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/{RefinerModelInputField.tsx => RefinerModelFieldInputComponent.tsx} (90%) rename invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/{SDXLMainModelInputField.tsx => SDXLMainModelFieldInputComponent.tsx} (92%) rename invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/{SchedulerInputField.tsx => SchedulerFieldInputComponent.tsx} (72%) rename invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/{StringInputField.tsx => StringFieldInputComponent.tsx} (70%) rename invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/{T2IAdapterModelInputField.tsx => T2IAdapterModelFieldInputComponent.tsx} (87%) rename invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/{VaeModelInputField.tsx => VAEModelFieldInputComponent.tsx} (89%) create mode 100644 invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/types.ts delete mode 100644 invokeai/frontend/web/src/features/nodes/components/flow/panels/TopRightPanel/FieldTypeLegend.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts delete mode 100644 invokeai/frontend/web/src/features/nodes/hooks/useNodeVersion.ts create mode 100644 invokeai/frontend/web/src/features/nodes/hooks/usePrettyFieldType.ts create mode 100644 invokeai/frontend/web/src/features/nodes/store/util/nodeUpdate.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/common.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/error.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/field.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/invocation.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/metadata.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/migration/migrations.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/migration/v1/fieldTypeMap.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/migration/v1/workflowV1.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/openapi.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/semver.ts delete mode 100644 invokeai/frontend/web/src/features/nodes/types/types.ts create mode 100644 invokeai/frontend/web/src/features/nodes/types/workflow.ts create mode 100644 invokeai/frontend/web/src/features/nodes/util/buildFieldInputInstance.ts create mode 100644 invokeai/frontend/web/src/features/nodes/util/buildFieldInputTemplate.ts delete mode 100644 invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts delete mode 100644 invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts create mode 100644 invokeai/frontend/web/src/features/nodes/util/parseFieldType.ts rename invokeai/frontend/web/src/features/parameters/{util => hooks}/useCoreParametersCollapseLabel.ts (100%) delete mode 100644 invokeai/frontend/web/src/services/api/guards.ts diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 44471eab3c..79c7740485 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -1,11 +1,8 @@ -import sys -from typing import Any - -from fastapi.responses import HTMLResponse - # parse_args() must be called before any other imports. if it is not called first, consumers of the config # which are imported/used before parse_args() is called will get the default config values instead of the # values from the command line or config file. +import sys + from invokeai.version.invokeai_version import __version__ from .services.config import InvokeAIAppConfig @@ -22,6 +19,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c import socket from inspect import signature from pathlib import Path + from typing import Any import uvicorn from fastapi import FastAPI @@ -29,7 +27,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c from fastapi.middleware.gzip import GZipMiddleware from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html from fastapi.openapi.utils import get_openapi - from fastapi.responses import FileResponse + from fastapi.responses import FileResponse, HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi_events.handlers.local import local_handler from fastapi_events.middleware import EventHandlerASGIMiddleware @@ -58,9 +56,9 @@ if True: # hack to make flake8 happy with imports coming after setting up the c from .api.sockets import SocketIO from .invocations.baseinvocation import ( BaseInvocation, + InputFieldJSONSchemaExtra, + OutputFieldJSONSchemaExtra, UIConfigBase, - _InputField, - _OutputField, ) if is_mps_available(): @@ -157,7 +155,11 @@ def custom_openapi() -> dict[str, Any]: # Add Node Editor UI helper schemas ui_config_schemas = models_json_schema( - [(UIConfigBase, "serialization"), (_InputField, "serialization"), (_OutputField, "serialization")], + [ + (UIConfigBase, "serialization"), + (InputFieldJSONSchemaExtra, "serialization"), + (OutputFieldJSONSchemaExtra, "serialization"), + ], ref_template="#/components/schemas/{model}", ) for schema_key, ui_config_schema in ui_config_schemas[1]["$defs"].items(): @@ -165,7 +167,7 @@ def custom_openapi() -> dict[str, Any]: # Add a reference to the output type to additionalProperties of the invoker schema for invoker in all_invocations: - invoker_name = invoker.__name__ + invoker_name = invoker.__name__ # type: ignore [attr-defined] # this is a valid attribute output_type = signature(obj=invoker.invoke).return_annotation output_type_title = output_type_titles[output_type.__name__] invoker_schema = openapi_schema["components"]["schemas"][f"{invoker_name}"] diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 1b3e535d34..cddbd071de 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -17,11 +17,15 @@ from pydantic_core import PydanticUndefined from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.shared.fields import FieldDescriptions +from invokeai.app.util.metaenum import MetaEnum from invokeai.app.util.misc import uuid_string +from invokeai.backend.util.logging import InvokeAILogger if TYPE_CHECKING: from ..services.invocation_services import InvocationServices +logger = InvokeAILogger.get_logger() + class InvalidVersionError(ValueError): pass @@ -31,7 +35,7 @@ class InvalidFieldError(TypeError): pass -class Input(str, Enum): +class Input(str, Enum, metaclass=MetaEnum): """ The type of input a field accepts. - `Input.Direct`: The field must have its value provided directly, when the invocation and field \ @@ -45,86 +49,120 @@ class Input(str, Enum): Any = "any" -class UIType(str, Enum): +class FieldKind(str, Enum, metaclass=MetaEnum): """ - Type hints for the UI. - If a field should be provided a data type that does not exactly match the python type of the field, \ - use this to provide the type that should be used instead. See the node development docs for detail \ - on adding a new field type, which involves client-side changes. + The kind of field. + - `Input`: An input field on a node. + - `Output`: An output field on a node. + - `Internal`: A field which is treated as an input, but cannot be used in node definitions. Metadata is + one example. It is provided to nodes via the WithMetadata class, and we want to reserve the field name + "metadata" for this on all nodes. `FieldKind` is used to short-circuit the field name validation logic, + allowing "metadata" for that field. + - `NodeAttribute`: The field is a node attribute. These are fields which are not inputs or outputs, + but which are used to store information about the node. For example, the `id` and `type` fields are node + attributes. + + The presence of this in `json_schema_extra["field_kind"]` is used when initializing node schemas on app + startup, and when generating the OpenAPI schema for the workflow editor. """ - # region Primitives - Boolean = "boolean" - Color = "ColorField" - Conditioning = "ConditioningField" - Control = "ControlField" - Float = "float" - Image = "ImageField" - Integer = "integer" - Latents = "LatentsField" - String = "string" - # endregion + Input = "input" + Output = "output" + Internal = "internal" + NodeAttribute = "node_attribute" - # region Collection Primitives - BooleanCollection = "BooleanCollection" - ColorCollection = "ColorCollection" - ConditioningCollection = "ConditioningCollection" - ControlCollection = "ControlCollection" - FloatCollection = "FloatCollection" - ImageCollection = "ImageCollection" - IntegerCollection = "IntegerCollection" - LatentsCollection = "LatentsCollection" - StringCollection = "StringCollection" - # endregion - # region Polymorphic Primitives - BooleanPolymorphic = "BooleanPolymorphic" - ColorPolymorphic = "ColorPolymorphic" - ConditioningPolymorphic = "ConditioningPolymorphic" - ControlPolymorphic = "ControlPolymorphic" - FloatPolymorphic = "FloatPolymorphic" - ImagePolymorphic = "ImagePolymorphic" - IntegerPolymorphic = "IntegerPolymorphic" - LatentsPolymorphic = "LatentsPolymorphic" - StringPolymorphic = "StringPolymorphic" - # endregion +class UIType(str, Enum, metaclass=MetaEnum): + """ + Type hints for the UI for situations in which the field type is not enough to infer the correct UI type. - # region Models - MainModel = "MainModelField" + - Model Fields + The most common node-author-facing use will be for model fields. Internally, there is no difference + between SD-1, SD-2 and SDXL model fields - they all use the class `MainModelField`. To ensure the + base-model-specific UI is rendered, use e.g. `ui_type=UIType.SDXLMainModelField` to indicate that + the field is an SDXL main model field. + + - Any Field + We cannot infer the usage of `typing.Any` via schema parsing, so you *must* use `ui_type=UIType.Any` to + indicate that the field accepts any type. Use with caution. This cannot be used on outputs. + + - Scheduler Field + Special handling in the UI is needed for this field, which otherwise would be parsed as a plain enum field. + + - Internal Fields + Similar to the Any Field, the `collect` and `iterate` nodes make use of `typing.Any`. To facilitate + handling these types in the client, we use `UIType._Collection` and `UIType._CollectionItem`. These + should not be used by node authors. + """ + + # region Model Field Types SDXLMainModel = "SDXLMainModelField" SDXLRefinerModel = "SDXLRefinerModelField" ONNXModel = "ONNXModelField" - VaeModel = "VaeModelField" + VaeModel = "VAEModelField" LoRAModel = "LoRAModelField" ControlNetModel = "ControlNetModelField" IPAdapterModel = "IPAdapterModelField" - UNet = "UNetField" - Vae = "VaeField" - CLIP = "ClipField" # endregion - # region Iterate/Collect - Collection = "Collection" - CollectionItem = "CollectionItem" + # region Misc Field Types + Scheduler = "SchedulerField" + Any = "AnyField" # endregion - # region Misc - Enum = "enum" - Scheduler = "Scheduler" - WorkflowField = "WorkflowField" - IsIntermediate = "IsIntermediate" - BoardField = "BoardField" - Any = "Any" - MetadataItem = "MetadataItem" - MetadataItemCollection = "MetadataItemCollection" - MetadataItemPolymorphic = "MetadataItemPolymorphic" - MetadataDict = "MetadataDict" + # region Internal Field Types + _Collection = "CollectionField" + _CollectionItem = "CollectionItemField" + # endregion + + # region DEPRECATED + Boolean = "DEPRECATED_Boolean" + Color = "DEPRECATED_Color" + Conditioning = "DEPRECATED_Conditioning" + Control = "DEPRECATED_Control" + Float = "DEPRECATED_Float" + Image = "DEPRECATED_Image" + Integer = "DEPRECATED_Integer" + Latents = "DEPRECATED_Latents" + String = "DEPRECATED_String" + BooleanCollection = "DEPRECATED_BooleanCollection" + ColorCollection = "DEPRECATED_ColorCollection" + ConditioningCollection = "DEPRECATED_ConditioningCollection" + ControlCollection = "DEPRECATED_ControlCollection" + FloatCollection = "DEPRECATED_FloatCollection" + ImageCollection = "DEPRECATED_ImageCollection" + IntegerCollection = "DEPRECATED_IntegerCollection" + LatentsCollection = "DEPRECATED_LatentsCollection" + StringCollection = "DEPRECATED_StringCollection" + BooleanPolymorphic = "DEPRECATED_BooleanPolymorphic" + ColorPolymorphic = "DEPRECATED_ColorPolymorphic" + ConditioningPolymorphic = "DEPRECATED_ConditioningPolymorphic" + ControlPolymorphic = "DEPRECATED_ControlPolymorphic" + FloatPolymorphic = "DEPRECATED_FloatPolymorphic" + ImagePolymorphic = "DEPRECATED_ImagePolymorphic" + IntegerPolymorphic = "DEPRECATED_IntegerPolymorphic" + LatentsPolymorphic = "DEPRECATED_LatentsPolymorphic" + StringPolymorphic = "DEPRECATED_StringPolymorphic" + MainModel = "DEPRECATED_MainModel" + UNet = "DEPRECATED_UNet" + Vae = "DEPRECATED_Vae" + CLIP = "DEPRECATED_CLIP" + Collection = "DEPRECATED_Collection" + CollectionItem = "DEPRECATED_CollectionItem" + Enum = "DEPRECATED_Enum" + WorkflowField = "DEPRECATED_WorkflowField" + IsIntermediate = "DEPRECATED_IsIntermediate" + BoardField = "DEPRECATED_BoardField" + MetadataItem = "DEPRECATED_MetadataItem" + MetadataItemCollection = "DEPRECATED_MetadataItemCollection" + MetadataItemPolymorphic = "DEPRECATED_MetadataItemPolymorphic" + MetadataDict = "DEPRECATED_MetadataDict" # endregion -class UIComponent(str, Enum): +class UIComponent(str, Enum, metaclass=MetaEnum): """ - The type of UI component to use for a field, used to override the default components, which are \ + The type of UI component to use for a field, used to override the default components, which are inferred from the field type. """ @@ -133,7 +171,7 @@ class UIComponent(str, Enum): Slider = "slider" -class _InputField(BaseModel): +class InputFieldJSONSchemaExtra(BaseModel): """ *DO NOT USE* This helper class is used to tell the client about our custom field attributes via OpenAPI @@ -142,12 +180,15 @@ class _InputField(BaseModel): """ input: Input - ui_hidden: bool - ui_type: Optional[UIType] - ui_component: Optional[UIComponent] - ui_order: Optional[int] - ui_choice_labels: Optional[dict[str, str]] - item_default: Optional[Any] + orig_required: bool + field_kind: FieldKind + default: Optional[Any] = None + orig_default: Optional[Any] = None + ui_hidden: bool = False + ui_type: Optional[UIType] = None + ui_component: Optional[UIComponent] = None + ui_order: Optional[int] = None + ui_choice_labels: Optional[dict[str, str]] = None model_config = ConfigDict( validate_assignment=True, @@ -155,7 +196,7 @@ class _InputField(BaseModel): ) -class _OutputField(BaseModel): +class OutputFieldJSONSchemaExtra(BaseModel): """ *DO NOT USE* This helper class is used to tell the client about our custom field attributes via OpenAPI @@ -163,6 +204,7 @@ class _OutputField(BaseModel): purpose in the backend. """ + field_kind: FieldKind ui_hidden: bool ui_type: Optional[UIType] ui_order: Optional[int] @@ -180,6 +222,7 @@ def get_type(klass: BaseModel) -> str: def InputField( # copied from pydantic's Field + # TODO: Can we support default_factory? default: Any = _Unset, default_factory: Callable[[], Any] | None = _Unset, title: str | None = _Unset, @@ -203,12 +246,11 @@ def InputField( ui_hidden: bool = False, ui_order: Optional[int] = None, ui_choice_labels: Optional[dict[str, str]] = None, - item_default: Optional[Any] = None, ) -> Any: """ Creates an input field for an invocation. - This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \ + This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.Field) \ that adds a few extra parameters to support graph execution and the node editor UI. :param Input input: [Input.Any] The kind of input this field requires. \ @@ -228,28 +270,59 @@ def InputField( For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \ For this case, you could provide `UIComponent.Textarea`. - : param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. + :param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. - : param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \ + :param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. - : param bool item_default: [None] Specifies the default item value, if this is a collection input. \ - Ignored for non-collection fields. + :param dict[str, str] ui_choice_labels: [None] Specifies the labels to use for the choices in an enum field. """ - json_schema_extra_: dict[str, Any] = { - "input": input, - "ui_type": ui_type, - "ui_component": ui_component, - "ui_hidden": ui_hidden, - "ui_order": ui_order, - "item_default": item_default, - "ui_choice_labels": ui_choice_labels, - "_field_kind": "input", - } + json_schema_extra_ = InputFieldJSONSchemaExtra( + input=input, + ui_type=ui_type, + ui_component=ui_component, + ui_hidden=ui_hidden, + ui_order=ui_order, + ui_choice_labels=ui_choice_labels, + field_kind=FieldKind.Input, + orig_required=True, + ) + """ + There is a conflict between the typing of invocation definitions and the typing of an invocation's + `invoke()` function. + + On instantiation of a node, the invocation definition is used to create the python class. At this time, + any number of fields may be optional, because they may be provided by connections. + + On calling of `invoke()`, however, those fields may be required. + + For example, consider an ResizeImageInvocation with an `image: ImageField` field. + + `image` is required during the call to `invoke()`, but when the python class is instantiated, + the field may not be present. This is fine, because that image field will be provided by a + connection from an ancestor node, which outputs an image. + + This means we want to type the `image` field as optional for the node class definition, but required + for the `invoke()` function. + + If we use `typing.Optional` in the node class definition, the field will be typed as optional in the + `invoke()` method, and we'll have to do a lot of runtime checks to ensure the field is present - or + any static type analysis tools will complain. + + To get around this, in node class definitions, we type all fields correctly for the `invoke()` function, + but secretly make them optional in `InputField()`. We also store the original required bool and/or default + value. When we call `invoke()`, we use this stored information to do an additional check on the class. + """ + + if default_factory is not _Unset and default_factory is not None: + default = default_factory() + del default_factory + logger.warn('"default_factory" is not supported, calling it now to set "default"') + + # These are the args we may wish pass to the pydantic `Field()` function field_args = { "default": default, - "default_factory": default_factory, "title": title, "description": description, "pattern": pattern, @@ -266,70 +339,34 @@ def InputField( "max_length": max_length, } - """ - Invocation definitions have their fields typed correctly for their `invoke()` functions. - This typing is often more specific than the actual invocation definition requires, because - fields may have values provided only by connections. - - For example, consider an ResizeImageInvocation with an `image: ImageField` field. - - `image` is required during the call to `invoke()`, but when the python class is instantiated, - the field may not be present. This is fine, because that image field will be provided by a - an ancestor node that outputs the image. - - So we'd like to type that `image` field as `Optional[ImageField]`. If we do that, however, then - we need to handle a lot of extra logic in the `invoke()` function to check if the field has a - value or not. This is very tedious. - - Ideally, the invocation definition would be able to specify that the field is required during - invocation, but optional during instantiation. So the field would be typed as `image: ImageField`, - but when calling the `invoke()` function, we raise an error if the field is not present. - - To do this, we need to do a bit of fanagling to make the pydantic field optional, and then do - extra validation when calling `invoke()`. - - There is some additional logic here to cleaning create the pydantic field via the wrapper. - """ - - # Filter out field args not provided + # We only want to pass the args that were provided, otherwise the `Field()`` function won't work as expected provided_args = {k: v for (k, v) in field_args.items() if v is not PydanticUndefined} - if (default is not PydanticUndefined) and (default_factory is not PydanticUndefined): - raise ValueError("Cannot specify both default and default_factory") + # Because we are manually making fields optional, we need to store the original required bool for reference later + json_schema_extra_.orig_required = default is PydanticUndefined - # because we are manually making fields optional, we need to store the original required bool for reference later - if default is PydanticUndefined and default_factory is PydanticUndefined: - json_schema_extra_.update({"orig_required": True}) - else: - json_schema_extra_.update({"orig_required": False}) - - # make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one - if (input is Input.Any or input is Input.Connection) and default_factory is PydanticUndefined: + # Make Input.Any and Input.Connection fields optional, providing None as a default if the field doesn't already have one + if input is Input.Any or input is Input.Connection: default_ = None if default is PydanticUndefined else default provided_args.update({"default": default_}) if default is not PydanticUndefined: - # before invoking, we'll grab the original default value and set it on the field if the field wasn't provided a value - json_schema_extra_.update({"default": default}) - json_schema_extra_.update({"orig_default": default}) - elif default is not PydanticUndefined and default_factory is PydanticUndefined: + # Before invoking, we'll check for the original default value and set it on the field if the field has no value + json_schema_extra_.default = default + json_schema_extra_.orig_default = default + elif default is not PydanticUndefined: default_ = default provided_args.update({"default": default_}) - json_schema_extra_.update({"orig_default": default_}) - elif default_factory is not PydanticUndefined: - provided_args.update({"default_factory": default_factory}) - # TODO: cannot serialize default_factory... - # json_schema_extra_.update(dict(orig_default_factory=default_factory)) + json_schema_extra_.orig_default = default_ return Field( **provided_args, - json_schema_extra=json_schema_extra_, + json_schema_extra=json_schema_extra_.model_dump(exclude_none=True), ) def OutputField( # copied from pydantic's Field default: Any = _Unset, - default_factory: Callable[[], Any] | None = _Unset, title: str | None = _Unset, description: str | None = _Unset, pattern: str | None = _Unset, @@ -362,13 +399,12 @@ def OutputField( `MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \ `UIType.SDXLMainModelField` to indicate that the field is an SDXL main model field. - : param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \ + :param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \ - : param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \ + :param int ui_order: [None] Specifies the order in which this field should be rendered in the UI. \ """ return Field( default=default, - default_factory=default_factory, title=title, description=description, pattern=pattern, @@ -383,12 +419,12 @@ def OutputField( decimal_places=decimal_places, min_length=min_length, max_length=max_length, - json_schema_extra={ - "ui_type": ui_type, - "ui_hidden": ui_hidden, - "ui_order": ui_order, - "_field_kind": "output", - }, + json_schema_extra=OutputFieldJSONSchemaExtra( + ui_type=ui_type, + ui_hidden=ui_hidden, + ui_order=ui_order, + field_kind=FieldKind.Output, + ).model_dump(exclude_none=True), ) @@ -538,7 +574,7 @@ class BaseInvocation(ABC, BaseModel): return signature(cls.invoke).return_annotation @staticmethod - def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None: + def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None: # Add the various UI-facing attributes to the schema. These are used to build the invocation templates. uiconfig = getattr(model_class, "UIConfig", None) if uiconfig and hasattr(uiconfig, "title"): @@ -604,15 +640,17 @@ class BaseInvocation(ABC, BaseModel): id: str = Field( default_factory=uuid_string, description="The id of this instance of an invocation. Must be unique among all instances of invocations.", - json_schema_extra={"_field_kind": "internal"}, + json_schema_extra={"field_kind": FieldKind.NodeAttribute}, ) is_intermediate: bool = Field( default=False, description="Whether or not this is an intermediate invocation.", - json_schema_extra={"ui_type": UIType.IsIntermediate, "_field_kind": "internal"}, + json_schema_extra={"ui_type": "IsIntermediate", "field_kind": FieldKind.NodeAttribute}, ) use_cache: bool = Field( - default=True, description="Whether or not to use the cache", json_schema_extra={"_field_kind": "internal"} + default=True, + description="Whether or not to use the cache", + json_schema_extra={"field_kind": FieldKind.NodeAttribute}, ) UIConfig: ClassVar[Type[UIConfigBase]] @@ -629,12 +667,15 @@ class BaseInvocation(ABC, BaseModel): TBaseInvocation = TypeVar("TBaseInvocation", bound=BaseInvocation) -RESERVED_INPUT_FIELD_NAMES = { +RESERVED_NODE_ATTRIBUTE_FIELD_NAMES = { "id", "is_intermediate", "use_cache", "type", "workflow", +} + +RESERVED_INPUT_FIELD_NAMES = { "metadata", } @@ -653,39 +694,56 @@ def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None """ Validates the fields of an invocation or invocation output: - must not override any pydantic reserved fields + - must not end with "Collection" or "Polymorphic" as these are reserved for internal use - must be created via `InputField`, `OutputField`, or be an internal field defined in this file """ for name, field in model_fields.items(): if name in RESERVED_PYDANTIC_FIELD_NAMES: raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved by pydantic)') - field_kind = ( - # _field_kind is defined via InputField(), OutputField() or by one of the internal fields defined in this file - field.json_schema_extra.get("_field_kind", None) if field.json_schema_extra else None - ) + if not field.annotation: + raise InvalidFieldError(f'Invalid field type "{name}" on "{model_type}" (missing annotation)') + + if not isinstance(field.json_schema_extra, dict): + raise InvalidFieldError( + f'Invalid field definition for "{name}" on "{model_type}" (missing json_schema_extra dict)' + ) + + field_kind = field.json_schema_extra.get("field_kind", None) # must have a field_kind - if field_kind is None or field_kind not in {"input", "output", "internal"}: + if not isinstance(field_kind, FieldKind): raise InvalidFieldError( f'Invalid field definition for "{name}" on "{model_type}" (maybe it\'s not an InputField or OutputField?)' ) - if field_kind == "input" and name in RESERVED_INPUT_FIELD_NAMES: + if field_kind is FieldKind.Input and ( + name in RESERVED_NODE_ATTRIBUTE_FIELD_NAMES or name in RESERVED_INPUT_FIELD_NAMES + ): raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved input field name)') - if field_kind == "output" and name in RESERVED_OUTPUT_FIELD_NAMES: + if field_kind is FieldKind.Output and name in RESERVED_OUTPUT_FIELD_NAMES: raise InvalidFieldError(f'Invalid field name "{name}" on "{model_type}" (reserved output field name)') - # internal fields *must* be in the reserved list - if ( - field_kind == "internal" - and name not in RESERVED_INPUT_FIELD_NAMES - and name not in RESERVED_OUTPUT_FIELD_NAMES - ): + if (field_kind is FieldKind.Internal) and name not in RESERVED_INPUT_FIELD_NAMES: raise InvalidFieldError( f'Invalid field name "{name}" on "{model_type}" (internal field without reserved name)' ) + # node attribute fields *must* be in the reserved list + if ( + field_kind is FieldKind.NodeAttribute + and name not in RESERVED_NODE_ATTRIBUTE_FIELD_NAMES + and name not in RESERVED_OUTPUT_FIELD_NAMES + ): + raise InvalidFieldError( + f'Invalid field name "{name}" on "{model_type}" (node attribute field without reserved name)' + ) + + ui_type = field.json_schema_extra.get("ui_type", None) + if isinstance(ui_type, str) and ui_type.startswith("DEPRECATED_"): + logger.warn(f"\"UIType.{ui_type.split('_')[-1]}\" is deprecated, ignoring") + field.json_schema_extra.pop("ui_type") return None @@ -749,7 +807,7 @@ def invocation( invocation_type_annotation = Literal[invocation_type] # type: ignore invocation_type_field = Field( - title="type", default=invocation_type, json_schema_extra={"_field_kind": "internal"} + title="type", default=invocation_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute} ) docstring = cls.__doc__ @@ -795,7 +853,9 @@ def invocation_output( # Add the output type to the model. output_type_annotation = Literal[output_type] # type: ignore - output_type_field = Field(title="type", default=output_type, json_schema_extra={"_field_kind": "internal"}) + output_type_field = Field( + title="type", default=output_type, json_schema_extra={"field_kind": FieldKind.NodeAttribute} + ) docstring = cls.__doc__ cls = create_model( @@ -827,7 +887,7 @@ WorkflowFieldValidator = TypeAdapter(WorkflowField) class WithWorkflow(BaseModel): workflow: Optional[WorkflowField] = Field( - default=None, description=FieldDescriptions.workflow, json_schema_extra={"_field_kind": "internal"} + default=None, description=FieldDescriptions.workflow, json_schema_extra={"field_kind": FieldKind.NodeAttribute} ) @@ -845,5 +905,11 @@ MetadataFieldValidator = TypeAdapter(MetadataField) class WithMetadata(BaseModel): metadata: Optional[MetadataField] = Field( - default=None, description=FieldDescriptions.metadata, json_schema_extra={"_field_kind": "internal"} + default=None, + description=FieldDescriptions.metadata, + json_schema_extra=InputFieldJSONSchemaExtra( + field_kind=FieldKind.Internal, + input=Input.Connection, + orig_required=False, + ).model_dump(exclude_none=True), ) diff --git a/invokeai/app/invocations/collections.py b/invokeai/app/invocations/collections.py index f26eebe1ff..4c7b6f94cd 100644 --- a/invokeai/app/invocations/collections.py +++ b/invokeai/app/invocations/collections.py @@ -5,7 +5,7 @@ import numpy as np from pydantic import ValidationInfo, field_validator from invokeai.app.invocations.primitives import IntegerCollectionOutput -from invokeai.app.util.misc import SEED_MAX, get_random_seed +from invokeai.app.util.misc import SEED_MAX from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation @@ -55,7 +55,7 @@ class RangeOfSizeInvocation(BaseInvocation): title="Random Range", tags=["range", "integer", "random", "collection"], category="collections", - version="1.0.0", + version="1.0.1", use_cache=False, ) class RandomRangeInvocation(BaseInvocation): @@ -65,10 +65,10 @@ class RandomRangeInvocation(BaseInvocation): high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value") size: int = InputField(default=1, description="The number of values to generate") seed: int = InputField( + default=0, ge=0, le=SEED_MAX, description="The seed for the RNG (omit for random)", - default_factory=get_random_seed, ) def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: diff --git a/invokeai/app/invocations/custom_nodes/init.py b/invokeai/app/invocations/custom_nodes/init.py index c6708e95a7..a379a35fbf 100644 --- a/invokeai/app/invocations/custom_nodes/init.py +++ b/invokeai/app/invocations/custom_nodes/init.py @@ -39,6 +39,8 @@ for d in Path(__file__).parent.iterdir(): logger.warn(f"Could not load {init}") continue + logger.info(f"Loading node pack {spec.name}") + module = module_from_spec(spec) sys.modules[spec.name] = module spec.loader.exec_module(module) @@ -47,5 +49,5 @@ for d in Path(__file__).parent.iterdir(): del init, module_name - -logger.info(f"Loaded {loaded_count} modules from {Path(__file__).parent}") +if loaded_count > 0: + logger.info(f"Loaded {loaded_count} node packs from {Path(__file__).parent}") diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index 9905aa1b5e..0822a4ce2d 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -8,7 +8,7 @@ from PIL import Image, ImageOps from invokeai.app.invocations.primitives import ColorField, ImageField, ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin -from invokeai.app.util.misc import SEED_MAX, get_random_seed +from invokeai.app.util.misc import SEED_MAX from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint from invokeai.backend.image_util.lama import LaMA from invokeai.backend.image_util.patchmatch import PatchMatch @@ -154,17 +154,17 @@ class InfillColorInvocation(BaseInvocation, WithWorkflow, WithMetadata): ) -@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.0") +@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.1.1") class InfillTileInvocation(BaseInvocation, WithWorkflow, WithMetadata): """Infills transparent areas of an image with tiles of the image""" image: ImageField = InputField(description="The image to infill") tile_size: int = InputField(default=32, ge=1, description="The tile size (px)") seed: int = InputField( + default=0, ge=0, le=SEED_MAX, description="The seed to use for tile generation (omit for random)", - default_factory=get_random_seed, ) def invoke(self, context: InvocationContext) -> ImageOutput: diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index 485932e18d..e0f582eab8 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -11,7 +11,6 @@ from invokeai.app.invocations.baseinvocation import ( InputField, InvocationContext, OutputField, - UIType, invocation, invocation_output, ) @@ -67,7 +66,7 @@ class IPAdapterInvocation(BaseInvocation): # weight: float = InputField(default=1.0, description="The weight of the IP-Adapter.", ui_type=UIType.Float) weight: Union[float, List[float]] = InputField( - default=1, ge=-1, description="The weight given to the IP-Adapter", ui_type=UIType.Float, title="Weight" + default=1, ge=-1, description="The weight given to the IP-Adapter", title="Weight" ) begin_step_percent: float = InputField( diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 9d4afb7020..d438bcae02 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -274,7 +274,10 @@ class DenoiseLatentsInvocation(BaseInvocation): ui_order=7, ) latents: Optional[LatentsField] = InputField( - default=None, description=FieldDescriptions.latents, input=Input.Connection + default=None, + description=FieldDescriptions.latents, + input=Input.Connection, + ui_order=4, ) denoise_mask: Optional[DenoiseMaskField] = InputField( default=None, diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 8cce9bdb88..99dcc72999 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -14,7 +14,6 @@ from .baseinvocation import ( InputField, InvocationContext, OutputField, - UIType, invocation, invocation_output, ) @@ -395,7 +394,6 @@ class VaeLoaderInvocation(BaseInvocation): vae_model: VAEModelField = InputField( description=FieldDescriptions.vae_model, input=Input.Direct, - ui_type=UIType.VaeModel, title="VAE", ) diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index e975b7bf22..b1ee91e1cd 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -6,7 +6,7 @@ from pydantic import field_validator from invokeai.app.invocations.latent import LatentsField from invokeai.app.shared.fields import FieldDescriptions -from invokeai.app.util.misc import SEED_MAX, get_random_seed +from invokeai.app.util.misc import SEED_MAX from ...backend.util.devices import choose_torch_device, torch_dtype from .baseinvocation import ( @@ -83,16 +83,16 @@ def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int): title="Noise", tags=["latents", "noise"], category="latents", - version="1.0.0", + version="1.0.1", ) class NoiseInvocation(BaseInvocation): """Generates latent noise.""" seed: int = InputField( + default=0, ge=0, le=SEED_MAX, description=FieldDescriptions.seed, - default_factory=get_random_seed, ) width: int = InputField( default=512, diff --git a/invokeai/app/invocations/primitives.py b/invokeai/app/invocations/primitives.py index ccfb7dcbb3..afe8ff06d9 100644 --- a/invokeai/app/invocations/primitives.py +++ b/invokeai/app/invocations/primitives.py @@ -62,12 +62,12 @@ class BooleanInvocation(BaseInvocation): title="Boolean Collection Primitive", tags=["primitives", "boolean", "collection"], category="primitives", - version="1.0.0", + version="1.0.1", ) class BooleanCollectionInvocation(BaseInvocation): """A collection of boolean primitive values""" - collection: list[bool] = InputField(default_factory=list, description="The collection of boolean values") + collection: list[bool] = InputField(default=[], description="The collection of boolean values") def invoke(self, context: InvocationContext) -> BooleanCollectionOutput: return BooleanCollectionOutput(collection=self.collection) @@ -111,12 +111,12 @@ class IntegerInvocation(BaseInvocation): title="Integer Collection Primitive", tags=["primitives", "integer", "collection"], category="primitives", - version="1.0.0", + version="1.0.1", ) class IntegerCollectionInvocation(BaseInvocation): """A collection of integer primitive values""" - collection: list[int] = InputField(default_factory=list, description="The collection of integer values") + collection: list[int] = InputField(default=[], description="The collection of integer values") def invoke(self, context: InvocationContext) -> IntegerCollectionOutput: return IntegerCollectionOutput(collection=self.collection) @@ -158,12 +158,12 @@ class FloatInvocation(BaseInvocation): title="Float Collection Primitive", tags=["primitives", "float", "collection"], category="primitives", - version="1.0.0", + version="1.0.1", ) class FloatCollectionInvocation(BaseInvocation): """A collection of float primitive values""" - collection: list[float] = InputField(default_factory=list, description="The collection of float values") + collection: list[float] = InputField(default=[], description="The collection of float values") def invoke(self, context: InvocationContext) -> FloatCollectionOutput: return FloatCollectionOutput(collection=self.collection) @@ -205,12 +205,12 @@ class StringInvocation(BaseInvocation): title="String Collection Primitive", tags=["primitives", "string", "collection"], category="primitives", - version="1.0.0", + version="1.0.1", ) class StringCollectionInvocation(BaseInvocation): """A collection of string primitive values""" - collection: list[str] = InputField(default_factory=list, description="The collection of string values") + collection: list[str] = InputField(default=[], description="The collection of string values") def invoke(self, context: InvocationContext) -> StringCollectionOutput: return StringCollectionOutput(collection=self.collection) @@ -467,13 +467,13 @@ class ConditioningInvocation(BaseInvocation): title="Conditioning Collection Primitive", tags=["primitives", "conditioning", "collection"], category="primitives", - version="1.0.0", + version="1.0.1", ) class ConditioningCollectionInvocation(BaseInvocation): """A collection of conditioning tensor primitive values""" collection: list[ConditioningField] = InputField( - default_factory=list, + default=[], description="The collection of conditioning tensors", ) diff --git a/invokeai/app/invocations/t2i_adapter.py b/invokeai/app/invocations/t2i_adapter.py index 8ff8ca762c..2412a00079 100644 --- a/invokeai/app/invocations/t2i_adapter.py +++ b/invokeai/app/invocations/t2i_adapter.py @@ -9,7 +9,6 @@ from invokeai.app.invocations.baseinvocation import ( InputField, InvocationContext, OutputField, - UIType, invocation, invocation_output, ) @@ -59,7 +58,7 @@ class T2IAdapterInvocation(BaseInvocation): ui_order=-1, ) weight: Union[float, list[float]] = InputField( - default=1, ge=0, description="The weight given to the T2I-Adapter", ui_type=UIType.Float, title="Weight" + default=1, ge=0, description="The weight given to the T2I-Adapter", title="Weight" ) begin_step_percent: float = InputField( default=0, ge=-1, le=2, description="When the T2I-Adapter is first applied (% of total steps)" diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 29af1e2333..ee86ef17c6 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -205,7 +205,7 @@ class IterateInvocationOutput(BaseInvocationOutput): """Used to connect iteration outputs. Will be expanded to a specific output.""" item: Any = OutputField( - description="The item being iterated over", title="Collection Item", ui_type=UIType.CollectionItem + description="The item being iterated over", title="Collection Item", ui_type=UIType._CollectionItem ) @@ -215,7 +215,7 @@ class IterateInvocation(BaseInvocation): """Iterates over a list of items""" collection: list[Any] = InputField( - description="The list of items to iterate over", default_factory=list, ui_type=UIType.Collection + description="The list of items to iterate over", default=[], ui_type=UIType._Collection ) index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True) @@ -227,7 +227,7 @@ class IterateInvocation(BaseInvocation): @invocation_output("collect_output") class CollectInvocationOutput(BaseInvocationOutput): collection: list[Any] = OutputField( - description="The collection of input items", title="Collection", ui_type=UIType.Collection + description="The collection of input items", title="Collection", ui_type=UIType._Collection ) @@ -238,12 +238,12 @@ class CollectInvocation(BaseInvocation): item: Optional[Any] = InputField( default=None, description="The item to collect (all inputs must be of the same type)", - ui_type=UIType.CollectionItem, + ui_type=UIType._CollectionItem, title="Collection Item", input=Input.Connection, ) collection: list[Any] = InputField( - description="The collection, will be provided on execution", default_factory=list, ui_hidden=True + description="The collection, will be provided on execution", default=[], ui_hidden=True ) def invoke(self, context: InvocationContext) -> CollectInvocationOutput: diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 9951b21cd8..faa870bd32 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -805,6 +805,8 @@ "clipField": "Clip", "clipFieldDescription": "Tokenizer and text_encoder submodels.", "collection": "Collection", + "collectionFieldType": "{{name}} Collection", + "polymorphicFieldType": "{{name}} Polymorphic", "collectionDescription": "TODO", "collectionItem": "Collection Item", "collectionItemDescription": "TODO", @@ -891,10 +893,15 @@ "mainModelField": "Model", "mainModelFieldDescription": "TODO", "maybeIncompatible": "May be Incompatible With Installed", - "mismatchedVersion": "Has Mismatched Version", + "mismatchedVersion": "Invalid node: node {{node}} of type {{type}} has mismatched version (try updating?)", "missingCanvaInitImage": "Missing canvas init image", "missingCanvaInitMaskImages": "Missing canvas init and mask images", - "missingTemplate": "Missing Template", + "missingTemplate": "Invalid node: node {{node}} of type {{type}} missing template (not installed?)", + "sourceNodeDoesNotExist": "Invalid edge: source/output node {{node}} does not exist", + "targetNodeDoesNotExist": "Invalid edge: target/input node {{node}} does not exist", + "sourceNodeFieldDoesNotExist": "Invalid edge: source/output field {{node}}.{{field}} does not exist", + "targetNodeFieldDoesNotExist": "Invalid edge: target/input field {{node}}.{{field}} does not exist", + "deletedInvalidEdge": "Deleted invalid edge {{source}} -> {{target}}", "noConnectionData": "No connection data", "noConnectionInProgress": "No connection in progress", "node": "Node", @@ -954,10 +961,17 @@ "stringDescription": "Strings are text.", "stringPolymorphic": "String Polymorphic", "stringPolymorphicDescription": "A collection of strings.", - "unableToLoadWorkflow": "Unable to Validate Workflow", + "unableToLoadWorkflow": "Unable to Load Workflow", "unableToParseEdge": "Unable to parse edge", "unableToParseNode": "Unable to parse node", + "unableToUpdateNode": "Unable to update node", "unableToValidateWorkflow": "Unable to Validate Workflow", + "unknownErrorValidatingWorkflow": "Unknown error validating workflow", + "inputFieldTypeParseError": "Unable to parse type of input field {{node}}.{{field}} ({{message}})", + "outputFieldTypeParseError": "Unable to parse type of output field {{node}}.{{field}} ({{message}})", + "unableToExtractSchemaNameFromRef": "unable to extract schema name from ref", + "unsupportedArrayItemType": "unsupported array item type \"{{type}}\"", + "unableToParseFieldType": "unable to parse field type", "uNetField": "UNet", "uNetFieldDescription": "UNet submodel.", "unhandledInputProperty": "Unhandled input property", @@ -971,8 +985,9 @@ "unkownInvocation": "Unknown Invocation type", "unknownOutput": "Unknown output", "updateNode": "Update Node", - "updateAllNodes": "Update All Nodes", "updateApp": "Update App", + "updateAllNodes": "Update All Nodes", + "allNodesUpdated": "All Nodes Updated", "unableToUpdateNodes_one": "Unable to update {{count}} node", "unableToUpdateNodes_other": "Unable to update {{count}} nodes", "vaeField": "Vae", @@ -981,6 +996,8 @@ "vaeModelFieldDescription": "TODO", "validateConnections": "Validate Connections and Graph", "validateConnectionsHelp": "Prevent invalid connections from being made, and invalid graphs from being invoked", + "unableToGetWorkflowVersion": "Unable to get workflow schema version", + "unrecognizedWorkflowVersion": "Unrecognized workflow schema version {{version}}", "version": "Version", "versionUnknown": " Version Unknown", "workflow": "Workflow", diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index 9c1727fc79..4a41cb3db6 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -71,7 +71,7 @@ import { addSocketUnsubscribedEventListener as addSocketUnsubscribedListener } f import { addStagingAreaImageSavedListener } from './listeners/stagingAreaImageSaved'; import { addTabChangedListener } from './listeners/tabChanged'; import { addUpscaleRequestedListener } from './listeners/upscaleRequested'; -import { addWorkflowLoadedListener } from './listeners/workflowLoaded'; +import { addWorkflowLoadRequestedListener } from './listeners/workflowLoadRequested'; import { addUpdateAllNodesRequestedListener } from './listeners/updateAllNodesRequested'; export const listenerMiddleware = createListenerMiddleware(); @@ -178,7 +178,7 @@ addBoardIdSelectedListener(); addReceivedOpenAPISchemaListener(); // Workflows -addWorkflowLoadedListener(); +addWorkflowLoadRequestedListener(); addUpdateAllNodesRequestedListener(); // DND diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts index 1996ec99a5..0966a8c86b 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/controlNetImageProcessed.ts @@ -12,10 +12,10 @@ import { addToast } from 'features/system/store/systemSlice'; import { t } from 'i18next'; import { imagesApi } from 'services/api/endpoints/images'; import { queueApi } from 'services/api/endpoints/queue'; -import { isImageOutput } from 'services/api/guards'; import { BatchConfig, ImageDTO } from 'services/api/types'; import { socketInvocationComplete } from 'services/events/actions'; import { startAppListening } from '..'; +import { isImageOutput } from 'features/nodes/types/common'; export const addControlNetImageProcessedListener = () => { startAppListening({ diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts index bd5422841f..f23b7284fe 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDeleted.ts @@ -5,19 +5,20 @@ import { controlAdapterProcessedImageChanged, selectControlAdapterAll, } from 'features/controlAdapters/store/controlAdaptersSlice'; +import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types'; import { imageDeletionConfirmed } from 'features/deleteImageModal/store/actions'; import { isModalOpenChanged } from 'features/deleteImageModal/store/slice'; import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors'; import { imageSelected } from 'features/gallery/store/gallerySlice'; import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice'; -import { isInvocationNode } from 'features/nodes/types/types'; +import { isImageFieldInputInstance } from 'features/nodes/types/field'; +import { isInvocationNode } from 'features/nodes/types/invocation'; import { clearInitialImage } from 'features/parameters/store/generationSlice'; import { clamp, forEach } from 'lodash-es'; import { api } from 'services/api'; import { imagesApi } from 'services/api/endpoints/images'; import { imagesAdapter } from 'services/api/util'; import { startAppListening } from '..'; -import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types'; export const addRequestedSingleImageDeletionListener = () => { startAppListening({ @@ -121,7 +122,7 @@ export const addRequestedSingleImageDeletionListener = () => { forEach(node.data.inputs, (input) => { if ( - input.type === 'ImageField' && + isImageFieldInputInstance(input) && input.value?.image_name === imageDTO.image_name ) { dispatch( @@ -241,7 +242,7 @@ export const addRequestedMultipleImageDeletionListener = () => { forEach(node.data.inputs, (input) => { if ( - input.type === 'ImageField' && + isImageFieldInputInstance(input) && input.value?.image_name === imageDTO.image_name ) { dispatch( diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts index 6ed0b93e99..e4175affe6 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelSelected.ts @@ -12,12 +12,12 @@ import { setWidth, vaeSelected, } from 'features/parameters/store/generationSlice'; -import { zMainOrOnnxModel } from 'features/parameters/types/parameterSchemas'; import { addToast } from 'features/system/store/systemSlice'; import { makeToast } from 'features/system/util/makeToast'; import { t } from 'i18next'; import { forEach } from 'lodash-es'; import { startAppListening } from '..'; +import { zParameterModel } from 'features/parameters/types/parameterSchemas'; export const addModelSelectedListener = () => { startAppListening({ @@ -26,7 +26,7 @@ export const addModelSelectedListener = () => { const log = logger('models'); const state = getState(); - const result = zMainOrOnnxModel.safeParse(action.payload); + const result = zParameterModel.safeParse(action.payload); if (!result.success) { log.error( diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts index 785630495b..afb390470b 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts @@ -11,9 +11,9 @@ import { vaeSelected, } from 'features/parameters/store/generationSlice'; import { - zMainOrOnnxModel, - zSDXLRefinerModel, - zVaeModel, + zParameterModel, + zParameterSDXLRefinerModel, + zParameterVAEModel, } from 'features/parameters/types/parameterSchemas'; import { refinerModelChanged, @@ -67,7 +67,7 @@ export const addModelsLoadedListener = () => { return; } - const result = zMainOrOnnxModel.safeParse(models[0]); + const result = zParameterModel.safeParse(models[0]); if (!result.success) { log.error( @@ -119,7 +119,7 @@ export const addModelsLoadedListener = () => { return; } - const result = zSDXLRefinerModel.safeParse(models[0]); + const result = zParameterSDXLRefinerModel.safeParse(models[0]); if (!result.success) { log.error( @@ -170,7 +170,7 @@ export const addModelsLoadedListener = () => { return; } - const result = zVaeModel.safeParse(firstModel); + const result = zParameterVAEModel.safeParse(firstModel); if (!result.success) { log.error( diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts index 5599913a18..f5b630a39d 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts @@ -15,6 +15,7 @@ export const addReceivedOpenAPISchemaListener = () => { log.debug({ schemaJSON }, 'Received OpenAPI schema'); const { nodesAllowlist, nodesDenylist } = getState().config; + const nodeTemplates = parseSchema( schemaJSON, nodesAllowlist, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts index cfd69ce9bc..bc9959b8fc 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts @@ -13,13 +13,13 @@ import { } from 'features/nodes/util/graphBuilders/constants'; import { boardsApi } from 'services/api/endpoints/boards'; import { imagesApi } from 'services/api/endpoints/images'; -import { isImageOutput } from 'services/api/guards'; import { imagesAdapter } from 'services/api/util'; import { appSocketInvocationComplete, socketInvocationComplete, } from 'services/events/actions'; import { startAppListening } from '../..'; +import { isImageOutput } from 'features/nodes/types/common'; // These nodes output an image, but do not actually *save* an image, so we don't want to handle the gallery logic on them const nodeTypeDenylist = ['load_image', 'image']; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts index ece6702ceb..b2383410bd 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts @@ -1,14 +1,16 @@ +import { logger } from 'app/logging/logger'; +import { updateAllNodesRequested } from 'features/nodes/store/actions'; +import { nodeReplaced } from 'features/nodes/store/nodesSlice'; import { getNeedsUpdate, updateNode, -} from 'features/nodes/hooks/useNodeVersion'; -import { updateAllNodesRequested } from 'features/nodes/store/actions'; -import { nodeReplaced } from 'features/nodes/store/nodesSlice'; -import { startAppListening } from '..'; -import { logger } from 'app/logging/logger'; +} from 'features/nodes/store/util/nodeUpdate'; +import { NodeUpdateError } from 'features/nodes/types/error'; +import { isInvocationNode } from 'features/nodes/types/invocation'; import { addToast } from 'features/system/store/systemSlice'; import { makeToast } from 'features/system/util/makeToast'; import { t } from 'i18next'; +import { startAppListening } from '..'; export const addUpdateAllNodesRequestedListener = () => { startAppListening({ @@ -20,22 +22,31 @@ export const addUpdateAllNodesRequestedListener = () => { let unableToUpdateCount = 0; - nodes.forEach((node) => { + nodes.filter(isInvocationNode).forEach((node) => { const template = templates[node.data.type]; - const needsUpdate = getNeedsUpdate(node, template); - const updatedNode = updateNode(node, template); - if (!updatedNode) { - if (needsUpdate) { - unableToUpdateCount++; - } + if (!template) { + unableToUpdateCount++; return; } - dispatch(nodeReplaced({ nodeId: updatedNode.id, node: updatedNode })); + if (!getNeedsUpdate(node, template)) { + // No need to increment the count here, since we're not actually updating + return; + } + try { + const updatedNode = updateNode(node, template); + dispatch(nodeReplaced({ nodeId: updatedNode.id, node: updatedNode })); + } catch (e) { + if (e instanceof NodeUpdateError) { + unableToUpdateCount++; + } + } }); if (unableToUpdateCount) { log.warn( - `Unable to update ${unableToUpdateCount} nodes. Please report this issue.` + t('nodes.unableToUpdateNodes', { + count: unableToUpdateCount, + }) ); dispatch( addToast( @@ -46,6 +57,15 @@ export const addUpdateAllNodesRequestedListener = () => { }) ) ); + } else { + dispatch( + addToast( + makeToast({ + title: t('nodes.allNodesUpdated'), + status: 'success', + }) + ) + ); } }, }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts new file mode 100644 index 0000000000..5336c63942 --- /dev/null +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts @@ -0,0 +1,105 @@ +import { logger } from 'app/logging/logger'; +import { parseify } from 'common/util/serialize'; +import { workflowLoadRequested } from 'features/nodes/store/actions'; +import { workflowLoaded } from 'features/nodes/store/nodesSlice'; +import { $flow } from 'features/nodes/store/reactFlowInstance'; +import { WorkflowVersionError } from 'features/nodes/types/error'; +import { validateWorkflow } from 'features/nodes/util/validateWorkflow'; +import { addToast } from 'features/system/store/systemSlice'; +import { makeToast } from 'features/system/util/makeToast'; +import { setActiveTab } from 'features/ui/store/uiSlice'; +import { t } from 'i18next'; +import { z } from 'zod'; +import { fromZodError } from 'zod-validation-error'; +import { startAppListening } from '..'; + +export const addWorkflowLoadRequestedListener = () => { + startAppListening({ + actionCreator: workflowLoadRequested, + effect: (action, { dispatch, getState }) => { + const log = logger('nodes'); + const workflow = action.payload; + const nodeTemplates = getState().nodes.nodeTemplates; + + try { + const { workflow: validatedWorkflow, warnings } = validateWorkflow( + workflow, + nodeTemplates + ); + dispatch(workflowLoaded(validatedWorkflow)); + if (!warnings.length) { + dispatch( + addToast( + makeToast({ + title: t('toast.workflowLoaded'), + status: 'success', + }) + ) + ); + } else { + dispatch( + addToast( + makeToast({ + title: t('toast.loadedWithWarnings'), + status: 'warning', + }) + ) + ); + warnings.forEach(({ message, ...rest }) => { + log.warn(rest, message); + }); + } + + dispatch(setActiveTab('nodes')); + requestAnimationFrame(() => { + $flow.get()?.fitView(); + }); + } catch (e) { + if (e instanceof WorkflowVersionError) { + // The workflow version was not recognized in the valid list of versions + log.error({ error: parseify(e) }, e.message); + dispatch( + addToast( + makeToast({ + title: t('nodes.unableToValidateWorkflow'), + status: 'error', + description: e.message, + }) + ) + ); + } else if (e instanceof z.ZodError) { + // There was a problem validating the workflow itself + const { message } = fromZodError(e, { + prefix: t('nodes.workflowValidation'), + }); + log.error({ error: parseify(e) }, message); + dispatch( + addToast( + makeToast({ + title: t('nodes.unableToValidateWorkflow'), + status: 'error', + description: message, + }) + ) + ); + } else { + // Some other error occurred + console.log(e); + log.error( + { error: parseify(e) }, + t('nodes.unknownErrorValidatingWorkflow') + ); + dispatch( + addToast( + makeToast({ + title: t('nodes.unableToValidateWorkflow'), + status: 'error', + description: t('nodes.unknownErrorValidatingWorkflow'), + }) + ) + ); + } + } + }, + }); +}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoaded.ts deleted file mode 100644 index de697a70e5..0000000000 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoaded.ts +++ /dev/null @@ -1,56 +0,0 @@ -import { logger } from 'app/logging/logger'; -import { workflowLoadRequested } from 'features/nodes/store/actions'; -import { workflowLoaded } from 'features/nodes/store/nodesSlice'; -import { $flow } from 'features/nodes/store/reactFlowInstance'; -import { validateWorkflow } from 'features/nodes/util/validateWorkflow'; -import { addToast } from 'features/system/store/systemSlice'; -import { makeToast } from 'features/system/util/makeToast'; -import { setActiveTab } from 'features/ui/store/uiSlice'; -import { startAppListening } from '..'; -import { t } from 'i18next'; - -export const addWorkflowLoadedListener = () => { - startAppListening({ - actionCreator: workflowLoadRequested, - effect: (action, { dispatch, getState }) => { - const log = logger('nodes'); - const workflow = action.payload; - const nodeTemplates = getState().nodes.nodeTemplates; - - const { workflow: validatedWorkflow, errors } = validateWorkflow( - workflow, - nodeTemplates - ); - - dispatch(workflowLoaded(validatedWorkflow)); - - if (!errors.length) { - dispatch( - addToast( - makeToast({ - title: t('toast.workflowLoaded'), - status: 'success', - }) - ) - ); - } else { - dispatch( - addToast( - makeToast({ - title: t('toast.loadedWithWarnings'), - status: 'warning', - }) - ) - ); - errors.forEach(({ message, ...rest }) => { - log.warn(rest, message); - }); - } - - dispatch(setActiveTab('nodes')); - requestAnimationFrame(() => { - $flow.get()?.fitView(); - }); - }, - }); -}; diff --git a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts index 309154db50..b61dfee857 100644 --- a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts +++ b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts @@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { selectControlAdapterAll } from 'features/controlAdapters/store/controlAdaptersSlice'; import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types'; -import { isInvocationNode } from 'features/nodes/types/types'; +import { isInvocationNode } from 'features/nodes/types/invocation'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import i18n from 'i18next'; import { forEach } from 'lodash-es'; diff --git a/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts b/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts index 9e293f1104..cdaba3e9a1 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts +++ b/invokeai/frontend/web/src/features/controlAdapters/store/controlAdaptersSlice.ts @@ -6,9 +6,9 @@ import { isAnyOf, } from '@reduxjs/toolkit'; import { - ControlNetModelParam, - IPAdapterModelParam, - T2IAdapterModelParam, + ParameterControlNetModel, + ParameterIPAdapterModel, + ParameterT2IAdapterModel, } from 'features/parameters/types/parameterSchemas'; import { cloneDeep, merge, uniq } from 'lodash-es'; import { appSocketInvocationError } from 'services/events/actions'; @@ -243,9 +243,9 @@ export const controlAdaptersSlice = createSlice({ action: PayloadAction<{ id: string; model: - | ControlNetModelParam - | T2IAdapterModelParam - | IPAdapterModelParam; + | ParameterControlNetModel + | ParameterT2IAdapterModel + | ParameterIPAdapterModel; }> ) => { const { id, model } = action.payload; diff --git a/invokeai/frontend/web/src/features/controlAdapters/store/types.ts b/invokeai/frontend/web/src/features/controlAdapters/store/types.ts index afc6df45e4..ea63600cdd 100644 --- a/invokeai/frontend/web/src/features/controlAdapters/store/types.ts +++ b/invokeai/frontend/web/src/features/controlAdapters/store/types.ts @@ -1,8 +1,8 @@ import { EntityState } from '@reduxjs/toolkit'; import { - ControlNetModelParam, - IPAdapterModelParam, - T2IAdapterModelParam, + ParameterControlNetModel, + ParameterIPAdapterModel, + ParameterT2IAdapterModel, } from 'features/parameters/types/parameterSchemas'; import { isObject } from 'lodash-es'; import { components } from 'services/api/schema'; @@ -378,7 +378,7 @@ export type ControlNetConfig = { type: 'controlnet'; id: string; isEnabled: boolean; - model: ControlNetModelParam | null; + model: ParameterControlNetModel | null; weight: number; beginStepPct: number; endStepPct: number; @@ -395,7 +395,7 @@ export type T2IAdapterConfig = { type: 't2i_adapter'; id: string; isEnabled: boolean; - model: T2IAdapterModelParam | null; + model: ParameterT2IAdapterModel | null; weight: number; beginStepPct: number; endStepPct: number; @@ -412,7 +412,7 @@ export type IPAdapterConfig = { id: string; isEnabled: boolean; controlImage: string | null; - model: IPAdapterModelParam | null; + model: ParameterIPAdapterModel | null; weight: number; beginStepPct: number; endStepPct: number; diff --git a/invokeai/frontend/web/src/features/deleteImageModal/store/selectors.ts b/invokeai/frontend/web/src/features/deleteImageModal/store/selectors.ts index d8e68dca21..387d5916fa 100644 --- a/invokeai/frontend/web/src/features/deleteImageModal/store/selectors.ts +++ b/invokeai/frontend/web/src/features/deleteImageModal/store/selectors.ts @@ -1,11 +1,12 @@ import { createSelector } from '@reduxjs/toolkit'; import { RootState } from 'app/store/store'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { isInvocationNode } from 'features/nodes/types/types'; +import { isInvocationNode } from 'features/nodes/types/invocation'; import { some } from 'lodash-es'; import { ImageUsage } from './types'; import { selectControlAdapterAll } from 'features/controlAdapters/store/controlAdaptersSlice'; import { isControlNetOrT2IAdapter } from 'features/controlAdapters/store/types'; +import { isImageFieldInputInstance } from 'features/nodes/types/field'; export const getImageUsage = (state: RootState, image_name: string) => { const { generation, canvas, nodes, controlAdapters } = state; @@ -19,7 +20,8 @@ export const getImageUsage = (state: RootState, image_name: string) => { return some( node.data.inputs, (input) => - input.type === 'ImageField' && input.value?.image_name === image_name + isImageFieldInputInstance(input) && + input.value?.image_name === image_name ); }); diff --git a/invokeai/frontend/web/src/features/dnd/types/index.ts b/invokeai/frontend/web/src/features/dnd/types/index.ts index f5254f8a5a..45f325ebd1 100644 --- a/invokeai/frontend/web/src/features/dnd/types/index.ts +++ b/invokeai/frontend/web/src/features/dnd/types/index.ts @@ -11,9 +11,9 @@ import { useDroppable as useOriginalDroppable, } from '@dnd-kit/core'; import { - InputFieldTemplate, - InputFieldValue, -} from 'features/nodes/types/types'; + FieldInputTemplate, + FieldInputInstance, +} from 'features/nodes/types/field'; import { ImageDTO } from 'services/api/types'; type BaseDropData = { @@ -93,8 +93,8 @@ export type NodeFieldDraggableData = BaseDragData & { payloadType: 'NODE_FIELD'; payload: { nodeId: string; - field: InputFieldValue; - fieldTemplate: InputFieldTemplate; + field: FieldInputInstance; + fieldTemplate: FieldInputTemplate; }; }; diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx index ce5b178fa2..537df1bd28 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataActions.tsx @@ -4,14 +4,14 @@ import { LoRAMetadataItem, IPAdapterMetadataItem, T2IAdapterMetadataItem, -} from 'features/nodes/types/types'; +} from 'features/nodes/types/metadata'; import { useRecallParameters } from 'features/parameters/hooks/useRecallParameters'; import { memo, useMemo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; import { - isValidControlNetModel, - isValidLoRAModel, - isValidT2IAdapterModel, + isParameterControlNetModel, + isParameterLoRAModel, + isParameterT2IAdapterModel, } from '../../../parameters/types/parameterSchemas'; import ImageMetadataItem from './ImageMetadataItem'; @@ -132,7 +132,7 @@ const ImageMetadataActions = (props: Props) => { const validControlNets: ControlNetMetadataItem[] = useMemo(() => { return metadata?.controlnets ? metadata.controlnets.filter((controlnet) => - isValidControlNetModel(controlnet.control_model) + isParameterControlNetModel(controlnet.control_model) ) : []; }, [metadata?.controlnets]); @@ -140,7 +140,7 @@ const ImageMetadataActions = (props: Props) => { const validIPAdapters: IPAdapterMetadataItem[] = useMemo(() => { return metadata?.ipAdapters ? metadata.ipAdapters.filter((ipAdapter) => - isValidControlNetModel(ipAdapter.ip_adapter_model) + isParameterControlNetModel(ipAdapter.ip_adapter_model) ) : []; }, [metadata?.ipAdapters]); @@ -148,7 +148,7 @@ const ImageMetadataActions = (props: Props) => { const validT2IAdapters: T2IAdapterMetadataItem[] = useMemo(() => { return metadata?.t2iAdapters ? metadata.t2iAdapters.filter((t2iAdapter) => - isValidT2IAdapterModel(t2iAdapter.t2i_adapter_model) + isParameterT2IAdapterModel(t2iAdapter.t2i_adapter_model) ) : []; }, [metadata?.t2iAdapters]); @@ -157,8 +157,6 @@ const ImageMetadataActions = (props: Props) => { return null; } - console.log(metadata); - return ( <> {metadata.created_by && ( @@ -275,7 +273,7 @@ const ImageMetadataActions = (props: Props) => { )} {metadata.loras && metadata.loras.map((lora, index) => { - if (isValidLoRAModel(lora.lora)) { + if (isParameterLoRAModel(lora.lora)) { return ( { const { t } = useTranslation(); const fieldFilter = useAppSelector( - (state) => state.nodes.currentConnectionFieldType + (state) => state.nodes.connectionStartFieldType ); const handleFilter = useAppSelector( (state) => state.nodes.connectionStartParams?.handleType @@ -111,7 +110,7 @@ const AddNodePopover = () => { data.sort((a, b) => a.label.localeCompare(b.label)); - return { data, t }; + return { data }; }, defaultSelectorOptions ); @@ -121,7 +120,7 @@ const AddNodePopover = () => { const inputRef = useRef(null); const addNode = useCallback( - (nodeType: AnyInvocationType) => { + (nodeType: string) => { const invocation = buildInvocation(nodeType); if (!invocation) { const errorMessage = t('nodes.unknownNode', { @@ -145,7 +144,7 @@ const AddNodePopover = () => { return; } - addNode(v as AnyInvocationType); + addNode(v); }, [addNode] ); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/connectionLines/CustomConnectionLine.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/connectionLines/CustomConnectionLine.tsx index a379be7ee2..f3d705b347 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/connectionLines/CustomConnectionLine.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/connectionLines/CustomConnectionLine.tsx @@ -2,18 +2,17 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; -import { FIELDS } from 'features/nodes/types/constants'; import { memo } from 'react'; import { ConnectionLineComponentProps, getBezierPath } from 'reactflow'; +import { getFieldColor } from '../edges/util/getEdgeColor'; const selector = createSelector(stateSelector, ({ nodes }) => { - const { shouldAnimateEdges, currentConnectionFieldType, shouldColorEdges } = + const { shouldAnimateEdges, connectionStartFieldType, shouldColorEdges } = nodes; - const stroke = - currentConnectionFieldType && shouldColorEdges - ? colorTokenToCssVar(FIELDS[currentConnectionFieldType].color) - : colorTokenToCssVar('base.500'); + const stroke = shouldColorEdges + ? getFieldColor(connectionStartFieldType) + : colorTokenToCssVar('base.500'); let className = 'react-flow__custom_connection-path'; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts new file mode 100644 index 0000000000..15c63b0bae --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/getEdgeColor.ts @@ -0,0 +1,12 @@ +import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; +import { FIELD_COLORS } from 'features/nodes/types/constants'; +import { FieldType } from 'features/nodes/types/field'; + +export const getFieldColor = (fieldType: FieldType | null): string => { + if (!fieldType) { + return colorTokenToCssVar('base.500'); + } + const color = FIELD_COLORS[fieldType.name]; + + return color ? colorTokenToCssVar(color) : colorTokenToCssVar('base.500'); +}; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts index b5dc484eae..73d3d5dc4d 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/makeEdgeSelector.ts @@ -2,8 +2,8 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; -import { FIELDS } from 'features/nodes/types/constants'; -import { isInvocationNode } from 'features/nodes/types/types'; +import { isInvocationNode } from 'features/nodes/types/invocation'; +import { getFieldColor } from './getEdgeColor'; export const makeEdgeSelector = ( source: string, @@ -29,7 +29,7 @@ export const makeEdgeSelector = ( const stroke = sourceType && nodes.shouldColorEdges - ? colorTokenToCssVar(FIELDS[sourceType].color) + ? getFieldColor(sourceType) : colorTokenToCssVar('base.500'); return { diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeCollapsedHandles.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeCollapsedHandles.tsx index 30e02bfd84..b1ca6ac22f 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeCollapsedHandles.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeCollapsedHandles.tsx @@ -1,7 +1,7 @@ import { useColorModeValue } from '@chakra-ui/react'; import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens'; import { useNodeData } from 'features/nodes/hooks/useNodeData'; -import { isInvocationNodeData } from 'features/nodes/types/types'; +import { isInvocationNodeData } from 'features/nodes/types/invocation'; import { map } from 'lodash-es'; import { CSSProperties, memo, useMemo } from 'react'; import { Handle, Position } from 'reactflow'; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeInfoIcon.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeInfoIcon.tsx index 83867a35cb..a439538075 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeInfoIcon.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeInfoIcon.tsx @@ -2,8 +2,8 @@ import { Flex, Icon, Text, Tooltip } from '@chakra-ui/react'; import { compare } from 'compare-versions'; import { useNodeData } from 'features/nodes/hooks/useNodeData'; import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate'; -import { useNodeVersion } from 'features/nodes/hooks/useNodeVersion'; -import { isInvocationNodeData } from 'features/nodes/types/types'; +import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate'; +import { isInvocationNodeData } from 'features/nodes/types/invocation'; import { memo, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { FaInfoCircle } from 'react-icons/fa'; @@ -13,7 +13,7 @@ interface Props { } const InvocationNodeInfoIcon = ({ nodeId }: Props) => { - const { needsUpdate } = useNodeVersion(nodeId); + const needsUpdate = useNodeNeedsUpdate(nodeId); return ( { const { status, progress, progressImage } = nodeExecutionState; const { t } = useTranslation(); - if (status === NodeStatus.PENDING) { + if (status === zNodeStatus.enum.PENDING) { return {t('queue.pending')}; } - if (status === NodeStatus.IN_PROGRESS) { + if (status === zNodeStatus.enum.IN_PROGRESS) { if (progressImage) { return ( @@ -108,11 +111,11 @@ const TooltipLabel = memo(({ nodeExecutionState }: TooltipLabelProps) => { return {t('nodes.executionStateInProgress')}; } - if (status === NodeStatus.COMPLETED) { + if (status === zNodeStatus.enum.COMPLETED) { return {t('nodes.executionStateCompleted')}; } - if (status === NodeStatus.FAILED) { + if (status === zNodeStatus.enum.FAILED) { return {t('nodes.executionStateError')}; } @@ -127,7 +130,7 @@ type StatusIconProps = { const StatusIcon = memo((props: StatusIconProps) => { const { progress, status } = props.nodeExecutionState; - if (status === NodeStatus.PENDING) { + if (status === zNodeStatus.enum.PENDING) { return ( { /> ); } - if (status === NodeStatus.IN_PROGRESS) { + if (status === zNodeStatus.enum.IN_PROGRESS) { return progress === null ? ( { /> ); } - if (status === NodeStatus.COMPLETED) { + if (status === zNodeStatus.enum.COMPLETED) { return ( { /> ); } - if (status === NodeStatus.FAILED) { + if (status === zNodeStatus.enum.FAILED) { return ( { ); const mayExpose = useMemo( - () => ['any', 'direct'].includes(input ?? '__UNKNOWN_INPUT__'), + () => input && ['any', 'direct'].includes(input), [input] ); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx index 3166590254..a622e5018c 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/FieldHandle.tsx @@ -1,18 +1,17 @@ import { Tooltip } from '@chakra-ui/react'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; +import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType'; import { - COLLECTION_TYPES, - FIELDS, HANDLE_TOOLTIP_OPEN_DELAY, MODEL_TYPES, - POLYMORPHIC_TYPES, } from 'features/nodes/types/constants'; import { - InputFieldTemplate, - OutputFieldTemplate, -} from 'features/nodes/types/types'; + FieldInputTemplate, + FieldOutputTemplate, +} from 'features/nodes/types/field'; import { CSSProperties, memo, useMemo } from 'react'; import { Handle, HandleType, Position } from 'reactflow'; +import { getFieldColor } from '../../../edges/util/getEdgeColor'; export const handleBaseStyles: CSSProperties = { position: 'absolute', @@ -32,11 +31,11 @@ export const outputHandleStyles: CSSProperties = { }; type FieldHandleProps = { - fieldTemplate: InputFieldTemplate | OutputFieldTemplate; + fieldTemplate: FieldInputTemplate | FieldOutputTemplate; handleType: HandleType; isConnectionInProgress: boolean; isConnectionStartField: boolean; - connectionError: string | null; + connectionError?: string; }; const FieldHandle = (props: FieldHandleProps) => { @@ -47,23 +46,21 @@ const FieldHandle = (props: FieldHandleProps) => { isConnectionStartField, connectionError, } = props; - const { name, type } = fieldTemplate; - const { color: typeColor, title } = FIELDS[type]; - + const { name } = fieldTemplate; + const type = fieldTemplate.type; + const fieldTypeName = useFieldTypeName(type); const styles: CSSProperties = useMemo(() => { - const isCollectionType = COLLECTION_TYPES.includes(type); - const isPolymorphicType = POLYMORPHIC_TYPES.includes(type); - const isModelType = MODEL_TYPES.includes(type); - const color = colorTokenToCssVar(typeColor); + const isModelType = MODEL_TYPES.some((t) => t === type.name); + const color = getFieldColor(type); const s: CSSProperties = { backgroundColor: - isCollectionType || isPolymorphicType - ? 'var(--invokeai-colors-base-900)' + type.isCollection || type.isPolymorphic + ? colorTokenToCssVar('base.900') : color, position: 'absolute', width: '1rem', height: '1rem', - borderWidth: isCollectionType || isPolymorphicType ? 4 : 0, + borderWidth: type.isCollection || type.isPolymorphic ? 4 : 0, borderStyle: 'solid', borderColor: color, borderRadius: isModelType ? 4 : '100%', @@ -97,18 +94,14 @@ const FieldHandle = (props: FieldHandleProps) => { isConnectionInProgress, isConnectionStartField, type, - typeColor, ]); const tooltip = useMemo(() => { - if (isConnectionInProgress && isConnectionStartField) { - return title; - } if (isConnectionInProgress && connectionError) { - return connectionError ?? title; + return connectionError; } - return title; - }, [connectionError, isConnectionInProgress, isConnectionStartField, title]); + return fieldTypeName; + }, [connectionError, fieldTypeName, isConnectionInProgress]); return ( { - const field = useFieldData(nodeId, fieldName); + const field = useFieldInstance(nodeId, fieldName); const fieldTemplate = useFieldTemplate(nodeId, fieldName, kind); - const isInputTemplate = isInputFieldTemplate(fieldTemplate); + const isInputTemplate = isFieldInputTemplate(fieldTemplate); + const fieldTypeName = useFieldTypeName(fieldTemplate?.type); const { t } = useTranslation(); const fieldTitle = useMemo(() => { - if (isInputFieldValue(field)) { + if (isFieldInputInstance(field)) { if (field.label && fieldTemplate?.title) { return `${field.label} (${fieldTemplate.title})`; } @@ -49,9 +49,9 @@ const FieldTooltipContent = ({ nodeId, fieldName, kind }: Props) => { {fieldTemplate.description} )} - {fieldTemplate && ( + {fieldTypeName && ( - {t('parameters.type')}: {FIELDS[fieldTemplate.type].title} + {t('parameters.type')}: {fieldTypeName} )} {isInputTemplate && ( diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx index 4a48971602..dac9404c26 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx @@ -77,10 +77,10 @@ const InputField = ({ nodeId, fieldName }: Props) => { sx={{ display: 'flex', alignItems: 'center', - h: 'full', mb: 0, px: 1, gap: 2, + h: 'full', }} > { const { t } = useTranslation(); - const field = useFieldData(nodeId, fieldName); + const fieldInstance = useFieldInstance(nodeId, fieldName); const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input'); if (fieldTemplate?.fieldKind === 'output') { return ( - {t('nodes.outputFieldInInput')}: {field?.type} + {t('nodes.outputFieldInInput')}: {fieldInstance?.type.name} ); } if ( - (field?.type === 'string' && fieldTemplate?.type === 'string') || - (field?.type === 'StringPolymorphic' && - fieldTemplate?.type === 'StringPolymorphic') + isStringFieldInputInstance(fieldInstance) && + isStringFieldInputTemplate(fieldTemplate) ) { return ( - ); } if ( - (field?.type === 'boolean' && fieldTemplate?.type === 'boolean') || - (field?.type === 'BooleanPolymorphic' && - fieldTemplate?.type === 'BooleanPolymorphic') + isBooleanFieldInputInstance(fieldInstance) && + isBooleanFieldInputTemplate(fieldTemplate) ) { return ( - ); } if ( - (field?.type === 'integer' && fieldTemplate?.type === 'integer') || - (field?.type === 'float' && fieldTemplate?.type === 'float') || - (field?.type === 'FloatPolymorphic' && - fieldTemplate?.type === 'FloatPolymorphic') || - (field?.type === 'IntegerPolymorphic' && - fieldTemplate?.type === 'IntegerPolymorphic') + (isIntegerFieldInputInstance(fieldInstance) && + isIntegerFieldInputTemplate(fieldTemplate)) || + (isFloatFieldInputInstance(fieldInstance) && + isFloatFieldInputTemplate(fieldTemplate)) ) { return ( - - ); - } - - if (field?.type === 'enum' && fieldTemplate?.type === 'enum') { - return ( - ); } if ( - (field?.type === 'ImageField' && fieldTemplate?.type === 'ImageField') || - (field?.type === 'ImagePolymorphic' && - fieldTemplate?.type === 'ImagePolymorphic') + isEnumFieldInputInstance(fieldInstance) && + isEnumFieldInputTemplate(fieldTemplate) ) { return ( - - ); - } - - if (field?.type === 'BoardField' && fieldTemplate?.type === 'BoardField') { - return ( - ); } if ( - field?.type === 'MainModelField' && - fieldTemplate?.type === 'MainModelField' + isImageFieldInputInstance(fieldInstance) && + isImageFieldInputTemplate(fieldTemplate) ) { return ( - ); } if ( - field?.type === 'SDXLRefinerModelField' && - fieldTemplate?.type === 'SDXLRefinerModelField' + isBoardFieldInputInstance(fieldInstance) && + isBoardFieldInputTemplate(fieldTemplate) ) { return ( - ); } if ( - field?.type === 'VaeModelField' && - fieldTemplate?.type === 'VaeModelField' + isMainModelFieldInputInstance(fieldInstance) && + isMainModelFieldInputTemplate(fieldTemplate) ) { return ( - ); } if ( - field?.type === 'LoRAModelField' && - fieldTemplate?.type === 'LoRAModelField' + isSDXLRefinerModelFieldInputInstance(fieldInstance) && + isSDXLRefinerModelFieldInputTemplate(fieldTemplate) ) { return ( - ); } if ( - field?.type === 'ControlNetModelField' && - fieldTemplate?.type === 'ControlNetModelField' + isVAEModelFieldInputInstance(fieldInstance) && + isVAEModelFieldInputTemplate(fieldTemplate) ) { return ( - ); } if ( - field?.type === 'IPAdapterModelField' && - fieldTemplate?.type === 'IPAdapterModelField' + isLoRAModelFieldInputInstance(fieldInstance) && + isLoRAModelFieldInputTemplate(fieldTemplate) ) { return ( - ); } if ( - field?.type === 'T2IAdapterModelField' && - fieldTemplate?.type === 'T2IAdapterModelField' + isControlNetModelFieldInputInstance(fieldInstance) && + isControlNetModelFieldInputTemplate(fieldTemplate) ) { return ( - - ); - } - if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') { - return ( - ); } if ( - field?.type === 'SDXLMainModelField' && - fieldTemplate?.type === 'SDXLMainModelField' + isIPAdapterModelFieldInputInstance(fieldInstance) && + isIPAdapterModelFieldInputTemplate(fieldTemplate) ) { return ( - ); } - if (field?.type === 'Scheduler' && fieldTemplate?.type === 'Scheduler') { + if ( + isT2IAdapterModelFieldInputInstance(fieldInstance) && + isT2IAdapterModelFieldInputTemplate(fieldTemplate) + ) { return ( - + ); + } + if ( + isColorFieldInputInstance(fieldInstance) && + isColorFieldInputTemplate(fieldTemplate) + ) { + return ( + ); } - if (field && fieldTemplate) { + if ( + isSDXLMainModelFieldInputInstance(fieldInstance) && + isSDXLMainModelFieldInputTemplate(fieldTemplate) + ) { + return ( + + ); + } + + if ( + isSchedulerFieldInputInstance(fieldInstance) && + isSchedulerFieldInputTemplate(fieldTemplate) + ) { + return ( + + ); + } + + if (fieldInstance && fieldTemplate) { // Fallback for when there is no component for the type return null; } @@ -255,7 +298,7 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { _dark: { color: 'error.300' }, }} > - {t('nodes.unknownFieldType')}: {field?.type} + {t('nodes.unknownFieldType')}: {fieldInstance?.type.name} ); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BoardInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BoardFieldInputComponent.tsx similarity index 82% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BoardInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BoardFieldInputComponent.tsx index a6e8cbb0c1..8f0f3260a6 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BoardInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BoardFieldInputComponent.tsx @@ -3,15 +3,15 @@ import { useAppDispatch } from 'app/store/storeHooks'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect'; import { fieldBoardValueChanged } from 'features/nodes/store/nodesSlice'; import { - BoardInputFieldTemplate, - BoardInputFieldValue, - FieldComponentProps, -} from 'features/nodes/types/types'; + BoardFieldInputTemplate, + BoardFieldInputInstance, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; import { memo, useCallback } from 'react'; import { useListAllBoardsQuery } from 'services/api/endpoints/boards'; -const BoardInputFieldComponent = ( - props: FieldComponentProps +const BoardFieldInputComponent = ( + props: FieldComponentProps ) => { const { nodeId, field } = props; const dispatch = useAppDispatch(); @@ -61,4 +61,4 @@ const BoardInputFieldComponent = ( ); }; -export default memo(BoardInputFieldComponent); +export default memo(BoardFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BooleanInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BooleanFieldInputComponent.tsx similarity index 65% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BooleanInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BooleanFieldInputComponent.tsx index d14756dbdb..3bac81b0f7 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BooleanInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/BooleanFieldInputComponent.tsx @@ -2,18 +2,16 @@ import { Switch } from '@chakra-ui/react'; import { useAppDispatch } from 'app/store/storeHooks'; import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice'; import { - BooleanInputFieldTemplate, - BooleanInputFieldValue, - BooleanPolymorphicInputFieldTemplate, - BooleanPolymorphicInputFieldValue, - FieldComponentProps, -} from 'features/nodes/types/types'; + BooleanFieldInputInstance, + BooleanFieldInputTemplate, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; import { ChangeEvent, memo, useCallback } from 'react'; -const BooleanInputFieldComponent = ( +const BooleanFieldInputComponent = ( props: FieldComponentProps< - BooleanInputFieldValue | BooleanPolymorphicInputFieldValue, - BooleanInputFieldTemplate | BooleanPolymorphicInputFieldTemplate + BooleanFieldInputInstance, + BooleanFieldInputTemplate > ) => { const { nodeId, field } = props; @@ -42,4 +40,4 @@ const BooleanInputFieldComponent = ( ); }; -export default memo(BooleanInputFieldComponent); +export default memo(BooleanFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ColorInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ColorFieldInputComponent.tsx similarity index 70% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ColorInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ColorFieldInputComponent.tsx index c2af279cb5..875bb06270 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ColorInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ColorFieldInputComponent.tsx @@ -1,15 +1,15 @@ import { useAppDispatch } from 'app/store/storeHooks'; import { fieldColorValueChanged } from 'features/nodes/store/nodesSlice'; import { - ColorInputFieldTemplate, - ColorInputFieldValue, - FieldComponentProps, -} from 'features/nodes/types/types'; + ColorFieldInputTemplate, + ColorFieldInputInstance, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; import { memo, useCallback } from 'react'; import { RgbaColor, RgbaColorPicker } from 'react-colorful'; -const ColorInputFieldComponent = ( - props: FieldComponentProps +const ColorFieldInputComponent = ( + props: FieldComponentProps ) => { const { nodeId, field } = props; @@ -37,4 +37,4 @@ const ColorInputFieldComponent = ( ); }; -export default memo(ColorInputFieldComponent); +export default memo(ColorFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx similarity index 87% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx index 804671204d..8604e6319e 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ControlNetModelFieldInputComponent.tsx @@ -3,20 +3,20 @@ import { useAppDispatch } from 'app/store/storeHooks'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { fieldControlNetModelValueChanged } from 'features/nodes/store/nodesSlice'; import { - ControlNetModelInputFieldTemplate, - ControlNetModelInputFieldValue, - FieldComponentProps, -} from 'features/nodes/types/types'; + ControlNetModelFieldInputTemplate, + ControlNetModelFieldInputInstance, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { modelIdToControlNetModelParam } from 'features/parameters/util/modelIdToControlNetModelParam'; import { forEach } from 'lodash-es'; import { memo, useCallback, useMemo } from 'react'; import { useGetControlNetModelsQuery } from 'services/api/endpoints/models'; -const ControlNetModelInputFieldComponent = ( +const ControlNetModelFieldInputComponent = ( props: FieldComponentProps< - ControlNetModelInputFieldValue, - ControlNetModelInputFieldTemplate + ControlNetModelFieldInputInstance, + ControlNetModelFieldInputTemplate > ) => { const { nodeId, field } = props; @@ -97,4 +97,4 @@ const ControlNetModelInputFieldComponent = ( ); }; -export default memo(ControlNetModelInputFieldComponent); +export default memo(ControlNetModelFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/EnumInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/EnumFieldInputComponent.tsx similarity index 77% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/EnumInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/EnumFieldInputComponent.tsx index 277020d847..e741afe964 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/EnumInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/EnumFieldInputComponent.tsx @@ -2,14 +2,14 @@ import { Select } from '@chakra-ui/react'; import { useAppDispatch } from 'app/store/storeHooks'; import { fieldEnumModelValueChanged } from 'features/nodes/store/nodesSlice'; import { - EnumInputFieldTemplate, - EnumInputFieldValue, - FieldComponentProps, -} from 'features/nodes/types/types'; + EnumFieldInputInstance, + EnumFieldInputTemplate, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; import { ChangeEvent, memo, useCallback } from 'react'; -const EnumInputFieldComponent = ( - props: FieldComponentProps +const EnumFieldInputComponent = ( + props: FieldComponentProps ) => { const { nodeId, field, fieldTemplate } = props; @@ -45,4 +45,4 @@ const EnumInputFieldComponent = ( ); }; -export default memo(EnumInputFieldComponent); +export default memo(EnumFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx similarity index 87% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx index 637fa79f60..0ec332cd5a 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/IPAdapterModelFieldInputComponent.tsx @@ -3,20 +3,20 @@ import { useAppDispatch } from 'app/store/storeHooks'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice'; import { - IPAdapterModelInputFieldTemplate, - IPAdapterModelInputFieldValue, - FieldComponentProps, -} from 'features/nodes/types/types'; + IPAdapterModelFieldInputTemplate, + IPAdapterModelFieldInputInstance, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { modelIdToIPAdapterModelParam } from 'features/parameters/util/modelIdToIPAdapterModelParams'; import { forEach } from 'lodash-es'; import { memo, useCallback, useMemo } from 'react'; import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models'; -const IPAdapterModelInputFieldComponent = ( +const IPAdapterModelFieldInputComponent = ( props: FieldComponentProps< - IPAdapterModelInputFieldValue, - IPAdapterModelInputFieldTemplate + IPAdapterModelFieldInputInstance, + IPAdapterModelFieldInputTemplate > ) => { const { nodeId, field } = props; @@ -97,4 +97,4 @@ const IPAdapterModelInputFieldComponent = ( ); }; -export default memo(IPAdapterModelInputFieldComponent); +export default memo(IPAdapterModelFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldInputComponent.tsx similarity index 87% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldInputComponent.tsx index 94095f2612..5feb870adc 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldInputComponent.tsx @@ -9,23 +9,18 @@ import { } from 'features/dnd/types'; import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice'; import { - FieldComponentProps, - ImageInputFieldTemplate, - ImageInputFieldValue, - ImagePolymorphicInputFieldTemplate, - ImagePolymorphicInputFieldValue, -} from 'features/nodes/types/types'; + ImageFieldInputInstance, + ImageFieldInputTemplate, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { FaUndo } from 'react-icons/fa'; import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { PostUploadAction } from 'services/api/types'; -const ImageInputFieldComponent = ( - props: FieldComponentProps< - ImageInputFieldValue | ImagePolymorphicInputFieldValue, - ImageInputFieldTemplate | ImagePolymorphicInputFieldTemplate - > +const ImageFieldInputComponent = ( + props: FieldComponentProps ) => { const { nodeId, field } = props; const dispatch = useAppDispatch(); @@ -102,7 +97,7 @@ const ImageInputFieldComponent = ( ); }; -export default memo(ImageInputFieldComponent); +export default memo(ImageFieldInputComponent); const UploadElement = memo(() => { const { t } = useTranslation(); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx similarity index 91% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx index dc79436ec6..fa2bada631 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/LoRAModelFieldInputComponent.tsx @@ -5,10 +5,10 @@ import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSe import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip'; import { fieldLoRAModelValueChanged } from 'features/nodes/store/nodesSlice'; import { - LoRAModelInputFieldTemplate, - LoRAModelInputFieldValue, - FieldComponentProps, -} from 'features/nodes/types/types'; + LoRAModelFieldInputTemplate, + LoRAModelFieldInputInstance, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { modelIdToLoRAModelParam } from 'features/parameters/util/modelIdToLoRAModelParam'; import { forEach } from 'lodash-es'; @@ -16,10 +16,10 @@ import { memo, useCallback, useMemo } from 'react'; import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; import { useTranslation } from 'react-i18next'; -const LoRAModelInputFieldComponent = ( +const LoRAModelFieldInputComponent = ( props: FieldComponentProps< - LoRAModelInputFieldValue, - LoRAModelInputFieldTemplate + LoRAModelFieldInputInstance, + LoRAModelFieldInputTemplate > ) => { const { nodeId, field } = props; @@ -121,4 +121,4 @@ const LoRAModelInputFieldComponent = ( ); }; -export default memo(LoRAModelInputFieldComponent); +export default memo(LoRAModelFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelFieldInputComponent.tsx similarity index 93% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelFieldInputComponent.tsx index af68b4291c..8c62548924 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/MainModelFieldInputComponent.tsx @@ -4,10 +4,10 @@ import { useAppDispatch } from 'app/store/storeHooks'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect'; import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice'; import { - MainModelInputFieldTemplate, - MainModelInputFieldValue, - FieldComponentProps, -} from 'features/nodes/types/types'; + MainModelFieldInputTemplate, + MainModelFieldInputInstance, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; @@ -21,10 +21,10 @@ import { } from 'services/api/endpoints/models'; import { useTranslation } from 'react-i18next'; -const MainModelInputFieldComponent = ( +const MainModelFieldInputComponent = ( props: FieldComponentProps< - MainModelInputFieldValue, - MainModelInputFieldTemplate + MainModelFieldInputInstance, + MainModelFieldInputTemplate > ) => { const { nodeId, field } = props; @@ -149,4 +149,4 @@ const MainModelInputFieldComponent = ( ); }; -export default memo(MainModelInputFieldComponent); +export default memo(MainModelFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberFieldInputComponent.tsx similarity index 71% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberFieldInputComponent.tsx index 2b2763ca3e..9daff53448 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberFieldInputComponent.tsx @@ -9,28 +9,18 @@ import { useAppDispatch } from 'app/store/storeHooks'; import { numberStringRegex } from 'common/components/IAINumberInput'; import { fieldNumberValueChanged } from 'features/nodes/store/nodesSlice'; import { - FieldComponentProps, - FloatInputFieldTemplate, - FloatInputFieldValue, - FloatPolymorphicInputFieldTemplate, - FloatPolymorphicInputFieldValue, - IntegerInputFieldTemplate, - IntegerInputFieldValue, - IntegerPolymorphicInputFieldTemplate, - IntegerPolymorphicInputFieldValue, -} from 'features/nodes/types/types'; + FloatFieldInputInstance, + FloatFieldInputTemplate, + IntegerFieldInputInstance, + IntegerFieldInputTemplate, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; import { memo, useCallback, useEffect, useMemo, useState } from 'react'; -const NumberInputFieldComponent = ( +const NumberFieldInputComponent = ( props: FieldComponentProps< - | IntegerInputFieldValue - | IntegerPolymorphicInputFieldValue - | FloatInputFieldValue - | FloatPolymorphicInputFieldValue, - | IntegerInputFieldTemplate - | IntegerPolymorphicInputFieldTemplate - | FloatInputFieldTemplate - | FloatPolymorphicInputFieldTemplate + IntegerFieldInputInstance | FloatFieldInputInstance, + IntegerFieldInputTemplate | FloatFieldInputTemplate > ) => { const { nodeId, field, fieldTemplate } = props; @@ -39,7 +29,7 @@ const NumberInputFieldComponent = ( String(field.value) ); const isIntegerField = useMemo( - () => fieldTemplate.type === 'integer', + () => fieldTemplate.type.name === 'IntegerField', [fieldTemplate.type] ); @@ -86,4 +76,4 @@ const NumberInputFieldComponent = ( ); }; -export default memo(NumberInputFieldComponent); +export default memo(NumberFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelFieldInputComponent.tsx similarity index 90% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelFieldInputComponent.tsx index e6db6031b8..42e63a8cb6 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/RefinerModelFieldInputComponent.tsx @@ -4,10 +4,10 @@ import { useAppDispatch } from 'app/store/storeHooks'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect'; import { fieldRefinerModelValueChanged } from 'features/nodes/store/nodesSlice'; import { - FieldComponentProps, - SDXLRefinerModelInputFieldTemplate, - SDXLRefinerModelInputFieldValue, -} from 'features/nodes/types/types'; + SDXLRefinerModelFieldInputTemplate, + SDXLRefinerModelFieldInputInstance, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; @@ -18,10 +18,10 @@ import { useTranslation } from 'react-i18next'; import { REFINER_BASE_MODELS } from 'services/api/constants'; import { useGetMainModelsQuery } from 'services/api/endpoints/models'; -const RefinerModelInputFieldComponent = ( +const RefinerModelFieldInputComponent = ( props: FieldComponentProps< - SDXLRefinerModelInputFieldValue, - SDXLRefinerModelInputFieldTemplate + SDXLRefinerModelFieldInputInstance, + SDXLRefinerModelFieldInputTemplate > ) => { const { nodeId, field } = props; @@ -120,4 +120,4 @@ const RefinerModelInputFieldComponent = ( ); }; -export default memo(RefinerModelInputFieldComponent); +export default memo(RefinerModelFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelFieldInputComponent.tsx similarity index 92% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelFieldInputComponent.tsx index c6ef5c6bb4..260f51ee8b 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SDXLMainModelFieldInputComponent.tsx @@ -4,10 +4,10 @@ import { useAppDispatch } from 'app/store/storeHooks'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect'; import { fieldMainModelValueChanged } from 'features/nodes/store/nodesSlice'; import { - SDXLMainModelInputFieldTemplate, - SDXLMainModelInputFieldValue, - FieldComponentProps, -} from 'features/nodes/types/types'; + SDXLMainModelFieldInputTemplate, + SDXLMainModelFieldInputInstance, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { modelIdToMainModelParam } from 'features/parameters/util/modelIdToMainModelParam'; import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; @@ -21,10 +21,10 @@ import { useGetOnnxModelsQuery, } from 'services/api/endpoints/models'; -const ModelInputFieldComponent = ( +const SDXLMainModelFieldInputComponent = ( props: FieldComponentProps< - SDXLMainModelInputFieldValue, - SDXLMainModelInputFieldTemplate + SDXLMainModelFieldInputInstance, + SDXLMainModelFieldInputTemplate > ) => { const { nodeId, field } = props; @@ -147,4 +147,4 @@ const ModelInputFieldComponent = ( ); }; -export default memo(ModelInputFieldComponent); +export default memo(SDXLMainModelFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SchedulerInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SchedulerFieldInputComponent.tsx similarity index 72% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SchedulerInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SchedulerFieldInputComponent.tsx index e4a3fb2a3d..a3b30f5057 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SchedulerInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/SchedulerFieldInputComponent.tsx @@ -5,14 +5,12 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect'; import { fieldSchedulerValueChanged } from 'features/nodes/store/nodesSlice'; import { - SchedulerInputFieldTemplate, - SchedulerInputFieldValue, - FieldComponentProps, -} from 'features/nodes/types/types'; -import { - SCHEDULER_LABEL_MAP, - SchedulerParam, -} from 'features/parameters/types/parameterSchemas'; + SchedulerFieldInputTemplate, + SchedulerFieldInputInstance, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; +import { ParameterScheduler } from 'features/parameters/types/parameterSchemas'; +import { SCHEDULER_LABEL_MAP } from 'features/parameters/types/constants'; import { map } from 'lodash-es'; import { memo, useCallback } from 'react'; @@ -24,7 +22,7 @@ const selector = createSelector( const data = map(SCHEDULER_LABEL_MAP, (label, name) => ({ value: name, label: label, - group: enabledSchedulers.includes(name as SchedulerParam) + group: enabledSchedulers.includes(name as ParameterScheduler) ? 'Favorites' : undefined, })).sort((a, b) => a.label.localeCompare(b.label)); @@ -36,10 +34,10 @@ const selector = createSelector( defaultSelectorOptions ); -const SchedulerInputField = ( +const SchedulerFieldInputComponent = ( props: FieldComponentProps< - SchedulerInputFieldValue, - SchedulerInputFieldTemplate + SchedulerFieldInputInstance, + SchedulerFieldInputTemplate > ) => { const { nodeId, field } = props; @@ -55,7 +53,7 @@ const SchedulerInputField = ( fieldSchedulerValueChanged({ nodeId, fieldName: field.name, - value: value as SchedulerParam, + value: value as ParameterScheduler, }) ); }, @@ -72,4 +70,4 @@ const SchedulerInputField = ( ); }; -export default memo(SchedulerInputField); +export default memo(SchedulerFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/StringInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/StringFieldInputComponent.tsx similarity index 70% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/StringInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/StringFieldInputComponent.tsx index 720722030b..50c8c487da 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/StringInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/StringFieldInputComponent.tsx @@ -3,19 +3,14 @@ import IAIInput from 'common/components/IAIInput'; import IAITextarea from 'common/components/IAITextarea'; import { fieldStringValueChanged } from 'features/nodes/store/nodesSlice'; import { - StringInputFieldTemplate, - StringInputFieldValue, - FieldComponentProps, - StringPolymorphicInputFieldValue, - StringPolymorphicInputFieldTemplate, -} from 'features/nodes/types/types'; + StringFieldInputInstance, + StringFieldInputTemplate, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; import { ChangeEvent, memo, useCallback } from 'react'; -const StringInputFieldComponent = ( - props: FieldComponentProps< - StringInputFieldValue | StringPolymorphicInputFieldValue, - StringInputFieldTemplate | StringPolymorphicInputFieldTemplate - > +const StringFieldInputComponent = ( + props: FieldComponentProps ) => { const { nodeId, field, fieldTemplate } = props; const dispatch = useAppDispatch(); @@ -48,4 +43,4 @@ const StringInputFieldComponent = ( return ; }; -export default memo(StringInputFieldComponent); +export default memo(StringFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx similarity index 87% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx index f5ae6b747a..03b3cba4f0 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/T2IAdapterModelFieldInputComponent.tsx @@ -3,20 +3,20 @@ import { useAppDispatch } from 'app/store/storeHooks'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice'; import { - T2IAdapterModelInputFieldTemplate, - T2IAdapterModelInputFieldValue, - FieldComponentProps, -} from 'features/nodes/types/types'; + T2IAdapterModelFieldInputInstance, + T2IAdapterModelFieldInputTemplate, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { modelIdToT2IAdapterModelParam } from 'features/parameters/util/modelIdToT2IAdapterModelParam'; import { forEach } from 'lodash-es'; import { memo, useCallback, useMemo } from 'react'; import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models'; -const T2IAdapterModelInputFieldComponent = ( +const T2IAdapterModelFieldInputComponent = ( props: FieldComponentProps< - T2IAdapterModelInputFieldValue, - T2IAdapterModelInputFieldTemplate + T2IAdapterModelFieldInputInstance, + T2IAdapterModelFieldInputTemplate > ) => { const { nodeId, field } = props; @@ -97,4 +97,4 @@ const T2IAdapterModelInputFieldComponent = ( ); }; -export default memo(T2IAdapterModelInputFieldComponent); +export default memo(T2IAdapterModelFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VaeModelInputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx similarity index 89% rename from invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VaeModelInputField.tsx rename to invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx index 79ada94c3e..93d397b202 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VaeModelInputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/VAEModelFieldInputComponent.tsx @@ -4,20 +4,20 @@ import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSe import IAIMantineSelectItemWithTooltip from 'common/components/IAIMantineSelectItemWithTooltip'; import { fieldVaeModelValueChanged } from 'features/nodes/store/nodesSlice'; import { - FieldComponentProps, - VaeModelInputFieldTemplate, - VaeModelInputFieldValue, -} from 'features/nodes/types/types'; + VAEModelFieldInputTemplate, + VAEModelFieldInputInstance, +} from 'features/nodes/types/field'; +import { FieldComponentProps } from './types'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { modelIdToVAEModelParam } from 'features/parameters/util/modelIdToVAEModelParam'; import { forEach } from 'lodash-es'; import { memo, useCallback, useMemo } from 'react'; import { useGetVaeModelsQuery } from 'services/api/endpoints/models'; -const VaeModelInputFieldComponent = ( +const VAEModelFieldInputComponent = ( props: FieldComponentProps< - VaeModelInputFieldValue, - VaeModelInputFieldTemplate + VAEModelFieldInputInstance, + VAEModelFieldInputTemplate > ) => { const { nodeId, field } = props; @@ -105,4 +105,4 @@ const VaeModelInputFieldComponent = ( ); }; -export default memo(VaeModelInputFieldComponent); +export default memo(VAEModelFieldInputComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/types.ts b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/types.ts new file mode 100644 index 0000000000..22c488c16f --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/types.ts @@ -0,0 +1,13 @@ +import { + FieldInputInstance, + FieldInputTemplate, +} from 'features/nodes/types/field'; + +export type FieldComponentProps< + V extends FieldInputInstance, + T extends FieldInputTemplate, +> = { + nodeId: string; + field: V; + fieldTemplate: T; +}; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Notes/NotesNode.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Notes/NotesNode.tsx index ec869f3dad..bbbb7b7372 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Notes/NotesNode.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Notes/NotesNode.tsx @@ -2,7 +2,7 @@ import { Box, Flex } from '@chakra-ui/react'; import { useAppDispatch } from 'app/store/storeHooks'; import IAITextarea from 'common/components/IAITextarea'; import { notesNodeValueChanged } from 'features/nodes/store/nodesSlice'; -import { NotesNodeData } from 'features/nodes/types/types'; +import { NotesNodeData } from 'features/nodes/types/invocation'; import { ChangeEvent, memo, useCallback } from 'react'; import { NodeProps } from 'reactflow'; import NodeWrapper from '../common/NodeWrapper'; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeWrapper.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeWrapper.tsx index 79de65760f..b6ccd4ae9f 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeWrapper.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeWrapper.tsx @@ -14,7 +14,7 @@ import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH, } from 'features/nodes/types/constants'; -import { NodeStatus } from 'features/nodes/types/types'; +import { zNodeStatus } from 'features/nodes/types/invocation'; import { contextMenusClosed } from 'features/ui/store/uiSlice'; import { MouseEvent, @@ -40,7 +40,8 @@ const NodeWrapper = (props: NodeWrapperProps) => { createSelector( stateSelector, ({ nodes }) => - nodes.nodeExecutionStates[nodeId]?.status === NodeStatus.IN_PROGRESS + nodes.nodeExecutionStates[nodeId]?.status === + zNodeStatus.enum.IN_PROGRESS ), [nodeId] ); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopCenterPanel/LoadWorkflowButton.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopCenterPanel/LoadWorkflowButton.tsx index 8454f5539f..eb593ee1db 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopCenterPanel/LoadWorkflowButton.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopCenterPanel/LoadWorkflowButton.tsx @@ -8,7 +8,7 @@ import { FaUpload } from 'react-icons/fa'; const LoadWorkflowButton = () => { const { t } = useTranslation(); const resetRef = useRef<() => void>(null); - const loadWorkflowFromFile = useLoadWorkflowFromFile(); + const loadWorkflowFromFile = useLoadWorkflowFromFile(resetRef); return ( { - return ( - - {map(FIELDS, ({ title, description, color }, key) => ( - - - {title} - - - ))} - - ); -}; - -export default memo(FieldTypeLegend); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopRightPanel/TopRightPanel.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopRightPanel/TopRightPanel.tsx index db8f544c2e..c289ea02dd 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopRightPanel/TopRightPanel.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopRightPanel/TopRightPanel.tsx @@ -1,18 +1,11 @@ import { Flex } from '@chakra-ui/layout'; -import { useAppSelector } from 'app/store/storeHooks'; import { memo } from 'react'; -import FieldTypeLegend from './FieldTypeLegend'; import WorkflowEditorSettings from './WorkflowEditorSettings'; const TopRightPanel = () => { - const shouldShowFieldTypeLegend = useAppSelector( - (state) => state.nodes.shouldShowFieldTypeLegend - ); - return ( - {shouldShowFieldTypeLegend && } ); }; diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx index d906557dd3..ecbe538fcc 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx @@ -10,17 +10,15 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import IAIIconButton from 'common/components/IAIIconButton'; import { IAINoContentFallback } from 'common/components/IAIImageFallback'; -import { useNodeVersion } from 'features/nodes/hooks/useNodeVersion'; +import { getNeedsUpdate } from 'features/nodes/store/util/nodeUpdate'; import { InvocationNodeData, InvocationTemplate, isInvocationNode, -} from 'features/nodes/types/types'; -import { memo } from 'react'; +} from 'features/nodes/types/invocation'; +import { memo, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -import { FaSync } from 'react-icons/fa'; import { Node } from 'reactflow'; import NotesTextarea from '../../flow/nodes/Invocation/NotesTextarea'; import ScrollableContent from '../ScrollableContent'; @@ -63,12 +61,17 @@ const InspectorDetailsTab = () => { export default memo(InspectorDetailsTab); -const Content = (props: { +type ContentProps = { node: Node; template: InvocationTemplate; -}) => { +}; + +const Content = memo(({ node, template }: ContentProps) => { const { t } = useTranslation(); - const { needsUpdate, updateNode } = useNodeVersion(props.node.id); + const needsUpdate = useMemo( + () => getNeedsUpdate(node, template), + [node, template] + ); return ( - + {t('nodes.nodeType')} - {props.template.title} + {template.title} {t('nodes.nodeVersion')} - {props.node.data.version} + {node.data.version} - {needsUpdate && ( - } - onClick={updateNode} - /> - )} - + ); -}; +}); + +Content.displayName = 'Content'; diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx index f4abc621b4..265d369f5a 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx @@ -5,7 +5,7 @@ import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer'; -import { isInvocationNode } from 'features/nodes/types/types'; +import { isInvocationNode } from 'features/nodes/types/invocation'; import { memo } from 'react'; import { ImageOutput } from 'services/api/types'; import { AnyResult } from 'services/events/types'; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts index dda2efc156..ccfa0f57fd 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts @@ -2,14 +2,11 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { map } from 'lodash-es'; +import { keys, map } from 'lodash-es'; import { useMemo } from 'react'; -import { isInvocationNode } from '../types/types'; -import { - POLYMORPHIC_TYPES, - TYPES_WITH_INPUT_COMPONENTS, -} from '../types/constants'; +import { isInvocationNode } from '../types/invocation'; import { getSortedFilteredFieldNames } from '../util/getSortedFilteredFieldNames'; +import { TEMPLATE_BUILDER_MAP } from '../util/buildFieldInputTemplate'; export const useAnyOrDirectInputFieldNames = (nodeId: string) => { const selector = useMemo( @@ -28,8 +25,8 @@ export const useAnyOrDirectInputFieldNames = (nodeId: string) => { const fields = map(nodeTemplate.inputs).filter( (field) => (['any', 'direct'].includes(field.input) || - POLYMORPHIC_TYPES.includes(field.type)) && - TYPES_WITH_INPUT_COMPONENTS.includes(field.type) + field.type.isPolymorphic) && + keys(TEMPLATE_BUILDER_MAP).includes(field.type.name) ); return getSortedFilteredFieldNames(fields); }, diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNodeData.ts b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNodeData.ts index 036ce8d44e..694261d943 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNodeData.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNodeData.ts @@ -3,10 +3,13 @@ import { RootState } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { useCallback } from 'react'; import { Node, useReactFlow } from 'reactflow'; -import { AnyInvocationType } from 'services/events/types'; -import { buildNodeData } from '../store/util/buildNodeData'; +import { + buildCurrentImageNode, + buildInvocationNode, + buildNotesNode, +} from '../store/util/buildNodeData'; import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from '../types/constants'; - +import { AnyNodeData, InvocationTemplate } from '../types/invocation'; const templatesSelector = createSelector( [(state: RootState) => state.nodes], (nodes) => nodes.nodeTemplates @@ -22,7 +25,8 @@ export const useBuildNodeData = () => { const flow = useReactFlow(); return useCallback( - (type: AnyInvocationType | 'current_image' | 'notes') => { + // string here is "any invocation type" + (type: string | 'current_image' | 'notes'): Node => { let _x = window.innerWidth / 2; let _y = window.innerHeight / 2; @@ -41,9 +45,19 @@ export const useBuildNodeData = () => { y: _y, }); - const template = nodeTemplates[type]; + if (type === 'current_image') { + return buildCurrentImageNode(position); + } - return buildNodeData(type, position, template); + if (type === 'notes') { + return buildNotesNode(position); + } + + // TODO: Keep track of invocation types so we do not need to cast this + // We know it is safe because the caller of this function gets the `type` arg from the list of invocation templates. + const template = nodeTemplates[type] as InvocationTemplate; + + return buildInvocationNode(position, template); }, [nodeTemplates, flow] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts index 9fb31df801..2951167944 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts @@ -2,14 +2,11 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { map } from 'lodash-es'; +import { keys, map } from 'lodash-es'; import { useMemo } from 'react'; -import { - POLYMORPHIC_TYPES, - TYPES_WITH_INPUT_COMPONENTS, -} from '../types/constants'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; import { getSortedFilteredFieldNames } from '../util/getSortedFilteredFieldNames'; +import { TEMPLATE_BUILDER_MAP } from '../util/buildFieldInputTemplate'; export const useConnectionInputFieldNames = (nodeId: string) => { const selector = useMemo( @@ -29,9 +26,8 @@ export const useConnectionInputFieldNames = (nodeId: string) => { // get the visible fields const fields = map(nodeTemplate.inputs).filter( (field) => - (field.input === 'connection' && - !POLYMORPHIC_TYPES.includes(field.type)) || - !TYPES_WITH_INPUT_COMPONENTS.includes(field.type) + (field.input === 'connection' && !field.type.isPolymorphic) || + !keys(TEMPLATE_BUILDER_MAP).includes(field.type.name) ); return getSortedFilteredFieldNames(fields); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts index 96b2d652e9..cc3b2ce7ac 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts @@ -8,7 +8,7 @@ import { useFieldType } from './useFieldType.ts'; const selectIsConnectionInProgress = createSelector( stateSelector, ({ nodes }) => - nodes.currentConnectionFieldType !== null && + nodes.connectionStartFieldType !== null && nodes.connectionStartParams !== null ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts b/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts index 926c56ac1e..82db025a55 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts @@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { compareVersions } from 'compare-versions'; import { useMemo } from 'react'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; export const useDoNodeVersionsMatch = (nodeId: string) => { const selector = useMemo( diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts b/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts index 83bf6b8af0..e5264c07c4 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useDoesInputHaveValue.ts @@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { useMemo } from 'react'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; export const useDoesInputHaveValue = (nodeId: string, fieldName: string) => { const selector = useMemo( diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useEmbedWorkflow.ts b/invokeai/frontend/web/src/features/nodes/hooks/useEmbedWorkflow.ts index 866d8ab970..863bc64718 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useEmbedWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useEmbedWorkflow.ts @@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { useMemo } from 'react'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; export const useEmbedWorkflow = (nodeId: string) => { const selector = useMemo( diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldData.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldData.ts index ba2c4e2d5c..7cdd44e4fd 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldData.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldData.ts @@ -3,9 +3,9 @@ import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { useMemo } from 'react'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; -export const useFieldData = (nodeId: string, fieldName: string) => { +export const useFieldInstance = (nodeId: string, fieldName: string) => { const selector = useMemo( () => createSelector( diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts index 159815a6a6..82f90531dd 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts @@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { useMemo } from 'react'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; export const useFieldInputKind = (nodeId: string, fieldName: string) => { const selector = useMemo( diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldLabel.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldLabel.ts index fcf33c3427..cabef729ae 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldLabel.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldLabel.ts @@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { useMemo } from 'react'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; export const useFieldLabel = (nodeId: string, fieldName: string) => { const selector = useMemo( diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts index 93d545aaea..a18a027c6b 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts @@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { useMemo } from 'react'; import { KIND_MAP } from '../types/constants'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; export const useFieldTemplate = ( nodeId: string, diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts index 923c25cc18..faec9c1ff3 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts @@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { useMemo } from 'react'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; import { KIND_MAP } from '../types/constants'; export const useFieldTemplateTitle = ( diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts index f4d78f8954..0775c32cb2 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldType.ts.ts @@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { useMemo } from 'react'; import { KIND_MAP } from '../types/constants'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; export const useFieldType = ( nodeId: string, @@ -20,7 +20,8 @@ export const useFieldType = ( if (!isInvocationNode(node)) { return; } - return node?.data[KIND_MAP[kind]][fieldName]?.type; + const field = node.data[KIND_MAP[kind]][fieldName]; + return field?.type; }, defaultSelectorOptions ), diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts index a413de38ae..c22c0d9505 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts @@ -2,7 +2,8 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { getNeedsUpdate } from './useNodeVersion'; +import { getNeedsUpdate } from '../store/util/nodeUpdate'; +import { isInvocationNode } from '../types/invocation'; const selector = createSelector( stateSelector, @@ -10,8 +11,11 @@ const selector = createSelector( const nodes = state.nodes.nodes; const templates = state.nodes.nodeTemplates; - const needsUpdate = nodes.some((node) => { + const needsUpdate = nodes.filter(isInvocationNode).some((node) => { const template = templates[node.data.type]; + if (!template) { + return false; + } return getNeedsUpdate(node, template); }); return needsUpdate; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useHasImageOutput.ts b/invokeai/frontend/web/src/features/nodes/hooks/useHasImageOutput.ts index 111e48a45f..6b99d75ef0 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useHasImageOutput.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useHasImageOutput.ts @@ -4,8 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { some } from 'lodash-es'; import { useMemo } from 'react'; -import { IMAGE_FIELDS } from '../types/constants'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; export const useHasImageOutput = (nodeId: string) => { const selector = useMemo( @@ -20,8 +19,8 @@ export const useHasImageOutput = (nodeId: string) => { return some( node.data.outputs, (output) => - IMAGE_FIELDS.includes(output.type) && - // the image primitive node does not actually save the image, do not show the image-saving checkboxes + output.type.name === 'ImageField' && + // the image primitive node (node type "image") does not actually save the image, do not show the image-saving checkboxes node.data.type !== 'image' ); }, diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsIntermediate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsIntermediate.ts index 86b9371b03..167610c14f 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsIntermediate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsIntermediate.ts @@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { useMemo } from 'react'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; export const useIsIntermediate = (nodeId: string) => { const selector = useMemo( diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts index c88d4758af..028b238c7b 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts @@ -4,7 +4,7 @@ import { useCallback } from 'react'; import { Connection, Node, useReactFlow } from 'reactflow'; import { validateSourceAndTargetTypes } from '../store/util/validateSourceAndTargetTypes'; import { getIsGraphAcyclic } from '../store/util/getIsGraphAcyclic'; -import { InvocationNodeData } from '../types/types'; +import { InvocationNodeData } from '../types/invocation'; /** * NOTE: The logic here must be duplicated in `invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts` @@ -34,10 +34,10 @@ export const useIsValidConnection = () => { return false; } - const sourceType = sourceNode.data.outputs[sourceHandle]?.type; - const targetType = targetNode.data.inputs[targetHandle]?.type; + const sourceField = sourceNode.data.outputs[sourceHandle]; + const targetField = targetNode.data.inputs[targetHandle]; - if (!sourceType || !targetType) { + if (!sourceField || !targetField) { // something has gone terribly awry return false; } @@ -70,12 +70,13 @@ export const useIsValidConnection = () => { return edge.target === target && edge.targetHandle === targetHandle; }) && // except CollectionItem inputs can have multiples - targetType !== 'CollectionItem' + targetField.type.name !== 'CollectionItemField' ) { return false; } - if (!validateSourceAndTargetTypes(sourceType, targetType)) { + // Must use the originalType here if it exists + if (!validateSourceAndTargetTypes(sourceField.type, targetField.type)) { return false; } diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useLoadWorkflowFromFile.tsx b/invokeai/frontend/web/src/features/nodes/hooks/useLoadWorkflowFromFile.tsx index 890fa7a72d..3646e8dc58 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useLoadWorkflowFromFile.tsx +++ b/invokeai/frontend/web/src/features/nodes/hooks/useLoadWorkflowFromFile.tsx @@ -1,17 +1,15 @@ import { ListItem, Text, UnorderedList } from '@chakra-ui/react'; import { useLogger } from 'app/logging/useLogger'; import { useAppDispatch } from 'app/store/storeHooks'; -import { parseify } from 'common/util/serialize'; -import { zWorkflow } from 'features/nodes/types/types'; import { addToast } from 'features/system/store/systemSlice'; import { makeToast } from 'features/system/util/makeToast'; -import { memo, useCallback } from 'react'; -import { ZodError } from 'zod'; -import { fromZodError, fromZodIssue } from 'zod-validation-error'; -import { workflowLoadRequested } from '../store/actions'; +import { RefObject, memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; +import { ZodError } from 'zod'; +import { fromZodIssue } from 'zod-validation-error'; +import { workflowLoadRequested } from '../store/actions'; -export const useLoadWorkflowFromFile = () => { +export const useLoadWorkflowFromFile = (resetRef: RefObject<() => void>) => { const dispatch = useAppDispatch(); const logger = useLogger('nodes'); const { t } = useTranslation(); @@ -26,33 +24,10 @@ export const useLoadWorkflowFromFile = () => { try { const parsedJSON = JSON.parse(String(rawJSON)); - const result = zWorkflow.safeParse(parsedJSON); - - if (!result.success) { - const { message } = fromZodError(result.error, { - prefix: t('nodes.workflowValidation'), - }); - - logger.error({ error: parseify(result.error) }, message); - - dispatch( - addToast( - makeToast({ - title: t('nodes.unableToValidateWorkflow'), - status: 'error', - duration: 5000, - }) - ) - ); - reader.abort(); - return; - } - - dispatch(workflowLoadRequested(result.data)); - - reader.abort(); - } catch { - // file reader error + dispatch(workflowLoadRequested(parsedJSON)); + } catch (e) { + // There was a problem reading the file + logger.error(t('nodes.unableToLoadWorkflow')); dispatch( addToast( makeToast({ @@ -61,12 +36,15 @@ export const useLoadWorkflowFromFile = () => { }) ) ); + reader.abort(); } }; reader.readAsText(file); + // Reset the file picker internal state so that the same file can be loaded again + resetRef.current?.(); }, - [dispatch, logger, t] + [dispatch, logger, resetRef, t] ); return loadWorkflowFromFile; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts index f9bbe4cc1d..edce18b52b 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeLabel.ts @@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { useMemo } from 'react'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; export const useNodeLabel = (nodeId: string) => { const selector = useMemo( diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts new file mode 100644 index 0000000000..99a7c47170 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts @@ -0,0 +1,35 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import { useMemo } from 'react'; +import { isInvocationNode } from '../types/invocation'; +import { getNeedsUpdate } from '../store/util/nodeUpdate'; + +export const useNodeNeedsUpdate = (nodeId: string) => { + const selector = useMemo( + () => + createSelector( + stateSelector, + ({ nodes }) => { + const node = nodes.nodes.find((node) => node.id === nodeId); + const template = nodes.nodeTemplates[node?.data.type ?? '']; + return { node, template }; + }, + defaultSelectorOptions + ), + [nodeId] + ); + + const { node, template } = useAppSelector(selector); + + const needsUpdate = useMemo( + () => + isInvocationNode(node) && template + ? getNeedsUpdate(node, template) + : false, + [node, template] + ); + + return needsUpdate; +}; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts index 6fd0615563..83012d0830 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts @@ -3,16 +3,14 @@ import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { useMemo } from 'react'; -import { AnyInvocationType } from 'services/events/types'; +import { InvocationTemplate } from '../types/invocation'; -export const useNodeTemplateByType = ( - type: AnyInvocationType | 'current_image' | 'notes' -) => { +export const useNodeTemplateByType = (type: string) => { const selector = useMemo( () => createSelector( stateSelector, - ({ nodes }) => { + ({ nodes }): InvocationTemplate | undefined => { const nodeTemplate = nodes.nodeTemplates[type]; return nodeTemplate; }, diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts index 4ef3eed5d9..c3dc150735 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts @@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { useMemo } from 'react'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; export const useNodeTemplateTitle = (nodeId: string) => { const selector = useMemo( diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeVersion.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeVersion.ts deleted file mode 100644 index 1f213d6481..0000000000 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeVersion.ts +++ /dev/null @@ -1,119 +0,0 @@ -import { createSelector } from '@reduxjs/toolkit'; -import { stateSelector } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { satisfies } from 'compare-versions'; -import { cloneDeep, defaultsDeep } from 'lodash-es'; -import { useCallback, useMemo } from 'react'; -import { Node } from 'reactflow'; -import { AnyInvocationType } from 'services/events/types'; -import { nodeReplaced } from '../store/nodesSlice'; -import { buildNodeData } from '../store/util/buildNodeData'; -import { - InvocationNodeData, - InvocationTemplate, - NodeData, - isInvocationNode, - zParsedSemver, -} from '../types/types'; -import { useAppToaster } from 'app/components/Toaster'; -import { useTranslation } from 'react-i18next'; - -export const getNeedsUpdate = ( - node?: Node, - template?: InvocationTemplate -) => { - if (!isInvocationNode(node) || !template) { - return false; - } - return node.data.version !== template.version; -}; - -export const getMayUpdateNode = ( - node?: Node, - template?: InvocationTemplate -) => { - const needsUpdate = getNeedsUpdate(node, template); - if ( - !needsUpdate || - !isInvocationNode(node) || - !template || - !node.data.version - ) { - return false; - } - const templateMajor = zParsedSemver.parse(template.version).major; - - return satisfies(node.data.version, `^${templateMajor}`); -}; - -export const updateNode = ( - node?: Node, - template?: InvocationTemplate -) => { - const mayUpdate = getMayUpdateNode(node, template); - if ( - !mayUpdate || - !isInvocationNode(node) || - !template || - !node.data.version - ) { - return; - } - - const defaults = buildNodeData( - node.data.type as AnyInvocationType, - node.position, - template - ) as Node; - - const clone = cloneDeep(node); - clone.data.version = template.version; - defaultsDeep(clone, defaults); - return clone; -}; - -export const useNodeVersion = (nodeId: string) => { - const dispatch = useAppDispatch(); - const toast = useAppToaster(); - const { t } = useTranslation(); - const selector = useMemo( - () => - createSelector( - stateSelector, - ({ nodes }) => { - const node = nodes.nodes.find((node) => node.id === nodeId); - const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? '']; - return { node, nodeTemplate }; - }, - defaultSelectorOptions - ), - [nodeId] - ); - - const { node, nodeTemplate } = useAppSelector(selector); - - const needsUpdate = useMemo( - () => getNeedsUpdate(node, nodeTemplate), - [node, nodeTemplate] - ); - - const mayUpdate = useMemo( - () => getMayUpdateNode(node, nodeTemplate), - [node, nodeTemplate] - ); - - const _updateNode = useCallback(() => { - const needsUpdate = getNeedsUpdate(node, nodeTemplate); - const updatedNode = updateNode(node, nodeTemplate); - if (!updatedNode) { - if (needsUpdate) { - toast({ title: t('nodes.unableToUpdateNodes', { count: 1 }) }); - } - return; - } - dispatch(nodeReplaced({ nodeId: updatedNode.id, node: updatedNode })); - }, [dispatch, node, nodeTemplate, t, toast]); - - return { needsUpdate, mayUpdate, updateNode: _updateNode }; -}; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts index e0a1e5433e..93e4ccb833 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts @@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { map } from 'lodash-es'; import { useMemo } from 'react'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; import { getSortedFilteredFieldNames } from '../util/getSortedFilteredFieldNames'; export const useOutputFieldNames = (nodeId: string) => { diff --git a/invokeai/frontend/web/src/features/nodes/hooks/usePrettyFieldType.ts b/invokeai/frontend/web/src/features/nodes/hooks/usePrettyFieldType.ts new file mode 100644 index 0000000000..bff5873864 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/hooks/usePrettyFieldType.ts @@ -0,0 +1,23 @@ +import { useTranslation } from 'react-i18next'; +import { FieldType } from '../types/field'; +import { useMemo } from 'react'; + +export const useFieldTypeName = (fieldType?: FieldType): string => { + const { t } = useTranslation(); + + const name = useMemo(() => { + if (!fieldType) { + return ''; + } + const { name } = fieldType; + if (fieldType.isCollection) { + return t('nodes.collectionFieldType', { name }); + } + if (fieldType.isPolymorphic) { + return t('nodes.polymorphicFieldType', { name }); + } + return name; + }, [fieldType, t]); + + return name; +}; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useUseCache.ts b/invokeai/frontend/web/src/features/nodes/hooks/useUseCache.ts index 7416d7e66e..e05cbd818f 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useUseCache.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useUseCache.ts @@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { useMemo } from 'react'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; export const useUseCache = (nodeId: string) => { const selector = useMemo( diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useWithWorkflow.ts b/invokeai/frontend/web/src/features/nodes/hooks/useWithWorkflow.ts index 3c83e01731..c495c54974 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useWithWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useWithWorkflow.ts @@ -3,7 +3,7 @@ import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { useMemo } from 'react'; -import { isInvocationNode } from '../types/types'; +import { isInvocationNode } from '../types/invocation'; export const useWithWorkflow = (nodeId: string) => { const selector = useMemo( diff --git a/invokeai/frontend/web/src/features/nodes/store/actions.ts b/invokeai/frontend/web/src/features/nodes/store/actions.ts index 0d75e6934d..5dd5344a99 100644 --- a/invokeai/frontend/web/src/features/nodes/store/actions.ts +++ b/invokeai/frontend/web/src/features/nodes/store/actions.ts @@ -1,6 +1,5 @@ import { createAction, isAnyOf } from '@reduxjs/toolkit'; import { Graph } from 'services/api/types'; -import { Workflow } from '../types/types'; export const textToImageGraphBuilt = createAction( 'nodes/textToImageGraphBuilt' @@ -18,7 +17,7 @@ export const isAnyGraphBuilt = isAnyOf( nodesGraphBuilt ); -export const workflowLoadRequested = createAction( +export const workflowLoadRequested = createAction( 'nodes/workflowLoadRequested' ); diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesPersistDenylist.ts b/invokeai/frontend/web/src/features/nodes/store/nodesPersistDenylist.ts index 64fee2293f..1322bafa43 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesPersistDenylist.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesPersistDenylist.ts @@ -6,7 +6,7 @@ import { NodesState } from './types'; export const nodesPersistDenylist: (keyof NodesState)[] = [ 'nodeTemplates', 'connectionStartParams', - 'currentConnectionFieldType', + 'connectionStartFieldType', 'selectedNodes', 'selectedEdges', 'isReady', diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 3acef5978f..0c21d02fed 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -20,7 +20,6 @@ import { XYPosition, } from 'reactflow'; import { receivedOpenAPISchema } from 'services/api/thunks/schema'; -import { ImageField } from 'services/api/types'; import { appSocketGeneratorProgress, appSocketInvocationComplete, @@ -31,60 +30,58 @@ import { import { v4 as uuidv4 } from 'uuid'; import { DRAG_HANDLE_CLASSNAME } from '../types/constants'; import { - BoardInputFieldValue, - BooleanInputFieldValue, - ColorInputFieldValue, - ControlNetModelInputFieldValue, - CurrentImageNodeData, - EnumInputFieldValue, + BoardFieldValue, + BooleanFieldValue, + ColorFieldValue, + ControlNetModelFieldValue, + EnumFieldValue, FieldIdentifier, - FloatInputFieldValue, - ImageInputFieldValue, - InputFieldValue, - IntegerInputFieldValue, - InvocationNodeData, + FieldValue, + FloatFieldValue, + ImageFieldValue, + IntegerFieldValue, + IPAdapterModelFieldValue, + LoRAModelFieldValue, + MainModelFieldValue, + SchedulerFieldValue, + SDXLRefinerModelFieldValue, + StringFieldValue, + T2IAdapterModelFieldValue, + VAEModelFieldValue, +} from '../types/field'; +import { + AnyNodeData, InvocationTemplate, - IPAdapterModelInputFieldValue, isInvocationNode, isNotesNode, - LoRAModelInputFieldValue, - MainModelInputFieldValue, NodeExecutionState, - NodeStatus, - NotesNodeData, - SchedulerInputFieldValue, - SDXLRefinerModelInputFieldValue, - StringInputFieldValue, - T2IAdapterModelInputFieldValue, - VaeModelInputFieldValue, - Workflow, -} from '../types/types'; + zNodeStatus, +} from '../types/invocation'; +import { WorkflowV2 } from '../types/workflow'; import { NodesState } from './types'; -import { findUnoccupiedPosition } from './util/findUnoccupiedPosition'; import { findConnectionToValidHandle } from './util/findConnectionToValidHandle'; - -export const WORKFLOW_FORMAT_VERSION = '1.0.0'; +import { findUnoccupiedPosition } from './util/findUnoccupiedPosition'; const initialNodeExecutionState: Omit = { - status: NodeStatus.PENDING, + status: zNodeStatus.enum.PENDING, error: null, progress: null, progressImage: null, outputs: [], }; -export const initialWorkflow = { - meta: { - version: WORKFLOW_FORMAT_VERSION, - }, +const INITIAL_WORKFLOW: WorkflowV2 = { name: '', author: '', description: '', - notes: '', - tags: '', - contact: '', version: '', + contact: '', + tags: '', + notes: '', + nodes: [], + edges: [], exposedFields: [], + meta: { version: '2.0.0' }, }; export const initialNodesState: NodesState = { @@ -93,11 +90,10 @@ export const initialNodesState: NodesState = { nodeTemplates: {}, isReady: false, connectionStartParams: null, - currentConnectionFieldType: null, + connectionStartFieldType: null, connectionMade: false, modifyingEdge: false, addNewNodePosition: null, - shouldShowFieldTypeLegend: false, shouldShowMinimapPanel: true, shouldValidateGraph: true, shouldAnimateEdges: true, @@ -107,7 +103,7 @@ export const initialNodesState: NodesState = { nodeOpacity: 1, selectedNodes: [], selectedEdges: [], - workflow: initialWorkflow, + workflow: INITIAL_WORKFLOW, nodeExecutionStates: {}, viewport: { x: 0, y: 0, zoom: 1 }, mouseOverField: null, @@ -117,13 +113,13 @@ export const initialNodesState: NodesState = { selectionMode: SelectionMode.Partial, }; -type FieldValueAction = PayloadAction<{ +type FieldValueAction = PayloadAction<{ nodeId: string; fieldName: string; - value: T['value']; + value: T; }>; -const fieldValueReducer = ( +const fieldValueReducer = ( state: NodesState, action: FieldValueAction ) => { @@ -161,12 +157,7 @@ const nodesSlice = createSlice({ } state.nodes[nodeIndex] = action.payload.node; }, - nodeAdded: ( - state, - action: PayloadAction< - Node - > - ) => { + nodeAdded: (state, action: PayloadAction>) => { const node = action.payload; const position = findUnoccupiedPosition( state.nodes, @@ -203,7 +194,7 @@ const nodesSlice = createSlice({ nodeId && handleId && handleType && - state.currentConnectionFieldType + state.connectionStartFieldType ) { const newConnection = findConnectionToValidHandle( node, @@ -212,7 +203,7 @@ const nodesSlice = createSlice({ nodeId, handleId, handleType, - state.currentConnectionFieldType + state.connectionStartFieldType ); if (newConnection) { state.edges = addEdge( @@ -224,7 +215,7 @@ const nodesSlice = createSlice({ } state.connectionStartParams = null; - state.currentConnectionFieldType = null; + state.connectionStartFieldType = null; }, edgeChangeStarted: (state) => { state.modifyingEdge = true; @@ -258,10 +249,10 @@ const nodesSlice = createSlice({ handleType === 'source' ? node.data.outputs[handleId] : node.data.inputs[handleId]; - state.currentConnectionFieldType = field?.type ?? null; + state.connectionStartFieldType = field?.type ?? null; }, connectionMade: (state, action: PayloadAction) => { - const fieldType = state.currentConnectionFieldType; + const fieldType = state.connectionStartFieldType; if (!fieldType) { return; } @@ -286,7 +277,7 @@ const nodesSlice = createSlice({ nodeId && handleId && handleType && - state.currentConnectionFieldType + state.connectionStartFieldType ) { const newConnection = findConnectionToValidHandle( mouseOverNode, @@ -295,7 +286,7 @@ const nodesSlice = createSlice({ nodeId, handleId, handleType, - state.currentConnectionFieldType + state.connectionStartFieldType ); if (newConnection) { state.edges = addEdge( @@ -306,14 +297,14 @@ const nodesSlice = createSlice({ } } state.connectionStartParams = null; - state.currentConnectionFieldType = null; + state.connectionStartFieldType = null; } else { state.addNewNodePosition = action.payload.cursorPosition; state.isAddNodePopoverOpen = true; } } else { state.connectionStartParams = null; - state.currentConnectionFieldType = null; + state.connectionStartFieldType = null; } state.modifyingEdge = false; }, @@ -529,12 +520,7 @@ const nodesSlice = createSlice({ state.edges = applyEdgeChanges(edgeChanges, state.edges); } }, - nodesDeleted: ( - state, - action: PayloadAction< - Node[] - > - ) => { + nodesDeleted: (state, action: PayloadAction[]>) => { action.payload.forEach((node) => { state.workflow.exposedFields = state.workflow.exposedFields.filter( (f) => f.nodeId !== node.id @@ -588,132 +574,94 @@ const nodesSlice = createSlice({ }, fieldStringValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldNumberValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldBooleanValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldBoardValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldImageValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldColorValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldMainModelValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldRefinerModelValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldVaeModelValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldLoRAModelValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldControlNetModelValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldIPAdapterModelValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldT2IAdapterModelValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldEnumModelValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, fieldSchedulerValueChanged: ( state, - action: FieldValueAction + action: FieldValueAction ) => { fieldValueReducer(state, action); }, - imageCollectionFieldValueChanged: ( - state, - action: PayloadAction<{ - nodeId: string; - fieldName: string; - value: ImageField[]; - }> - ) => { - const { nodeId, fieldName, value } = action.payload; - const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId); - - if (nodeIndex === -1) { - return; - } - - const node = state.nodes?.[nodeIndex]; - - if (!isInvocationNode(node)) { - return; - } - - const input = node.data?.inputs[fieldName]; - if (!input) { - return; - } - - const currentValue = cloneDeep(input.value); - - if (!currentValue) { - input.value = value; - return; - } - - input.value = uniqBy( - (currentValue as ImageField[]).concat(value), - 'image_name' - ); - }, notesNodeValueChanged: ( state, action: PayloadAction<{ nodeId: string; value: string }> @@ -726,12 +674,6 @@ const nodesSlice = createSlice({ } node.data.notes = value; }, - shouldShowFieldTypeLegendChanged: ( - state, - action: PayloadAction - ) => { - state.shouldShowFieldTypeLegend = action.payload; - }, shouldShowMinimapPanelChanged: (state, action: PayloadAction) => { state.shouldShowMinimapPanel = action.payload; }, @@ -745,7 +687,7 @@ const nodesSlice = createSlice({ nodeEditorReset: (state) => { state.nodes = []; state.edges = []; - state.workflow = cloneDeep(initialWorkflow); + state.workflow = cloneDeep(INITIAL_WORKFLOW); }, shouldValidateGraphChanged: (state, action: PayloadAction) => { state.shouldValidateGraph = action.payload; @@ -783,7 +725,7 @@ const nodesSlice = createSlice({ workflowContactChanged: (state, action: PayloadAction) => { state.workflow.contact = action.payload; }, - workflowLoaded: (state, action: PayloadAction) => { + workflowLoaded: (state, action: PayloadAction) => { const { nodes, edges, ...workflow } = action.payload; state.workflow = workflow; @@ -810,7 +752,7 @@ const nodesSlice = createSlice({ }, {}); }, workflowReset: (state) => { - state.workflow = cloneDeep(initialWorkflow); + state.workflow = cloneDeep(INITIAL_WORKFLOW); }, viewportChanged: (state, action: PayloadAction) => { state.viewport = action.payload; @@ -942,7 +884,7 @@ const nodesSlice = createSlice({ //Make sure these get reset if we close the popover and haven't selected a node state.connectionStartParams = null; - state.currentConnectionFieldType = null; + state.connectionStartFieldType = null; }, addNodePopoverToggled: (state) => { state.isAddNodePopoverOpen = !state.isAddNodePopoverOpen; @@ -961,14 +903,14 @@ const nodesSlice = createSlice({ const { source_node_id } = action.payload.data; const node = state.nodeExecutionStates[source_node_id]; if (node) { - node.status = NodeStatus.IN_PROGRESS; + node.status = zNodeStatus.enum.IN_PROGRESS; } }); builder.addCase(appSocketInvocationComplete, (state, action) => { const { source_node_id, result } = action.payload.data; const nes = state.nodeExecutionStates[source_node_id]; if (nes) { - nes.status = NodeStatus.COMPLETED; + nes.status = zNodeStatus.enum.COMPLETED; if (nes.progress !== null) { nes.progress = 1; } @@ -979,7 +921,7 @@ const nodesSlice = createSlice({ const { source_node_id } = action.payload.data; const node = state.nodeExecutionStates[source_node_id]; if (node) { - node.status = NodeStatus.FAILED; + node.status = zNodeStatus.enum.FAILED; node.error = action.payload.data.error; node.progress = null; node.progressImage = null; @@ -990,7 +932,7 @@ const nodesSlice = createSlice({ action.payload.data; const node = state.nodeExecutionStates[source_node_id]; if (node) { - node.status = NodeStatus.IN_PROGRESS; + node.status = zNodeStatus.enum.IN_PROGRESS; node.progress = (step + 1) / total_steps; node.progressImage = progress_image ?? null; } @@ -998,7 +940,7 @@ const nodesSlice = createSlice({ builder.addCase(appSocketQueueItemStatusChanged, (state, action) => { if (['in_progress'].includes(action.payload.data.queue_item.status)) { forEach(state.nodeExecutionStates, (nes) => { - nes.status = NodeStatus.PENDING; + nes.status = zNodeStatus.enum.PENDING; nes.error = null; nes.progress = null; nes.progressImage = null; @@ -1037,7 +979,6 @@ export const { fieldSchedulerValueChanged, fieldStringValueChanged, fieldVaeModelValueChanged, - imageCollectionFieldValueChanged, mouseOverFieldChanged, mouseOverNodeChanged, nodeAdded, @@ -1063,7 +1004,6 @@ export const { selectionPasted, shouldAnimateEdgesChanged, shouldColorEdgesChanged, - shouldShowFieldTypeLegendChanged, shouldShowMinimapPanelChanged, shouldSnapToGridChanged, shouldValidateGraphChanged, diff --git a/invokeai/frontend/web/src/features/nodes/store/types.ts b/invokeai/frontend/web/src/features/nodes/store/types.ts index f6bfa7cad8..b865b9d3a1 100644 --- a/invokeai/frontend/web/src/features/nodes/store/types.ts +++ b/invokeai/frontend/web/src/features/nodes/store/types.ts @@ -6,25 +6,23 @@ import { Viewport, XYPosition, } from 'reactflow'; +import { FieldIdentifier, FieldType } from '../types/field'; import { - FieldIdentifier, - FieldType, + AnyNodeData, InvocationEdgeExtra, InvocationTemplate, - NodeData, NodeExecutionState, - Workflow, -} from '../types/types'; +} from '../types/invocation'; +import { WorkflowV2 } from '../types/workflow'; export type NodesState = { - nodes: Node[]; + nodes: Node[]; edges: Edge[]; nodeTemplates: Record; connectionStartParams: OnConnectStartParams | null; - currentConnectionFieldType: FieldType | null; + connectionStartFieldType: FieldType | null; connectionMade: boolean; modifyingEdge: boolean; - shouldShowFieldTypeLegend: boolean; shouldShowMinimapPanel: boolean; shouldValidateGraph: boolean; shouldAnimateEdges: boolean; @@ -33,13 +31,13 @@ export type NodesState = { shouldColorEdges: boolean; selectedNodes: string[]; selectedEdges: string[]; - workflow: Omit; + workflow: Omit; nodeExecutionStates: Record; viewport: Viewport; isReady: boolean; mouseOverField: FieldIdentifier | null; mouseOverNode: string | null; - nodesToCopy: Node[]; + nodesToCopy: Node[]; edgesToCopy: Edge[]; isAddNodePopoverOpen: boolean; addNewNodePosition: XYPosition | null; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/buildNodeData.ts b/invokeai/frontend/web/src/features/nodes/store/util/buildNodeData.ts index 6cecc8c409..5328f789ad 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/buildNodeData.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/buildNodeData.ts @@ -1,78 +1,73 @@ import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; +import { + FieldInputInstance, + FieldOutputInstance, +} from 'features/nodes/types/field'; import { CurrentImageNodeData, - InputFieldValue, InvocationNodeData, InvocationTemplate, NotesNodeData, - OutputFieldValue, -} from 'features/nodes/types/types'; -import { buildInputFieldValue } from 'features/nodes/util/fieldValueBuilders'; +} from 'features/nodes/types/invocation'; +import { buildFieldInputInstance } from 'features/nodes/util/buildFieldInputInstance'; import { reduce } from 'lodash-es'; import { Node, XYPosition } from 'reactflow'; -import { AnyInvocationType } from 'services/events/types'; import { v4 as uuidv4 } from 'uuid'; export const SHARED_NODE_PROPERTIES: Partial = { dragHandle: `.${DRAG_HANDLE_CLASSNAME}`, }; -export const buildNodeData = ( - type: AnyInvocationType | 'current_image' | 'notes', - position: XYPosition, - template?: InvocationTemplate -): - | Node - | Node - | Node - | undefined => { - const nodeId = uuidv4(); - if (type === 'current_image') { - const node: Node = { - ...SHARED_NODE_PROPERTIES, +export const buildNotesNode = (position: XYPosition): Node => { + const nodeId = uuidv4(); + const node: Node = { + ...SHARED_NODE_PROPERTIES, + id: nodeId, + type: 'notes', + position, + data: { + id: nodeId, + isOpen: true, + label: 'Notes', + notes: '', + type: 'notes', + }, + }; + return node; +}; + +export const buildCurrentImageNode = ( + position: XYPosition +): Node => { + const nodeId = uuidv4(); + const node: Node = { + ...SHARED_NODE_PROPERTIES, + id: nodeId, + type: 'current_image', + position, + data: { id: nodeId, type: 'current_image', - position, - data: { - id: nodeId, - type: 'current_image', - isOpen: true, - label: 'Current Image', - }, - }; + isOpen: true, + label: 'Current Image', + }, + }; + return node; +}; - return node; - } - - if (type === 'notes') { - const node: Node = { - ...SHARED_NODE_PROPERTIES, - id: nodeId, - type: 'notes', - position, - data: { - id: nodeId, - isOpen: true, - label: 'Notes', - notes: '', - type: 'notes', - }, - }; - - return node; - } - - if (template === undefined) { - console.error(`Unable to find template ${type}.`); - return; - } +export const buildInvocationNode = ( + position: XYPosition, + template: InvocationTemplate +): Node => { + const nodeId = uuidv4(); + const { type } = template; const inputs = reduce( template.inputs, (inputsAccumulator, inputTemplate, inputName) => { const fieldId = uuidv4(); - const inputFieldValue: InputFieldValue = buildInputFieldValue( + const inputFieldValue: FieldInputInstance = buildFieldInputInstance( fieldId, inputTemplate ); @@ -81,7 +76,7 @@ export const buildNodeData = ( return inputsAccumulator; }, - {} as Record + {} as Record ); const outputs = reduce( @@ -89,7 +84,7 @@ export const buildNodeData = ( (outputsAccumulator, outputTemplate, outputName) => { const fieldId = uuidv4(); - const outputFieldValue: OutputFieldValue = { + const outputFieldValue: FieldOutputInstance = { id: fieldId, name: outputName, type: outputTemplate.type, @@ -100,10 +95,10 @@ export const buildNodeData = ( return outputsAccumulator; }, - {} as Record + {} as Record ); - const invocation: Node = { + const node: Node = { ...SHARED_NODE_PROPERTIES, id: nodeId, type: 'invocation', @@ -117,11 +112,11 @@ export const buildNodeData = ( isOpen: true, embedWorkflow: false, isIntermediate: type === 'save_image' ? false : true, + useCache: template.useCache, inputs, outputs, - useCache: template.useCache, }, }; - return invocation; + return node; }; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts b/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts index 69386c1f23..0a7adf77cb 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts @@ -1,20 +1,19 @@ -import { Connection, HandleType } from 'reactflow'; -import { Node, Edge } from 'reactflow'; -import { - FieldType, - InputFieldValue, - OutputFieldValue, -} from 'features/nodes/types/types'; +import { Connection, Edge, HandleType, Node } from 'reactflow'; -import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes'; +import { + FieldInputInstance, + FieldOutputInstance, + FieldType, +} from 'features/nodes/types/field'; import { getIsGraphAcyclic } from './getIsGraphAcyclic'; +import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes'; const isValidConnection = ( edges: Edge[], handleCurrentType: HandleType, handleCurrentFieldType: FieldType, node: Node, - handle: InputFieldValue | OutputFieldValue + handle: FieldInputInstance | FieldOutputInstance ) => { let isValidConnection = true; if (handleCurrentType === 'source') { diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts index 57dd284b88..de79561291 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts @@ -1,9 +1,9 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; -import { getIsGraphAcyclic } from './getIsGraphAcyclic'; -import { FieldType } from 'features/nodes/types/types'; +import { FieldType } from 'features/nodes/types/field'; import i18n from 'i18next'; import { HandleType } from 'reactflow'; +import { getIsGraphAcyclic } from './getIsGraphAcyclic'; import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes'; /** @@ -17,15 +17,15 @@ export const makeConnectionErrorSelector = ( handleType: HandleType, fieldType?: FieldType ) => { - return createSelector(stateSelector, (state) => { + return createSelector(stateSelector, (state): string | undefined => { if (!fieldType) { return i18n.t('nodes.noFieldType'); } - const { currentConnectionFieldType, connectionStartParams, nodes, edges } = + const { connectionStartFieldType, connectionStartParams, nodes, edges } = state.nodes; - if (!connectionStartParams || !currentConnectionFieldType) { + if (!connectionStartParams || !connectionStartFieldType) { return i18n.t('nodes.noConnectionInProgress'); } @@ -40,9 +40,9 @@ export const makeConnectionErrorSelector = ( } const targetType = - handleType === 'target' ? fieldType : currentConnectionFieldType; + handleType === 'target' ? fieldType : connectionStartFieldType; const sourceType = - handleType === 'source' ? fieldType : currentConnectionFieldType; + handleType === 'source' ? fieldType : connectionStartFieldType; if (nodeId === connectionNodeId) { return i18n.t('nodes.cannotConnectToSelf'); @@ -80,7 +80,7 @@ export const makeConnectionErrorSelector = ( return edge.target === target && edge.targetHandle === targetHandle; }) && // except CollectionItem inputs can have multiples - targetType !== 'CollectionItem' + targetType.name !== 'CollectionItemField' ) { return i18n.t('nodes.inputMayOnlyHaveOneConnection'); } @@ -100,6 +100,6 @@ export const makeConnectionErrorSelector = ( return i18n.t('nodes.connectionWouldCreateCycle'); } - return null; + return; }); }; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/nodeUpdate.ts b/invokeai/frontend/web/src/features/nodes/store/util/nodeUpdate.ts new file mode 100644 index 0000000000..e9e24823f9 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/nodeUpdate.ts @@ -0,0 +1,68 @@ +import { satisfies } from 'compare-versions'; +import { NodeUpdateError } from 'features/nodes/types/error'; +import { + InvocationNodeData, + InvocationTemplate, +} from 'features/nodes/types/invocation'; +import { zParsedSemver } from 'features/nodes/types/semver'; +import { cloneDeep, defaultsDeep } from 'lodash-es'; +import { Node } from 'reactflow'; +import { buildInvocationNode } from './buildNodeData'; + +export const getNeedsUpdate = ( + node: Node, + template: InvocationTemplate +): boolean => { + if (node.data.type !== template.type) { + return true; + } + return node.data.version !== template.version; +}; /** + * Checks if a node may be updated by comparing its major version with the template's major version. + * @param node The node to check. + * @param template The invocation template to check against. + */ + +export const getMayUpdateNode = ( + node: Node, + template: InvocationTemplate +): boolean => { + const needsUpdate = getNeedsUpdate(node, template); + if (!needsUpdate || node.data.type !== template.type) { + return false; + } + const templateMajor = zParsedSemver.parse(template.version).major; + + return satisfies(node.data.version, `^${templateMajor}`); +}; /** + * Updates a node to the latest version of its template: + * - Create a new node data object with the latest version of the template. + * - Recursively merge new node data object into the node to be updated. + * + * @param node The node to updated. + * @param template The invocation template to update to. + * @throws {NodeUpdateError} If the node is not an invocation node. + */ + +export const updateNode = ( + node: Node, + template: InvocationTemplate +): Node => { + const mayUpdate = getMayUpdateNode(node, template); + + if (!mayUpdate || node.data.type !== template.type) { + throw new NodeUpdateError(`Unable to update node ${node.id}`); + } + + // Start with a "fresh" node - just as if the user created a new node of this type + const defaults = buildInvocationNode(node.position, template); + + // The updateability of a node, via semver comparison, relies on the this kind of recursive merge + // being valid. We rely on the template's major version to be majorly incremented if this kind of + // merge would result in an invalid node. + const clone = cloneDeep(node); + clone.data.version = template.version; + defaultsDeep(clone, defaults); // mutates! + + return clone; +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts index 2f47e47a78..2770af19e3 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts @@ -1,11 +1,12 @@ -import { - COLLECTION_MAP, - COLLECTION_TYPES, - POLYMORPHIC_TO_SINGLE_MAP, - POLYMORPHIC_TYPES, -} from 'features/nodes/types/constants'; -import { FieldType } from 'features/nodes/types/types'; +import { FieldType } from 'features/nodes/types/field'; +import { isEqual } from 'lodash-es'; +/** + * Validates that the source and target types are compatible for a connection. + * @param sourceType The type of the source field. + * @param targetType The type of the target field. + * @returns True if the connection is valid, false otherwise. + */ export const validateSourceAndTargetTypes = ( sourceType: FieldType, targetType: FieldType @@ -13,11 +14,14 @@ export const validateSourceAndTargetTypes = ( // TODO: There's a bug with Collect -> Iterate nodes: // https://github.com/invoke-ai/InvokeAI/issues/3956 // Once this is resolved, we can remove this check. - if (sourceType === 'Collection' && targetType === 'Collection') { + if ( + sourceType.name === 'CollectionField' && + targetType.name === 'CollectionField' + ) { return false; } - if (sourceType === targetType) { + if (isEqual(sourceType, targetType)) { return true; } @@ -31,46 +35,42 @@ export const validateSourceAndTargetTypes = ( */ const isCollectionItemToNonCollection = - sourceType === 'CollectionItem' && !COLLECTION_TYPES.includes(targetType); + sourceType.name === 'CollectionItemField' && !targetType.isCollection; const isNonCollectionToCollectionItem = - targetType === 'CollectionItem' && - !COLLECTION_TYPES.includes(sourceType) && - !POLYMORPHIC_TYPES.includes(sourceType); + targetType.name === 'CollectionItemField' && + !sourceType.isCollection && + !sourceType.isPolymorphic; const isAnythingToPolymorphicOfSameBaseType = - POLYMORPHIC_TYPES.includes(targetType) && - (() => { - if (!POLYMORPHIC_TYPES.includes(targetType)) { - return false; - } - const baseType = - POLYMORPHIC_TO_SINGLE_MAP[ - targetType as keyof typeof POLYMORPHIC_TO_SINGLE_MAP - ]; - - const collectionType = - COLLECTION_MAP[baseType as keyof typeof COLLECTION_MAP]; - - return sourceType === baseType || sourceType === collectionType; - })(); + targetType.isPolymorphic && sourceType.name === targetType.name; const isGenericCollectionToAnyCollectionOrPolymorphic = - sourceType === 'Collection' && - (COLLECTION_TYPES.includes(targetType) || - POLYMORPHIC_TYPES.includes(targetType)); + sourceType.name === 'CollectionField' && + (targetType.isCollection || targetType.isPolymorphic); const isCollectionToGenericCollection = - targetType === 'Collection' && COLLECTION_TYPES.includes(sourceType); + targetType.name === 'CollectionField' && sourceType.isCollection; - const isIntToFloat = sourceType === 'integer' && targetType === 'float'; + const areBothTypesSingle = + !sourceType.isCollection && + !sourceType.isPolymorphic && + !targetType.isCollection && + !targetType.isPolymorphic; + + const isIntToFloat = + areBothTypesSingle && + sourceType.name === 'IntegerField' && + targetType.name === 'FloatField'; const isIntOrFloatToString = - (sourceType === 'integer' || sourceType === 'float') && - targetType === 'string'; + areBothTypesSingle && + (sourceType.name === 'IntegerField' || sourceType.name === 'FloatField') && + targetType.name === 'StringField'; - const isTargetAnyType = targetType === 'Any'; + const isTargetAnyType = targetType.name === 'AnyField'; + // One of these must be true for the connection to be valid return ( isCollectionItemToNonCollection || isNonCollectionToCollectionItem || diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts new file mode 100644 index 0000000000..0cab248c80 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -0,0 +1,216 @@ +import { z } from 'zod'; + +// #region Field data schemas +export const zImageField = z.object({ + image_name: z.string().trim().min(1), +}); +export type ImageField = z.infer; + +export const zBoardField = z.object({ + board_id: z.string().trim().min(1), +}); +export type BoardField = z.infer; + +export const zColorField = z.object({ + r: z.number().int().min(0).max(255), + g: z.number().int().min(0).max(255), + b: z.number().int().min(0).max(255), + a: z.number().int().min(0).max(255), +}); +export type ColorField = z.infer; + +export const zSchedulerField = z.enum([ + 'euler', + 'deis', + 'ddim', + 'ddpm', + 'dpmpp_2s', + 'dpmpp_2m', + 'dpmpp_2m_sde', + 'dpmpp_sde', + 'heun', + 'kdpm_2', + 'lms', + 'pndm', + 'unipc', + 'euler_k', + 'dpmpp_2s_k', + 'dpmpp_2m_k', + 'dpmpp_2m_sde_k', + 'dpmpp_sde_k', + 'heun_k', + 'lms_k', + 'euler_a', + 'kdpm_2_a', + 'lcm', +]); +export type SchedulerField = z.infer; +// #endregion + +// #region Model-related schemas +export const zBaseModel = z.enum([ + 'any', + 'sd-1', + 'sd-2', + 'sdxl', + 'sdxl-refiner', +]); +export const zModelType = z.enum([ + 'onnx', + 'main', + 'vae', + 'lora', + 'controlnet', + 'embedding', +]); +export const zModelName = z.string().trim().min(1); +export const zModelIdentifier = z.object({ + model_name: zModelName, + base_model: zBaseModel, +}); +export type BaseModel = z.infer; +export type ModelType = z.infer; +export type ModelIdentifier = z.infer; + +export const zMainModelField = z.object({ + model_name: zModelName, + base_model: zBaseModel, + model_type: z.literal('main'), +}); +export const zONNXModelField = z.object({ + model_name: zModelName, + base_model: zBaseModel, + model_type: z.literal('onnx'), +}); +export const zMainOrONNXModelField = z.union([ + zMainModelField, + zONNXModelField, +]); +export const zSDXLRefinerModelField = z.object({ + model_name: z.string().min(1), + base_model: z.literal('sdxl-refiner'), + model_type: z.literal('main'), +}); +export type MainModelField = z.infer; +export type ONNXModelField = z.infer; +export type MainOrONNXModelField = z.infer; +export type SDXLRefinerModelField = z.infer; + +export const zSubModelType = z.enum([ + 'unet', + 'text_encoder', + 'text_encoder_2', + 'tokenizer', + 'tokenizer_2', + 'vae', + 'vae_decoder', + 'vae_encoder', + 'scheduler', + 'safety_checker', +]); +export type SubModelType = z.infer; + +export const zVAEModelField = zModelIdentifier; + +export const zModelInfo = zModelIdentifier.extend({ + model_type: zModelType, + submodel: zSubModelType.optional(), +}); +export type ModelInfo = z.infer; + +export const zLoRAModelField = zModelIdentifier; +export type LoRAModelField = z.infer; + +export const zControlNetModelField = zModelIdentifier; +export type ControlNetModelField = z.infer; + +export const zIPAdapterModelField = zModelIdentifier; +export type IPAdapterModelField = z.infer; + +export const zT2IAdapterModelField = zModelIdentifier; +export type T2IAdapterModelField = z.infer; + +export const zLoraInfo = zModelInfo.extend({ + weight: z.number().optional(), +}); +export type LoraInfo = z.infer; + +export const zUNetField = z.object({ + unet: zModelInfo, + scheduler: zModelInfo, + loras: z.array(zLoraInfo), +}); +export type UNetField = z.infer; + +export const zCLIPField = z.object({ + tokenizer: zModelInfo, + text_encoder: zModelInfo, + skipped_layers: z.number(), + loras: z.array(zLoraInfo), +}); +export type CLIPField = z.infer; + +export const zVAEField = z.object({ + vae: zModelInfo, +}); +export type VAEField = z.infer; +// #endregion + +// #region Control Adapters +export const zControlField = z.object({ + image: zImageField, + control_model: zControlNetModelField, + control_weight: z.union([z.number(), z.array(z.number())]).optional(), + begin_step_percent: z.number().optional(), + end_step_percent: z.number().optional(), + control_mode: z + .enum(['balanced', 'more_prompt', 'more_control', 'unbalanced']) + .optional(), + resize_mode: z + .enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple']) + .optional(), +}); +export type ControlField = z.infer; + +export const zIPAdapterField = z.object({ + image: zImageField, + ip_adapter_model: zIPAdapterModelField, + weight: z.number(), + begin_step_percent: z.number().optional(), + end_step_percent: z.number().optional(), +}); +export type IPAdapterField = z.infer; + +export const zT2IAdapterField = z.object({ + image: zImageField, + t2i_adapter_model: zT2IAdapterModelField, + weight: z.union([z.number(), z.array(z.number())]).optional(), + begin_step_percent: z.number().optional(), + end_step_percent: z.number().optional(), + resize_mode: z + .enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple']) + .optional(), +}); +export type T2IAdapterField = z.infer; +// #endregion + +// #region ProgressImage +export const zProgressImage = z.object({ + dataURL: z.string(), + width: z.number().int(), + height: z.number().int(), +}); +export type ProgressImage = z.infer; +// #endregion + +// #region ImageOutput +export const zImageOutput = z.object({ + image: zImageField, + width: z.number().int(), + height: z.number().int(), + type: z.literal('image_output'), +}); +export type ImageOutput = z.infer; +export const isImageOutput = (output: unknown): output is ImageOutput => + zImageOutput.safeParse(output).success; +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index c6eec736da..a97899de91 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -1,58 +1,31 @@ -import { - FieldType, - FieldTypeMap, - FieldTypeMapWithNumber, - FieldUIConfig, -} from './types'; -import { t } from 'i18next'; - +/** + * How long to wait before showing a tooltip when hovering a field handle. + */ export const HANDLE_TOOLTIP_OPEN_DELAY = 500; -export const COLOR_TOKEN_VALUE = 500; + +/** + * The width of a node in the UI in pixels. + */ export const NODE_WIDTH = 320; -export const NODE_MIN_WIDTH = 320; + +/** + * This class name is special - reactflow uses it to identify the drag handle of a node, + * applying the appropriate listeners to it. + */ export const DRAG_HANDLE_CLASSNAME = 'node-drag-handle'; -export const IMAGE_FIELDS = ['ImageField', 'ImageCollection']; -export const FOOTER_FIELDS = IMAGE_FIELDS; - +/** + * Helper for getting the kind of a field. + */ export const KIND_MAP = { input: 'inputs' as const, output: 'outputs' as const, }; -export const COLLECTION_TYPES: FieldType[] = [ - 'Collection', - 'IntegerCollection', - 'BooleanCollection', - 'FloatCollection', - 'StringCollection', - 'ImageCollection', - 'LatentsCollection', - 'ConditioningCollection', - 'ControlCollection', - 'ColorCollection', - 'T2IAdapterCollection', - 'IPAdapterCollection', - 'MetadataItemCollection', - 'MetadataCollection', -]; - -export const POLYMORPHIC_TYPES: FieldType[] = [ - 'IntegerPolymorphic', - 'BooleanPolymorphic', - 'FloatPolymorphic', - 'StringPolymorphic', - 'ImagePolymorphic', - 'LatentsPolymorphic', - 'ConditioningPolymorphic', - 'ControlPolymorphic', - 'ColorPolymorphic', - 'T2IAdapterPolymorphic', - 'IPAdapterPolymorphic', - 'MetadataItemPolymorphic', -]; - -export const MODEL_TYPES: FieldType[] = [ +/** + * Model types' handles are rendered as squares in the UI. + */ +export const MODEL_TYPES = [ 'IPAdapterModelField', 'ControlNetModelField', 'LoRAModelField', @@ -68,373 +41,33 @@ export const MODEL_TYPES: FieldType[] = [ 'IPAdapterModelField', ]; -export const COLLECTION_MAP: FieldTypeMapWithNumber = { - integer: 'IntegerCollection', - boolean: 'BooleanCollection', - number: 'FloatCollection', - float: 'FloatCollection', - string: 'StringCollection', - ImageField: 'ImageCollection', - LatentsField: 'LatentsCollection', - ConditioningField: 'ConditioningCollection', - ControlField: 'ControlCollection', - ColorField: 'ColorCollection', - T2IAdapterField: 'T2IAdapterCollection', - IPAdapterField: 'IPAdapterCollection', - MetadataItemField: 'MetadataItemCollection', - MetadataField: 'MetadataCollection', -}; -export const isCollectionItemType = ( - itemType: string | undefined -): itemType is keyof typeof COLLECTION_MAP => - Boolean(itemType && itemType in COLLECTION_MAP); - -export const SINGLE_TO_POLYMORPHIC_MAP: FieldTypeMapWithNumber = { - integer: 'IntegerPolymorphic', - boolean: 'BooleanPolymorphic', - number: 'FloatPolymorphic', - float: 'FloatPolymorphic', - string: 'StringPolymorphic', - ImageField: 'ImagePolymorphic', - LatentsField: 'LatentsPolymorphic', - ConditioningField: 'ConditioningPolymorphic', - ControlField: 'ControlPolymorphic', - ColorField: 'ColorPolymorphic', - T2IAdapterField: 'T2IAdapterPolymorphic', - IPAdapterField: 'IPAdapterPolymorphic', - MetadataItemField: 'MetadataItemPolymorphic', -}; - -export const POLYMORPHIC_TO_SINGLE_MAP: FieldTypeMap = { - IntegerPolymorphic: 'integer', - BooleanPolymorphic: 'boolean', - FloatPolymorphic: 'float', - StringPolymorphic: 'string', - ImagePolymorphic: 'ImageField', - LatentsPolymorphic: 'LatentsField', - ConditioningPolymorphic: 'ConditioningField', - ControlPolymorphic: 'ControlField', - ColorPolymorphic: 'ColorField', - T2IAdapterPolymorphic: 'T2IAdapterField', - IPAdapterPolymorphic: 'IPAdapterField', - MetadataItemPolymorphic: 'MetadataItemField', -}; - -export const TYPES_WITH_INPUT_COMPONENTS: FieldType[] = [ - 'string', - 'StringPolymorphic', - 'boolean', - 'BooleanPolymorphic', - 'integer', - 'float', - 'FloatPolymorphic', - 'IntegerPolymorphic', - 'enum', - 'ImageField', - 'ImagePolymorphic', - 'MainModelField', - 'SDXLRefinerModelField', - 'VaeModelField', - 'LoRAModelField', - 'ControlNetModelField', - 'ColorField', - 'SDXLMainModelField', - 'Scheduler', - 'IPAdapterModelField', - 'BoardField', - 'T2IAdapterModelField', -]; - -export const isPolymorphicItemType = ( - itemType: string | undefined -): itemType is keyof typeof SINGLE_TO_POLYMORPHIC_MAP => - Boolean(itemType && itemType in SINGLE_TO_POLYMORPHIC_MAP); - -export const FIELDS: Record = { - Any: { - color: 'gray.500', - description: 'Any field type is accepted.', - title: 'Any', - }, - MetadataField: { - color: 'gray.500', - description: 'A metadata dict.', - title: 'Metadata Dict', - }, - MetadataCollection: { - color: 'gray.500', - description: 'A collection of metadata dicts.', - title: 'Metadata Dict Collection', - }, - MetadataItemField: { - color: 'gray.500', - description: 'A metadata item.', - title: 'Metadata Item', - }, - MetadataItemCollection: { - color: 'gray.500', - description: 'Any field type is accepted.', - title: 'Metadata Item Collection', - }, - MetadataItemPolymorphic: { - color: 'gray.500', - description: - 'MetadataItem or MetadataItemCollection field types are accepted.', - title: 'Metadata Item Polymorphic', - }, - boolean: { - color: 'green.500', - description: t('nodes.booleanDescription'), - title: t('nodes.boolean'), - }, - BooleanCollection: { - color: 'green.500', - description: t('nodes.booleanCollectionDescription'), - title: t('nodes.booleanCollection'), - }, - BooleanPolymorphic: { - color: 'green.500', - description: t('nodes.booleanPolymorphicDescription'), - title: t('nodes.booleanPolymorphic'), - }, - ClipField: { - color: 'green.500', - description: t('nodes.clipFieldDescription'), - title: t('nodes.clipField'), - }, - Collection: { - color: 'base.500', - description: t('nodes.collectionDescription'), - title: t('nodes.collection'), - }, - CollectionItem: { - color: 'base.500', - description: t('nodes.collectionItemDescription'), - title: t('nodes.collectionItem'), - }, - ColorCollection: { - color: 'pink.300', - description: t('nodes.colorCollectionDescription'), - title: t('nodes.colorCollection'), - }, - ColorField: { - color: 'pink.300', - description: t('nodes.colorFieldDescription'), - title: t('nodes.colorField'), - }, - ColorPolymorphic: { - color: 'pink.300', - description: t('nodes.colorPolymorphicDescription'), - title: t('nodes.colorPolymorphic'), - }, - ConditioningCollection: { - color: 'cyan.500', - description: t('nodes.conditioningCollectionDescription'), - title: t('nodes.conditioningCollection'), - }, - ConditioningField: { - color: 'cyan.500', - description: t('nodes.conditioningFieldDescription'), - title: t('nodes.conditioningField'), - }, - ConditioningPolymorphic: { - color: 'cyan.500', - description: t('nodes.conditioningPolymorphicDescription'), - title: t('nodes.conditioningPolymorphic'), - }, - ControlCollection: { - color: 'teal.500', - description: t('nodes.controlCollectionDescription'), - title: t('nodes.controlCollection'), - }, - ControlField: { - color: 'teal.500', - description: t('nodes.controlFieldDescription'), - title: t('nodes.controlField'), - }, - ControlNetModelField: { - color: 'teal.500', - description: 'TODO', - title: 'ControlNet', - }, - ControlPolymorphic: { - color: 'teal.500', - description: 'Control info passed between nodes.', - title: 'Control Polymorphic', - }, - DenoiseMaskField: { - color: 'blue.300', - description: t('nodes.denoiseMaskFieldDescription'), - title: t('nodes.denoiseMaskField'), - }, - enum: { - color: 'blue.500', - description: t('nodes.enumDescription'), - title: t('nodes.enum'), - }, - float: { - color: 'orange.500', - description: t('nodes.floatDescription'), - title: t('nodes.float'), - }, - FloatCollection: { - color: 'orange.500', - description: t('nodes.floatCollectionDescription'), - title: t('nodes.floatCollection'), - }, - FloatPolymorphic: { - color: 'orange.500', - description: t('nodes.floatPolymorphicDescription'), - title: t('nodes.floatPolymorphic'), - }, - ImageCollection: { - color: 'purple.500', - description: t('nodes.imageCollectionDescription'), - title: t('nodes.imageCollection'), - }, - ImageField: { - color: 'purple.500', - description: t('nodes.imageFieldDescription'), - title: t('nodes.imageField'), - }, - BoardField: { - color: 'purple.500', - description: t('nodes.imageFieldDescription'), - title: t('nodes.imageField'), - }, - ImagePolymorphic: { - color: 'purple.500', - description: t('nodes.imagePolymorphicDescription'), - title: t('nodes.imagePolymorphic'), - }, - integer: { - color: 'red.500', - description: t('nodes.integerDescription'), - title: t('nodes.integer'), - }, - IntegerCollection: { - color: 'red.500', - description: t('nodes.integerCollectionDescription'), - title: t('nodes.integerCollection'), - }, - IntegerPolymorphic: { - color: 'red.500', - description: t('nodes.integerPolymorphicDescription'), - title: t('nodes.integerPolymorphic'), - }, - IPAdapterCollection: { - color: 'teal.500', - description: t('nodes.ipAdapterCollectionDescription'), - title: t('nodes.ipAdapterCollection'), - }, - IPAdapterField: { - color: 'teal.500', - description: t('nodes.ipAdapterDescription'), - title: t('nodes.ipAdapter'), - }, - IPAdapterModelField: { - color: 'teal.500', - description: t('nodes.ipAdapterModelDescription'), - title: t('nodes.ipAdapterModel'), - }, - IPAdapterPolymorphic: { - color: 'teal.500', - description: t('nodes.ipAdapterPolymorphicDescription'), - title: t('nodes.ipAdapterPolymorphic'), - }, - LatentsCollection: { - color: 'pink.500', - description: t('nodes.latentsCollectionDescription'), - title: t('nodes.latentsCollection'), - }, - LatentsField: { - color: 'pink.500', - description: t('nodes.latentsFieldDescription'), - title: t('nodes.latentsField'), - }, - LatentsPolymorphic: { - color: 'pink.500', - description: t('nodes.latentsPolymorphicDescription'), - title: t('nodes.latentsPolymorphic'), - }, - LoRAModelField: { - color: 'teal.500', - description: t('nodes.loRAModelFieldDescription'), - title: t('nodes.loRAModelField'), - }, - MainModelField: { - color: 'teal.500', - description: t('nodes.mainModelFieldDescription'), - title: t('nodes.mainModelField'), - }, - ONNXModelField: { - color: 'teal.500', - description: t('nodes.oNNXModelFieldDescription'), - title: t('nodes.oNNXModelField'), - }, - Scheduler: { - color: 'base.500', - description: t('nodes.schedulerDescription'), - title: t('nodes.scheduler'), - }, - SDXLMainModelField: { - color: 'teal.500', - description: t('nodes.sDXLMainModelFieldDescription'), - title: t('nodes.sDXLMainModelField'), - }, - SDXLRefinerModelField: { - color: 'teal.500', - description: t('nodes.sDXLRefinerModelFieldDescription'), - title: t('nodes.sDXLRefinerModelField'), - }, - string: { - color: 'yellow.500', - description: t('nodes.stringDescription'), - title: t('nodes.string'), - }, - StringCollection: { - color: 'yellow.500', - description: t('nodes.stringCollectionDescription'), - title: t('nodes.stringCollection'), - }, - StringPolymorphic: { - color: 'yellow.500', - description: t('nodes.stringPolymorphicDescription'), - title: t('nodes.stringPolymorphic'), - }, - T2IAdapterCollection: { - color: 'teal.500', - description: t('nodes.t2iAdapterCollectionDescription'), - title: t('nodes.t2iAdapterCollection'), - }, - T2IAdapterField: { - color: 'teal.500', - description: t('nodes.t2iAdapterFieldDescription'), - title: t('nodes.t2iAdapterField'), - }, - T2IAdapterModelField: { - color: 'teal.500', - description: 'TODO', - title: 'T2I-Adapter', - }, - T2IAdapterPolymorphic: { - color: 'teal.500', - description: 'T2I-Adapter info passed between nodes.', - title: 'T2I-Adapter Polymorphic', - }, - UNetField: { - color: 'red.500', - description: t('nodes.uNetFieldDescription'), - title: t('nodes.uNetField'), - }, - VaeField: { - color: 'blue.500', - description: t('nodes.vaeFieldDescription'), - title: t('nodes.vaeField'), - }, - VaeModelField: { - color: 'teal.500', - description: t('nodes.vaeModelFieldDescription'), - title: t('nodes.vaeModelField'), - }, +/** + * Colors for each field type - applies to their handles and edges. + */ +export const FIELD_COLORS: { [key: string]: string } = { + BoardField: 'purple.500', + BooleanField: 'green.500', + ClipField: 'green.500', + ColorField: 'pink.300', + ConditioningField: 'cyan.500', + ControlField: 'teal.500', + ControlNetModelField: 'teal.500', + EnumField: 'blue.500', + FloatField: 'orange.500', + ImageField: 'purple.500', + IntegerField: 'red.500', + IPAdapterField: 'teal.500', + IPAdapterModelField: 'teal.500', + LatentsField: 'pink.500', + LoRAModelField: 'teal.500', + MainModelField: 'teal.500', + ONNXModelField: 'teal.500', + SDXLMainModelField: 'teal.500', + SDXLRefinerModelField: 'teal.500', + StringField: 'yellow.500', + T2IAdapterField: 'teal.500', + T2IAdapterModelField: 'teal.500', + UNetField: 'red.500', + VaeField: 'blue.500', + VaeModelField: 'teal.500', }; diff --git a/invokeai/frontend/web/src/features/nodes/types/error.ts b/invokeai/frontend/web/src/features/nodes/types/error.ts new file mode 100644 index 0000000000..b7ffb753bc --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/error.ts @@ -0,0 +1,59 @@ +/** + * Invalid Workflow Version Error + * Raised when a workflow version is not recognized. + */ +export class WorkflowVersionError extends Error { + /** + * Create WorkflowVersionError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} + +/** + * Unable to Update Node Error + * Raised when a node cannot be updated. + */ +export class NodeUpdateError extends Error { + /** + * Create NodeUpdateError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} + +/** + * FieldTypeParseError + * Raised when a field cannot be parsed from a field schema. + */ +export class FieldTypeParseError extends Error { + /** + * Create FieldTypeParseError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} + +/** + * UnsupportedFieldTypeError + * Raised when an unsupported field type is parsed. + */ +export class UnsupportedFieldTypeError extends Error { + /** + * Create UnsupportedFieldTypeError + * @param {String} message + */ + constructor(message: string) { + super(message); + this.name = this.constructor.name; + } +} diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts new file mode 100644 index 0000000000..dd1c50f6e3 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -0,0 +1,1114 @@ +import { z } from 'zod'; +import { + zBoardField, + zColorField, + zControlNetModelField, + zIPAdapterModelField, + zImageField, + zLoRAModelField, + zMainOrONNXModelField, + zSchedulerField, + zT2IAdapterModelField, + zVAEModelField, +} from './common'; + +/** + * zod schemas & inferred types for input field values. + * + * These schemas and types are only required for field types that have UI components and allow the + * user to directly provide values. + * + * This includes primitive values (numbers, strings, booleans), models, scheduler, etc. + * + * If a field type does not have a UI component, then it does not need to be included here, because + * we never store its value. Such field types will be handled via the "StatelessField" logic. + * + * Fields require: + * - zFieldType - zod schema for the field type + * - zFieldValue - zod schema for the field value + * - zFieldInputInstance - zod schema for the field's input instance + * - zFieldOutputInstance - zod schema for the field's output instance + * - zFieldInputTemplate - zod schema for the field's input template + * - zFieldOutputTemplate - zod schema for the field's output template + * + * These then must be added to the unions at the bottom of this file. + */ + +/** */ + +// #region Base schemas & misc +export const zFieldInput = z.enum(['connection', 'direct', 'any']); +export type FieldInput = z.infer; + +export const zFieldUIComponent = z.enum(['none', 'textarea', 'slider']); +export type FieldUIComponent = z.infer; + +export const zFieldInstanceBase = z.object({ + id: z.string().trim().min(1), + name: z.string().trim().min(1), +}); +export const zFieldInputInstanceBase = zFieldInstanceBase.extend({ + fieldKind: z.literal('input'), + label: z.string().nullish(), +}); +export const zFieldOutputInstanceBase = zFieldInstanceBase.extend({ + fieldKind: z.literal('output'), +}); +export type FieldInstanceBase = z.infer; +export type FieldInputInstanceBase = z.infer; +export type FieldOutputInstanceBase = z.infer; + +export const zFieldTemplateBase = z.object({ + name: z.string().min(1), + title: z.string().min(1), + description: z.string().nullish(), + ui_hidden: z.boolean(), + ui_type: z.string().nullish(), + ui_order: z.number().int().nullish(), +}); +export const zFieldInputTemplateBase = zFieldTemplateBase.extend({ + fieldKind: z.literal('input'), + input: zFieldInput, + required: z.boolean(), + ui_component: zFieldUIComponent.nullish(), + ui_choice_labels: z.record(z.string()).nullish(), +}); +export const zFieldOutputTemplateBase = zFieldTemplateBase.extend({ + fieldKind: z.literal('output'), +}); +export type FieldTemplateBase = z.infer; +export type FieldInputTemplateBase = z.infer; +export type FieldOutputTemplateBase = z.infer; + +export const zFieldTypeBase = z.object({ + isCollection: z.boolean(), + isPolymorphic: z.boolean(), +}); + +export const zFieldIdentifier = z.object({ + nodeId: z.string().trim().min(1), + fieldName: z.string().trim().min(1), +}); +export type FieldIdentifier = z.infer; +export const isFieldIdentifier = (val: unknown): val is FieldIdentifier => + zFieldIdentifier.safeParse(val).success; +// #endregion + +// #region IntegerField +export const zIntegerFieldType = zFieldTypeBase.extend({ + name: z.literal('IntegerField'), +}); +export const zIntegerFieldValue = z.number().int(); +export const zIntegerFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zIntegerFieldType, + value: zIntegerFieldValue, +}); +export const zIntegerFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zIntegerFieldType, +}); +export const zIntegerFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zIntegerFieldType, + default: zIntegerFieldValue, + multipleOf: z.number().int().optional(), + maximum: z.number().int().optional(), + exclusiveMaximum: z.number().int().optional(), + minimum: z.number().int().optional(), + exclusiveMinimum: z.number().int().optional(), +}); +export const zIntegerFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zIntegerFieldType, +}); +export type IntegerFieldType = z.infer; +export type IntegerFieldValue = z.infer; +export type IntegerFieldInputInstance = z.infer< + typeof zIntegerFieldInputInstance +>; +export type IntegerFieldInputTemplate = z.infer< + typeof zIntegerFieldInputTemplate +>; +export const isIntegerFieldInputInstance = ( + val: unknown +): val is IntegerFieldInputInstance => + zIntegerFieldInputInstance.safeParse(val).success; +export const isIntegerFieldInputTemplate = ( + val: unknown +): val is IntegerFieldInputTemplate => + zIntegerFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region FloatField +export const zFloatFieldType = zFieldTypeBase.extend({ + name: z.literal('FloatField'), +}); +export const zFloatFieldValue = z.number(); +export const zFloatFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zFloatFieldType, + value: zFloatFieldValue, +}); +export const zFloatFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zFloatFieldType, +}); +export const zFloatFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zFloatFieldType, + default: zFloatFieldValue, + multipleOf: z.number().optional(), + maximum: z.number().optional(), + exclusiveMaximum: z.number().optional(), + minimum: z.number().optional(), + exclusiveMinimum: z.number().optional(), +}); +export const zFloatFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zFloatFieldType, +}); +export type FloatFieldType = z.infer; +export type FloatFieldValue = z.infer; +export type FloatFieldInputInstance = z.infer; +export type FloatFieldOutputInstance = z.infer< + typeof zFloatFieldOutputInstance +>; +export type FloatFieldInputTemplate = z.infer; +export type FloatFieldOutputTemplate = z.infer< + typeof zFloatFieldOutputTemplate +>; +export const isFloatFieldInputInstance = ( + val: unknown +): val is FloatFieldInputInstance => + zFloatFieldInputInstance.safeParse(val).success; +export const isFloatFieldInputTemplate = ( + val: unknown +): val is FloatFieldInputTemplate => + zFloatFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region StringField +export const zStringFieldType = zFieldTypeBase.extend({ + name: z.literal('StringField'), +}); +export const zStringFieldValue = z.string(); +export const zStringFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zStringFieldType, + value: zStringFieldValue, +}); +export const zStringFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zStringFieldType, +}); +export const zStringFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zStringFieldType, + default: zStringFieldValue, + maxLength: z.number().int().optional(), + minLength: z.number().int().optional(), +}); +export const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zStringFieldType, +}); + +export type StringFieldType = z.infer; +export type StringFieldValue = z.infer; +export type StringFieldInputInstance = z.infer< + typeof zStringFieldInputInstance +>; +export type StringFieldOutputInstance = z.infer< + typeof zStringFieldOutputInstance +>; +export type StringFieldInputTemplate = z.infer< + typeof zStringFieldInputTemplate +>; +export type StringFieldOutputTemplate = z.infer< + typeof zStringFieldOutputTemplate +>; +export const isStringFieldInputInstance = ( + val: unknown +): val is StringFieldInputInstance => + zStringFieldInputInstance.safeParse(val).success; +export const isStringFieldInputTemplate = ( + val: unknown +): val is StringFieldInputTemplate => + zStringFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region BooleanField +export const zBooleanFieldType = zFieldTypeBase.extend({ + name: z.literal('BooleanField'), +}); +export const zBooleanFieldValue = z.boolean(); +export const zBooleanFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zBooleanFieldType, + value: zBooleanFieldValue, +}); +export const zBooleanFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zBooleanFieldType, +}); +export const zBooleanFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zBooleanFieldType, + default: zBooleanFieldValue, +}); +export const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zBooleanFieldType, +}); +export type BooleanFieldType = z.infer; +export type BooleanFieldValue = z.infer; +export type BooleanFieldInputInstance = z.infer< + typeof zBooleanFieldInputInstance +>; +export type BooleanFieldOutputInstance = z.infer< + typeof zBooleanFieldOutputInstance +>; +export type BooleanFieldInputTemplate = z.infer< + typeof zBooleanFieldInputTemplate +>; +export type BooleanFieldOutputTemplate = z.infer< + typeof zBooleanFieldOutputTemplate +>; +export const isBooleanFieldInputInstance = ( + val: unknown +): val is BooleanFieldInputInstance => + zBooleanFieldInputInstance.safeParse(val).success; +export const isBooleanFieldInputTemplate = ( + val: unknown +): val is BooleanFieldInputTemplate => + zBooleanFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region EnumField +export const zEnumFieldType = zFieldTypeBase.extend({ + name: z.literal('EnumField'), +}); +export const zEnumFieldValue = z.string(); +export const zEnumFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zEnumFieldType, + value: zEnumFieldValue, +}); +export const zEnumFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zEnumFieldType, +}); +export const zEnumFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zEnumFieldType, + default: zEnumFieldValue, + options: z.array(z.string()), + labels: z.record(z.string()).optional(), +}); +export const zEnumFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zEnumFieldType, +}); +export type EnumFieldType = z.infer; +export type EnumFieldValue = z.infer; +export type EnumFieldInputInstance = z.infer; +export type EnumFieldOutputInstance = z.infer; +export type EnumFieldInputTemplate = z.infer; +export type EnumFieldOutputTemplate = z.infer; +export const isEnumFieldInputInstance = ( + val: unknown +): val is EnumFieldInputInstance => + zEnumFieldInputInstance.safeParse(val).success; +export const isEnumFieldInputTemplate = ( + val: unknown +): val is EnumFieldInputTemplate => + zEnumFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region ImageField +export const zImageFieldType = zFieldTypeBase.extend({ + name: z.literal('ImageField'), +}); +export const zImageFieldValue = zImageField.optional(); +export const zImageFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zImageFieldType, + value: zImageFieldValue, +}); +export const zImageFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zImageFieldType, +}); +export const zImageFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zImageFieldType, + default: zImageFieldValue, +}); +export const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zImageFieldType, +}); +export type ImageFieldType = z.infer; +export type ImageFieldValue = z.infer; +export type ImageFieldInputInstance = z.infer; +export type ImageFieldOutputInstance = z.infer< + typeof zImageFieldOutputInstance +>; +export type ImageFieldInputTemplate = z.infer; +export type ImageFieldOutputTemplate = z.infer< + typeof zImageFieldOutputTemplate +>; +export const isImageFieldInputInstance = ( + val: unknown +): val is ImageFieldInputInstance => + zImageFieldInputInstance.safeParse(val).success; +export const isImageFieldInputTemplate = ( + val: unknown +): val is ImageFieldInputTemplate => + zImageFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region BoardField +export const zBoardFieldType = zFieldTypeBase.extend({ + name: z.literal('BoardField'), +}); +export const zBoardFieldValue = zBoardField.optional(); +export const zBoardFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zBoardFieldType, + value: zBoardFieldValue, +}); +export const zBoardFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zBoardFieldType, +}); +export const zBoardFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zBoardFieldType, + default: zBoardFieldValue, +}); +export const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zBoardFieldType, +}); +export type BoardFieldType = z.infer; +export type BoardFieldValue = z.infer; +export type BoardFieldInputInstance = z.infer; +export type BoardFieldOutputInstance = z.infer< + typeof zBoardFieldOutputInstance +>; +export type BoardFieldInputTemplate = z.infer; +export type BoardFieldOutputTemplate = z.infer< + typeof zBoardFieldOutputTemplate +>; +export const isBoardFieldInputInstance = ( + val: unknown +): val is BoardFieldInputInstance => + zBoardFieldInputInstance.safeParse(val).success; +export const isBoardFieldInputTemplate = ( + val: unknown +): val is BoardFieldInputTemplate => + zBoardFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region ColorField +export const zColorFieldType = zFieldTypeBase.extend({ + name: z.literal('ColorField'), +}); +export const zColorFieldValue = zColorField.optional(); +export const zColorFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zColorFieldType, + value: zColorFieldValue, +}); +export const zColorFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zColorFieldType, +}); +export const zColorFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zColorFieldType, + default: zColorFieldValue, +}); +export const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zColorFieldType, +}); +export type ColorFieldType = z.infer; +export type ColorFieldValue = z.infer; +export type ColorFieldInputInstance = z.infer; +export type ColorFieldOutputInstance = z.infer< + typeof zColorFieldOutputInstance +>; +export type ColorFieldInputTemplate = z.infer; +export type ColorFieldOutputTemplate = z.infer< + typeof zColorFieldOutputTemplate +>; +export const isColorFieldInputInstance = ( + val: unknown +): val is ColorFieldInputInstance => + zColorFieldInputInstance.safeParse(val).success; +export const isColorFieldInputTemplate = ( + val: unknown +): val is ColorFieldInputTemplate => + zColorFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region MainModelField +export const zMainModelFieldType = zFieldTypeBase.extend({ + name: z.literal('MainModelField'), +}); +export const zMainModelFieldValue = zMainOrONNXModelField.optional(); +export const zMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zMainModelFieldType, + value: zMainModelFieldValue, +}); +export const zMainModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zMainModelFieldType, +}); +export const zMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zMainModelFieldType, + default: zMainModelFieldValue, +}); +export const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zMainModelFieldType, +}); +export type MainModelFieldType = z.infer; +export type MainModelFieldValue = z.infer; +export type MainModelFieldInputInstance = z.infer< + typeof zMainModelFieldInputInstance +>; +export type MainModelFieldOutputInstance = z.infer< + typeof zMainModelFieldOutputInstance +>; +export type MainModelFieldInputTemplate = z.infer< + typeof zMainModelFieldInputTemplate +>; +export type MainModelFieldOutputTemplate = z.infer< + typeof zMainModelFieldOutputTemplate +>; +export const isMainModelFieldInputInstance = ( + val: unknown +): val is MainModelFieldInputInstance => + zMainModelFieldInputInstance.safeParse(val).success; +export const isMainModelFieldInputTemplate = ( + val: unknown +): val is MainModelFieldInputTemplate => + zMainModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region SDXLMainModelField +export const zSDXLMainModelFieldType = zFieldTypeBase.extend({ + name: z.literal('SDXLMainModelField'), +}); +export const zSDXLMainModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL models only. +export const zSDXLMainModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zSDXLMainModelFieldType, + value: zSDXLMainModelFieldValue, +}); +export const zSDXLMainModelFieldOutputInstance = + zFieldOutputInstanceBase.extend({ + type: zSDXLMainModelFieldType, + }); +export const zSDXLMainModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zSDXLMainModelFieldType, + default: zSDXLMainModelFieldValue, +}); +export const zSDXLMainModelFieldOutputTemplate = + zFieldOutputTemplateBase.extend({ + type: zSDXLMainModelFieldType, + }); +export type SDXLMainModelFieldType = z.infer; +export type SDXLMainModelFieldValue = z.infer; +export type SDXLMainModelFieldInputInstance = z.infer< + typeof zSDXLMainModelFieldInputInstance +>; +export type SDXLMainModelFieldOutputInstance = z.infer< + typeof zSDXLMainModelFieldOutputInstance +>; +export type SDXLMainModelFieldInputTemplate = z.infer< + typeof zSDXLMainModelFieldInputTemplate +>; +export type SDXLMainModelFieldOutputTemplate = z.infer< + typeof zSDXLMainModelFieldOutputTemplate +>; +export const isSDXLMainModelFieldInputInstance = ( + val: unknown +): val is SDXLMainModelFieldInputInstance => + zSDXLMainModelFieldInputInstance.safeParse(val).success; +export const isSDXLMainModelFieldInputTemplate = ( + val: unknown +): val is SDXLMainModelFieldInputTemplate => + zSDXLMainModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region SDXLRefinerModelField +export const zSDXLRefinerModelFieldType = zFieldTypeBase.extend({ + name: z.literal('SDXLRefinerModelField'), +}); +export const zSDXLRefinerModelFieldValue = zMainModelFieldValue; // TODO: Narrow to SDXL Refiner models only. +export const zSDXLRefinerModelFieldInputInstance = + zFieldInputInstanceBase.extend({ + type: zSDXLRefinerModelFieldType, + value: zSDXLRefinerModelFieldValue, + }); +export const zSDXLRefinerModelFieldOutputInstance = + zFieldOutputInstanceBase.extend({ + type: zSDXLRefinerModelFieldType, + }); +export const zSDXLRefinerModelFieldInputTemplate = + zFieldInputTemplateBase.extend({ + type: zSDXLRefinerModelFieldType, + default: zSDXLRefinerModelFieldValue, + }); +export const zSDXLRefinerModelFieldOutputTemplate = + zFieldOutputTemplateBase.extend({ + type: zSDXLRefinerModelFieldType, + }); +export type SDXLRefinerModelFieldType = z.infer< + typeof zSDXLRefinerModelFieldType +>; +export type SDXLRefinerModelFieldValue = z.infer< + typeof zSDXLRefinerModelFieldValue +>; +export type SDXLRefinerModelFieldInputInstance = z.infer< + typeof zSDXLRefinerModelFieldInputInstance +>; +export type SDXLRefinerModelFieldOutputInstance = z.infer< + typeof zSDXLRefinerModelFieldOutputInstance +>; +export type SDXLRefinerModelFieldInputTemplate = z.infer< + typeof zSDXLRefinerModelFieldInputTemplate +>; +export type SDXLRefinerModelFieldOutputTemplate = z.infer< + typeof zSDXLRefinerModelFieldOutputTemplate +>; +export const isSDXLRefinerModelFieldInputInstance = ( + val: unknown +): val is SDXLRefinerModelFieldInputInstance => + zSDXLRefinerModelFieldInputInstance.safeParse(val).success; +export const isSDXLRefinerModelFieldInputTemplate = ( + val: unknown +): val is SDXLRefinerModelFieldInputTemplate => + zSDXLRefinerModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region VAEModelField +export const zVAEModelFieldType = zFieldTypeBase.extend({ + name: z.literal('VAEModelField'), +}); +export const zVAEModelFieldValue = zVAEModelField.optional(); +export const zVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zVAEModelFieldType, + value: zVAEModelFieldValue, +}); +export const zVAEModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zVAEModelFieldType, +}); +export const zVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zVAEModelFieldType, + default: zVAEModelFieldValue, +}); +export const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zVAEModelFieldType, +}); +export type VAEModelFieldType = z.infer; +export type VAEModelFieldValue = z.infer; +export type VAEModelFieldInputInstance = z.infer< + typeof zVAEModelFieldInputInstance +>; +export type VAEModelFieldOutputInstance = z.infer< + typeof zVAEModelFieldOutputInstance +>; +export type VAEModelFieldInputTemplate = z.infer< + typeof zVAEModelFieldInputTemplate +>; +export type VAEModelFieldOutputTemplate = z.infer< + typeof zVAEModelFieldOutputTemplate +>; +export const isVAEModelFieldInputInstance = ( + val: unknown +): val is VAEModelFieldInputInstance => + zVAEModelFieldInputInstance.safeParse(val).success; +export const isVAEModelFieldInputTemplate = ( + val: unknown +): val is VAEModelFieldInputTemplate => + zVAEModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region LoRAModelField +export const zLoRAModelFieldType = zFieldTypeBase.extend({ + name: z.literal('LoRAModelField'), +}); +export const zLoRAModelFieldValue = zLoRAModelField.optional(); +export const zLoRAModelFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zLoRAModelFieldType, + value: zLoRAModelFieldValue, +}); +export const zLoRAModelFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zLoRAModelFieldType, +}); +export const zLoRAModelFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zLoRAModelFieldType, + default: zLoRAModelFieldValue, +}); +export const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zLoRAModelFieldType, +}); +export type LoRAModelFieldType = z.infer; +export type LoRAModelFieldValue = z.infer; +export type LoRAModelFieldInputInstance = z.infer< + typeof zLoRAModelFieldInputInstance +>; +export type LoRAModelFieldOutputInstance = z.infer< + typeof zLoRAModelFieldOutputInstance +>; +export type LoRAModelFieldInputTemplate = z.infer< + typeof zLoRAModelFieldInputTemplate +>; +export type LoRAModelFieldOutputTemplate = z.infer< + typeof zLoRAModelFieldOutputTemplate +>; +export const isLoRAModelFieldInputInstance = ( + val: unknown +): val is LoRAModelFieldInputInstance => + zLoRAModelFieldInputInstance.safeParse(val).success; +export const isLoRAModelFieldInputTemplate = ( + val: unknown +): val is LoRAModelFieldInputTemplate => + zLoRAModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region ControlNetModelField +export const zControlNetModelFieldType = zFieldTypeBase.extend({ + name: z.literal('ControlNetModelField'), +}); +export const zControlNetModelFieldValue = zControlNetModelField.optional(); +export const zControlNetModelFieldInputInstance = + zFieldInputInstanceBase.extend({ + type: zControlNetModelFieldType, + value: zControlNetModelFieldValue, + }); +export const zControlNetModelFieldOutputInstance = + zFieldOutputInstanceBase.extend({ + type: zControlNetModelFieldType, + }); +export const zControlNetModelFieldInputTemplate = + zFieldInputTemplateBase.extend({ + type: zControlNetModelFieldType, + default: zControlNetModelFieldValue, + }); +export const zControlNetModelFieldOutputTemplate = + zFieldOutputTemplateBase.extend({ + type: zControlNetModelFieldType, + }); +export type ControlNetModelFieldType = z.infer< + typeof zControlNetModelFieldType +>; +export type ControlNetModelFieldValue = z.infer< + typeof zControlNetModelFieldValue +>; +export type ControlNetModelFieldInputInstance = z.infer< + typeof zControlNetModelFieldInputInstance +>; +export type ControlNetModelFieldOutputInstance = z.infer< + typeof zControlNetModelFieldOutputInstance +>; +export type ControlNetModelFieldInputTemplate = z.infer< + typeof zControlNetModelFieldInputTemplate +>; +export type ControlNetModelFieldOutputTemplate = z.infer< + typeof zControlNetModelFieldOutputTemplate +>; +export const isControlNetModelFieldInputInstance = ( + val: unknown +): val is ControlNetModelFieldInputInstance => + zControlNetModelFieldInputInstance.safeParse(val).success; +export const isControlNetModelFieldInputTemplate = ( + val: unknown +): val is ControlNetModelFieldInputTemplate => + zControlNetModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region IPAdapterModelField +export const zIPAdapterModelFieldType = zFieldTypeBase.extend({ + name: z.literal('IPAdapterModelField'), +}); +export const zIPAdapterModelFieldValue = zIPAdapterModelField.optional(); +export const zIPAdapterModelFieldInputInstance = zFieldInputInstanceBase.extend( + { + type: zIPAdapterModelFieldType, + value: zIPAdapterModelFieldValue, + } +); +export const zIPAdapterModelFieldOutputInstance = + zFieldOutputInstanceBase.extend({ + type: zIPAdapterModelFieldType, + }); +export const zIPAdapterModelFieldInputTemplate = zFieldInputTemplateBase.extend( + { type: zIPAdapterModelFieldType, default: zIPAdapterModelFieldValue } +); +export const zIPAdapterModelFieldOutputTemplate = + zFieldOutputTemplateBase.extend({ + type: zIPAdapterModelFieldType, + }); +export type IPAdapterModelFieldType = z.infer; +export type IPAdapterModelFieldValue = z.infer< + typeof zIPAdapterModelFieldValue +>; +export type IPAdapterModelFieldInputInstance = z.infer< + typeof zIPAdapterModelFieldInputInstance +>; +export type IPAdapterModelFieldOutputInstance = z.infer< + typeof zIPAdapterModelFieldOutputInstance +>; +export type IPAdapterModelFieldInputTemplate = z.infer< + typeof zIPAdapterModelFieldInputTemplate +>; +export type IPAdapterModelFieldOutputTemplate = z.infer< + typeof zIPAdapterModelFieldOutputTemplate +>; +export const isIPAdapterModelFieldInputInstance = ( + val: unknown +): val is IPAdapterModelFieldInputInstance => + zIPAdapterModelFieldInputInstance.safeParse(val).success; +export const isIPAdapterModelFieldInputTemplate = ( + val: unknown +): val is IPAdapterModelFieldInputTemplate => + zIPAdapterModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region T2IAdapterField +export const zT2IAdapterModelFieldType = zFieldTypeBase.extend({ + name: z.literal('T2IAdapterModelField'), +}); +export const zT2IAdapterModelFieldValue = zT2IAdapterModelField.optional(); +export const zT2IAdapterModelFieldInputInstance = + zFieldInputInstanceBase.extend({ + type: zT2IAdapterModelFieldType, + value: zT2IAdapterModelFieldValue, + }); +export const zT2IAdapterModelFieldOutputInstance = + zFieldOutputInstanceBase.extend({ + type: zT2IAdapterModelFieldType, + }); +export const zT2IAdapterModelFieldInputTemplate = + zFieldInputTemplateBase.extend({ + type: zT2IAdapterModelFieldType, + default: zT2IAdapterModelFieldValue, + }); +export const zT2IAdapterModelFieldOutputTemplate = + zFieldOutputTemplateBase.extend({ + type: zT2IAdapterModelFieldType, + }); +export type T2IAdapterModelFieldType = z.infer< + typeof zT2IAdapterModelFieldType +>; +export type T2IAdapterModelFieldValue = z.infer< + typeof zT2IAdapterModelFieldValue +>; +export type T2IAdapterModelFieldInputInstance = z.infer< + typeof zT2IAdapterModelFieldInputInstance +>; +export type T2IAdapterModelFieldOutputInstance = z.infer< + typeof zT2IAdapterModelFieldOutputInstance +>; +export type T2IAdapterModelFieldInputTemplate = z.infer< + typeof zT2IAdapterModelFieldInputTemplate +>; +export type T2IAdapterModelFieldOutputTemplate = z.infer< + typeof zT2IAdapterModelFieldOutputTemplate +>; +export const isT2IAdapterModelFieldInputInstance = ( + val: unknown +): val is T2IAdapterModelFieldInputInstance => + zT2IAdapterModelFieldInputInstance.safeParse(val).success; +export const isT2IAdapterModelFieldInputTemplate = ( + val: unknown +): val is T2IAdapterModelFieldInputTemplate => + zT2IAdapterModelFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region SchedulerField +export const zSchedulerFieldType = zFieldTypeBase.extend({ + name: z.literal('SchedulerField'), +}); +export const zSchedulerFieldValue = zSchedulerField.optional(); +export const zSchedulerFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zSchedulerFieldType, + value: zSchedulerFieldValue, +}); +export const zSchedulerFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zSchedulerFieldType, +}); +export const zSchedulerFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zSchedulerFieldType, + default: zSchedulerFieldValue, +}); +export const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zSchedulerFieldType, +}); +export type SchedulerFieldType = z.infer; +export type SchedulerFieldValue = z.infer; +export type SchedulerFieldInputInstance = z.infer< + typeof zSchedulerFieldInputInstance +>; +export type SchedulerFieldOutputInstance = z.infer< + typeof zSchedulerFieldOutputInstance +>; +export type SchedulerFieldInputTemplate = z.infer< + typeof zSchedulerFieldInputTemplate +>; +export type SchedulerFieldOutputTemplate = z.infer< + typeof zSchedulerFieldOutputTemplate +>; +export const isSchedulerFieldInputInstance = ( + val: unknown +): val is SchedulerFieldInputInstance => + zSchedulerFieldInputInstance.safeParse(val).success; +export const isSchedulerFieldInputTemplate = ( + val: unknown +): val is SchedulerFieldInputTemplate => + zSchedulerFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region StatelessField +/** + * StatelessField is a catchall for stateless fields with no UI input components. They do not + * do not support "direct" input, instead only accepting connections from other fields. + * + * This field type serves as a "generic" field type. + * + * Examples include: + * - Fields like UNetField or LatentsField where we do not allow direct UI input + * - Reserved fields like IsIntermediate + * - Any other field we don't have full-on schemas for + */ +export const zStatelessFieldType = zFieldTypeBase.extend({ + name: z.string().min(1), // stateless --> we accept the field's name as the type +}); +export const zStatelessFieldValue = z.undefined().catch(undefined); // stateless --> no value, but making this z.never() introduces a lot of extra TS fanagling +export const zStatelessFieldInputInstance = zFieldInputInstanceBase.extend({ + type: zStatelessFieldType, + value: zStatelessFieldValue, +}); +export const zStatelessFieldOutputInstance = zFieldOutputInstanceBase.extend({ + type: zStatelessFieldType, +}); +export const zStatelessFieldInputTemplate = zFieldInputTemplateBase.extend({ + type: zStatelessFieldType, + default: zStatelessFieldValue, + input: z.literal('connection'), // stateless --> only accepts connection inputs +}); +export const zStatelessFieldOutputTemplate = zFieldOutputTemplateBase.extend({ + type: zStatelessFieldType, +}); + +export type StatelessFieldType = z.infer; +export type StatelessFieldValue = z.infer; +export type StatelessFieldInputInstance = z.infer< + typeof zStatelessFieldInputInstance +>; +export type StatelessFieldOutputInstance = z.infer< + typeof zStatelessFieldOutputInstance +>; +export type StatelessFieldInputTemplate = z.infer< + typeof zStatelessFieldInputTemplate +>; +export type StatelessFieldOutputTemplate = z.infer< + typeof zStatelessFieldOutputTemplate +>; +// #endregion + +/** + * Here we define the main field unions: + * - FieldType + * - FieldValue + * - FieldInputInstance + * - FieldOutputInstance + * - FieldInputTemplate + * - FieldOutputTemplate + * + * All stateful fields are unioned together, and then that union is unioned with StatelessField. + * + * This allows us to interact with stateful fields without needing to worry about "generic" handling + * for all other StatelessFields. + */ + +// #region StatefulFieldType & FieldType +export const zStatefulFieldType = z.union([ + zIntegerFieldType, + zFloatFieldType, + zStringFieldType, + zBooleanFieldType, + zEnumFieldType, + zImageFieldType, + zBoardFieldType, + zMainModelFieldType, + zSDXLMainModelFieldType, + zSDXLRefinerModelFieldType, + zVAEModelFieldType, + zLoRAModelFieldType, + zControlNetModelFieldType, + zIPAdapterModelFieldType, + zT2IAdapterModelFieldType, + zColorFieldType, + zSchedulerFieldType, +]); +export type StatefulFieldType = z.infer; +export const isStatefulFieldType = (val: unknown): val is StatefulFieldType => + zStatefulFieldType.safeParse(val).success; + +export const zFieldType = z.union([zStatefulFieldType, zStatelessFieldType]); +export type FieldType = z.infer; +export const isFieldType = (val: unknown): val is FieldType => + zFieldType.safeParse(val).success; +// #endregion + +// #region StatefulFieldValue & FieldValue +export const zStatefulFieldValue = z.union([ + zIntegerFieldValue, + zFloatFieldValue, + zStringFieldValue, + zBooleanFieldValue, + zEnumFieldValue, + zImageFieldValue, + zBoardFieldValue, + zMainModelFieldValue, + zSDXLMainModelFieldValue, + zSDXLRefinerModelFieldValue, + zVAEModelFieldValue, + zLoRAModelFieldValue, + zControlNetModelFieldValue, + zIPAdapterModelFieldValue, + zT2IAdapterModelFieldValue, + zColorFieldValue, + zSchedulerFieldValue, +]); +export type StatefulFieldValue = z.infer; +export const isStatefulFieldValue = (val: unknown): val is StatefulFieldValue => + zStatefulFieldValue.safeParse(val).success; + +export const zFieldValue = z.union([zStatefulFieldValue, zStatelessFieldValue]); +export type FieldValue = z.infer; +export const isFieldValue = (val: unknown): val is FieldValue => + zFieldValue.safeParse(val).success; +// #endregion + +// #region StatefulFieldInputInstance & FieldInputInstance +export const zStatefulFieldInputInstance = z.union([ + zIntegerFieldInputInstance, + zFloatFieldInputInstance, + zStringFieldInputInstance, + zBooleanFieldInputInstance, + zEnumFieldInputInstance, + zImageFieldInputInstance, + zBoardFieldInputInstance, + zMainModelFieldInputInstance, + zSDXLMainModelFieldInputInstance, + zSDXLRefinerModelFieldInputInstance, + zVAEModelFieldInputInstance, + zLoRAModelFieldInputInstance, + zControlNetModelFieldInputInstance, + zIPAdapterModelFieldInputInstance, + zT2IAdapterModelFieldInputInstance, + zColorFieldInputInstance, + zSchedulerFieldInputInstance, +]); +export type StatefulFieldInputInstance = z.infer< + typeof zStatefulFieldInputInstance +>; +export const isStatefulFieldInputInstance = ( + val: unknown +): val is StatefulFieldInputInstance => + zStatefulFieldInputInstance.safeParse(val).success; + +export const zFieldInputInstance = z.union([ + zStatefulFieldInputInstance, + zStatelessFieldInputInstance, +]); +export type FieldInputInstance = z.infer; +export const isFieldInputInstance = (val: unknown): val is FieldInputInstance => + zFieldInputInstance.safeParse(val).success; +// #endregion + +// #region StatefulFieldOutputInstance & FieldOutputInstance +export const zStatefulFieldOutputInstance = z.union([ + zIntegerFieldOutputInstance, + zFloatFieldOutputInstance, + zStringFieldOutputInstance, + zBooleanFieldOutputInstance, + zEnumFieldOutputInstance, + zImageFieldOutputInstance, + zBoardFieldOutputInstance, + zMainModelFieldOutputInstance, + zSDXLMainModelFieldOutputInstance, + zSDXLRefinerModelFieldOutputInstance, + zVAEModelFieldOutputInstance, + zLoRAModelFieldOutputInstance, + zControlNetModelFieldOutputInstance, + zIPAdapterModelFieldOutputInstance, + zT2IAdapterModelFieldOutputInstance, + zColorFieldOutputInstance, + zSchedulerFieldOutputInstance, +]); +export type StatefulFieldOutputInstance = z.infer< + typeof zStatefulFieldOutputInstance +>; +export const isStatefulFieldOutputInstance = ( + val: unknown +): val is StatefulFieldOutputInstance => + zStatefulFieldOutputInstance.safeParse(val).success; + +export const zFieldOutputInstance = z.union([ + zStatefulFieldOutputInstance, + zStatelessFieldOutputInstance, +]); +export type FieldOutputInstance = z.infer; +export const isFieldOutputInstance = ( + val: unknown +): val is FieldOutputInstance => zFieldOutputInstance.safeParse(val).success; +// #endregion + +// #region StatefulFieldInputTemplate & FieldInputTemplate +export const zStatefulFieldInputTemplate = z.union([ + zIntegerFieldInputTemplate, + zFloatFieldInputTemplate, + zStringFieldInputTemplate, + zBooleanFieldInputTemplate, + zEnumFieldInputTemplate, + zImageFieldInputTemplate, + zBoardFieldInputTemplate, + zMainModelFieldInputTemplate, + zSDXLMainModelFieldInputTemplate, + zSDXLRefinerModelFieldInputTemplate, + zVAEModelFieldInputTemplate, + zLoRAModelFieldInputTemplate, + zControlNetModelFieldInputTemplate, + zIPAdapterModelFieldInputTemplate, + zT2IAdapterModelFieldInputTemplate, + zColorFieldInputTemplate, + zSchedulerFieldInputTemplate, + zStatelessFieldInputTemplate, +]); +export type StatefulFieldInputTemplate = z.infer; +export const isStatefulFieldInputTemplate = ( + val: unknown +): val is StatefulFieldInputTemplate => + zStatefulFieldInputTemplate.safeParse(val).success; + +export const zFieldInputTemplate = z.union([ + zStatefulFieldInputTemplate, + zStatelessFieldInputTemplate, +]); +export type FieldInputTemplate = z.infer; +export const isFieldInputTemplate = (val: unknown): val is FieldInputTemplate => + zFieldInputTemplate.safeParse(val).success; +// #endregion + +// #region StatefulFieldOutputTemplate & FieldOutputTemplate +export const zStatefulFieldOutputTemplate = z.union([ + zIntegerFieldOutputTemplate, + zFloatFieldOutputTemplate, + zStringFieldOutputTemplate, + zBooleanFieldOutputTemplate, + zEnumFieldOutputTemplate, + zImageFieldOutputTemplate, + zBoardFieldOutputTemplate, + zMainModelFieldOutputTemplate, + zSDXLMainModelFieldOutputTemplate, + zSDXLRefinerModelFieldOutputTemplate, + zVAEModelFieldOutputTemplate, + zLoRAModelFieldOutputTemplate, + zControlNetModelFieldOutputTemplate, + zIPAdapterModelFieldOutputTemplate, + zT2IAdapterModelFieldOutputTemplate, + zColorFieldOutputTemplate, + zSchedulerFieldOutputTemplate, +]); +export type StatefulFieldOutputTemplate = z.infer< + typeof zStatefulFieldOutputTemplate +>; +export const isStatefulFieldOutputTemplate = ( + val: unknown +): val is StatefulFieldOutputTemplate => + zStatefulFieldOutputTemplate.safeParse(val).success; + +export const zFieldOutputTemplate = z.union([ + zStatefulFieldOutputTemplate, + zStatelessFieldOutputTemplate, +]); +export type FieldOutputTemplate = z.infer; +export const isFieldOutputTemplate = ( + val: unknown +): val is FieldOutputTemplate => zFieldOutputTemplate.safeParse(val).success; +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/invocation.ts b/invokeai/frontend/web/src/features/nodes/types/invocation.ts new file mode 100644 index 0000000000..216db437b9 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/invocation.ts @@ -0,0 +1,108 @@ +import { Node } from 'reactflow'; +import { z } from 'zod'; +import { zProgressImage } from './common'; +import { + zFieldInputInstance, + zFieldInputTemplate, + zFieldOutputInstance, + zFieldOutputTemplate, +} from './field'; +import { zSemVer } from './semver'; + +// #region InvocationTemplate +export const zInvocationTemplate = z.object({ + type: z.string(), + title: z.string(), + description: z.string(), + tags: z.array(z.string().min(1)), + inputs: z.record(zFieldInputTemplate), + outputs: z.record(zFieldOutputTemplate), + outputType: z.string().min(1), + withWorkflow: z.boolean(), + version: zSemVer, + useCache: z.boolean(), +}); +export type InvocationTemplate = z.infer; +// #endregion + +// #region NodeData +export const zInvocationNodeData = z.object({ + id: z.string().trim().min(1), + type: z.string().trim().min(1), + label: z.string(), + isOpen: z.boolean(), + notes: z.string(), + embedWorkflow: z.boolean(), + isIntermediate: z.boolean(), + useCache: z.boolean(), + version: zSemVer, + inputs: z.record(zFieldInputInstance), + outputs: z.record(zFieldOutputInstance), +}); + +export const zNotesNodeData = z.object({ + id: z.string().trim().min(1), + type: z.literal('notes'), + label: z.string(), + isOpen: z.boolean(), + notes: z.string(), +}); +export const zCurrentImageNodeData = z.object({ + id: z.string().trim().min(1), + type: z.literal('current_image'), + label: z.string(), + isOpen: z.boolean(), +}); +export const zAnyNodeData = z.union([ + zInvocationNodeData, + zNotesNodeData, + zCurrentImageNodeData, +]); + +export type NotesNodeData = z.infer; +export type InvocationNodeData = z.infer; +export type CurrentImageNodeData = z.infer; +export type AnyNodeData = z.infer; + +export const isInvocationNode = ( + node?: Node +): node is Node => + Boolean(node && node.type === 'invocation'); +export const isNotesNode = ( + node?: Node +): node is Node => Boolean(node && node.type === 'notes'); +export const isProgressImageNode = ( + node?: Node +): node is Node => + Boolean(node && node.type === 'current_image'); +export const isInvocationNodeData = ( + node?: AnyNodeData +): node is InvocationNodeData => + Boolean(node && !['notes', 'current_image'].includes(node.type)); // node.type may be 'notes', 'current_image', or any invocation type +// #endregion + +// #region NodeExecutionState +export const zNodeStatus = z.enum([ + 'PENDING', + 'IN_PROGRESS', + 'COMPLETED', + 'FAILED', +]); +export const zNodeExecutionState = z.object({ + nodeId: z.string().trim().min(1), + status: zNodeStatus, + progress: z.number().nullable(), + progressImage: zProgressImage.nullable(), + error: z.string().nullable(), + outputs: z.array(z.any()), +}); +export type NodeExecutionState = z.infer; +export type NodeStatus = z.infer; +// #endregion + +// #region Edges +export const zInvocationEdgeExtra = z.object({ + type: z.union([z.literal('default'), z.literal('collapsed')]), +}); +export type InvocationEdgeExtra = z.infer; +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/metadata.ts b/invokeai/frontend/web/src/features/nodes/types/metadata.ts new file mode 100644 index 0000000000..a22b8aed0e --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/metadata.ts @@ -0,0 +1,81 @@ +import { z } from 'zod'; +import { + zControlField, + zIPAdapterField, + zLoRAModelField, + zMainModelField, + zONNXModelField, + zSDXLRefinerModelField, + zT2IAdapterField, + zVAEModelField, +} from './common'; + +// #region Metadata-optimized versions of schemas +// TODO: It's possible that `deepPartial` will be deprecated: +// - https://github.com/colinhacks/zod/issues/2106 +// - https://github.com/colinhacks/zod/issues/2854 +export const zLoRAMetadataItem = z.object({ + lora: zLoRAModelField.deepPartial(), + weight: z.number(), +}); +const zControlNetMetadataItem = zControlField.deepPartial(); +const zIPAdapterMetadataItem = zIPAdapterField.deepPartial(); +const zT2IAdapterMetadataItem = zT2IAdapterField.deepPartial(); +const zSDXLRefinerModelMetadataItem = zSDXLRefinerModelField.deepPartial(); +const zModelMetadataitem = z.union([ + zMainModelField.deepPartial(), + zONNXModelField.deepPartial(), +]); +const zVAEModelMetadataItem = zVAEModelField.deepPartial(); +export type LoRAMetadataItem = z.infer; +export type ControlNetMetadataItem = z.infer; +export type IPAdapterMetadataItem = z.infer; +export type T2IAdapterMetadataItem = z.infer; +export type SDXLRefinerModelMetadataItem = z.infer< + typeof zSDXLRefinerModelMetadataItem +>; +export type ModelMetadataitem = z.infer; +export type VAEModelMetadataItem = z.infer; +// #endregion + +// #region CoreMetadata +export const zCoreMetadata = z + .object({ + app_version: z.string().nullish().catch(null), + generation_mode: z.string().nullish().catch(null), + created_by: z.string().nullish().catch(null), + positive_prompt: z.string().nullish().catch(null), + negative_prompt: z.string().nullish().catch(null), + width: z.number().int().nullish().catch(null), + height: z.number().int().nullish().catch(null), + seed: z.number().int().nullish().catch(null), + rand_device: z.string().nullish().catch(null), + cfg_scale: z.number().nullish().catch(null), + steps: z.number().int().nullish().catch(null), + scheduler: z.string().nullish().catch(null), + clip_skip: z.number().int().nullish().catch(null), + model: zModelMetadataitem.nullish().catch(null), + controlnets: z.array(zControlNetMetadataItem).nullish().catch(null), + ipAdapters: z.array(zIPAdapterMetadataItem).nullish().catch(null), + t2iAdapters: z.array(zT2IAdapterMetadataItem).nullish().catch(null), + loras: z.array(zLoRAMetadataItem).nullish().catch(null), + vae: zVAEModelMetadataItem.nullish().catch(null), + strength: z.number().nullish().catch(null), + hrf_enabled: z.boolean().nullish().catch(null), + hrf_strength: z.number().nullish().catch(null), + hrf_method: z.string().nullish().catch(null), + init_image: z.string().nullish().catch(null), + positive_style_prompt: z.string().nullish().catch(null), + negative_style_prompt: z.string().nullish().catch(null), + refiner_model: zSDXLRefinerModelMetadataItem.nullish().catch(null), + refiner_cfg_scale: z.number().nullish().catch(null), + refiner_steps: z.number().int().nullish().catch(null), + refiner_scheduler: z.string().nullish().catch(null), + refiner_positive_aesthetic_score: z.number().nullish().catch(null), + refiner_negative_aesthetic_score: z.number().nullish().catch(null), + refiner_start: z.number().nullish().catch(null), + }) + .passthrough(); +export type CoreMetadata = z.infer; + +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/types/migration/migrations.ts b/invokeai/frontend/web/src/features/nodes/types/migration/migrations.ts new file mode 100644 index 0000000000..45c3852493 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/migration/migrations.ts @@ -0,0 +1,69 @@ +import { forEach, isString } from 'lodash-es'; +import { z } from 'zod'; +import { WorkflowVersionError } from '../error'; +import { zSemVer } from '../semver'; +import { WorkflowV2, zWorkflowV2 } from '../workflow'; +import { FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING } from './v1/fieldTypeMap'; +import { WorkflowV1, zWorkflowV1 } from './v1/workflowV1'; +import { t } from 'i18next'; + +/** + * Helper schema to extract the version from a workflow. + * + * All properties except for the version are ignored in this schema. + */ +const zWorkflowMetaVersion = z.object({ + meta: z.object({ version: zSemVer }), +}); + +/** + * Migrates a workflow from V1 to V2. + */ +const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => { + workflowToMigrate.nodes.forEach((node) => { + if (node.type === 'invocation') { + forEach(node.data.inputs, (input) => { + if (!isString(input.type)) { + return; + } + (input.type as unknown) = + FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING[input.type]; + }); + forEach(node.data.outputs, (output) => { + if (!isString(output.type)) { + return; + } + (output.type as unknown) = + FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING[output.type]; + }); + } + }); + (workflowToMigrate.meta.version as WorkflowV2['meta']['version']) = '2.0.0'; + return zWorkflowV2.parse(workflowToMigrate); +}; + +/** + * Parses a workflow and migrates it to the latest version if necessary. + */ +export const parseAndMigrateWorkflow = (data: unknown): WorkflowV2 => { + const workflowVersionResult = zWorkflowMetaVersion.safeParse(data); + + if (!workflowVersionResult.success) { + throw new WorkflowVersionError(t('nodes.unableToGetWorkflowVersion')); + } + + const { version } = workflowVersionResult.data.meta; + + if (version === '1.0.0') { + const v1 = zWorkflowV1.parse(data); + return migrateV1toV2(v1); + } + + if (version === '2.0.0') { + return zWorkflowV2.parse(data); + } + + throw new WorkflowVersionError( + t('nodes.unrecognizedWorkflowVersion', { version }) + ); +}; diff --git a/invokeai/frontend/web/src/features/nodes/types/migration/v1/fieldTypeMap.ts b/invokeai/frontend/web/src/features/nodes/types/migration/v1/fieldTypeMap.ts new file mode 100644 index 0000000000..facf015b02 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/migration/v1/fieldTypeMap.ts @@ -0,0 +1,270 @@ +import { FieldType, StatefulFieldType } from '../../field'; +import { FieldTypeV1 } from './workflowV1'; + +/** + * Mapping of V1 field type strings to their *stateful* V2 field type counterparts. + */ +const FIELD_TYPE_V1_TO_STATEFUL_FIELD_TYPE_V2: { + [key in FieldTypeV1]?: StatefulFieldType; +} = { + BoardField: { name: 'BoardField', isCollection: false, isPolymorphic: false }, + boolean: { name: 'BooleanField', isCollection: false, isPolymorphic: false }, + BooleanCollection: { + name: 'BooleanField', + isCollection: true, + isPolymorphic: false, + }, + BooleanPolymorphic: { + name: 'BooleanField', + isCollection: false, + isPolymorphic: true, + }, + ColorField: { name: 'ColorField', isCollection: false, isPolymorphic: false }, + ColorCollection: { + name: 'ColorField', + isCollection: true, + isPolymorphic: false, + }, + ColorPolymorphic: { + name: 'ColorField', + isCollection: false, + isPolymorphic: true, + }, + ControlNetModelField: { + name: 'ControlNetModelField', + isCollection: false, + isPolymorphic: false, + }, + enum: { name: 'EnumField', isCollection: false, isPolymorphic: false }, + float: { name: 'FloatField', isCollection: false, isPolymorphic: false }, + FloatCollection: { + name: 'FloatField', + isCollection: true, + isPolymorphic: false, + }, + FloatPolymorphic: { + name: 'FloatField', + isCollection: false, + isPolymorphic: true, + }, + ImageCollection: { + name: 'ImageField', + isCollection: true, + isPolymorphic: false, + }, + ImageField: { name: 'ImageField', isCollection: false, isPolymorphic: false }, + ImagePolymorphic: { + name: 'ImageField', + isCollection: false, + isPolymorphic: true, + }, + integer: { name: 'IntegerField', isCollection: false, isPolymorphic: false }, + IntegerCollection: { + name: 'IntegerField', + isCollection: true, + isPolymorphic: false, + }, + IntegerPolymorphic: { + name: 'IntegerField', + isCollection: false, + isPolymorphic: true, + }, + IPAdapterModelField: { + name: 'IPAdapterModelField', + isCollection: false, + isPolymorphic: false, + }, + LoRAModelField: { + name: 'LoRAModelField', + isCollection: false, + isPolymorphic: false, + }, + MainModelField: { + name: 'MainModelField', + isCollection: false, + isPolymorphic: false, + }, + Scheduler: { + name: 'SchedulerField', + isCollection: false, + isPolymorphic: false, + }, + SDXLMainModelField: { + name: 'SDXLMainModelField', + isCollection: false, + isPolymorphic: false, + }, + SDXLRefinerModelField: { + name: 'SDXLRefinerModelField', + isCollection: false, + isPolymorphic: false, + }, + string: { name: 'StringField', isCollection: false, isPolymorphic: false }, + StringCollection: { + name: 'StringField', + isCollection: true, + isPolymorphic: false, + }, + StringPolymorphic: { + name: 'StringField', + isCollection: false, + isPolymorphic: true, + }, + T2IAdapterModelField: { + name: 'T2IAdapterModelField', + isCollection: false, + isPolymorphic: false, + }, + VaeModelField: { + name: 'VAEModelField', + isCollection: false, + isPolymorphic: false, + }, +}; + +/** + * Mapping of V1 field type strings to their *stateless* V2 field type counterparts. + * + * The type doesn't do what I want it to do. + * + * Ideally, the value of each propery would be a `FieldType` where `FieldType['name']` is not in + * `StatefulFieldType['name']`, but this is hard to represent. That's because `FieldType['name']` is + * actually widened to `string`, and TS's `Exclude` doesn't work on `string`. + * + * There's probably some way to do it with conditionals and intersections but I can't figure it out. + * + * Thus, this object was manually edited to ensure it is correct. + */ +const FIELD_TYPE_V1_TO_STATELESS_FIELD_TYPE_V2: { + [key in FieldTypeV1]?: FieldType; +} = { + Any: { name: 'AnyField', isCollection: false, isPolymorphic: false }, + ClipField: { name: 'ClipField', isCollection: false, isPolymorphic: false }, + Collection: { + name: 'CollectionField', + isCollection: true, + isPolymorphic: false, + }, + CollectionItem: { + name: 'CollectionItemField', + isCollection: false, + isPolymorphic: false, + }, + ConditioningCollection: { + name: 'ConditioningField', + isCollection: true, + isPolymorphic: false, + }, + ConditioningField: { + name: 'ConditioningField', + isCollection: false, + isPolymorphic: false, + }, + ConditioningPolymorphic: { + name: 'ConditioningField', + isCollection: false, + isPolymorphic: true, + }, + ControlCollection: { + name: 'ControlField', + isCollection: true, + isPolymorphic: false, + }, + ControlField: { + name: 'ControlField', + isCollection: false, + isPolymorphic: false, + }, + ControlPolymorphic: { + name: 'ControlField', + isCollection: false, + isPolymorphic: true, + }, + DenoiseMaskField: { + name: 'DenoiseMaskField', + isCollection: false, + isPolymorphic: false, + }, + IPAdapterField: { + name: 'IPAdapterField', + isCollection: false, + isPolymorphic: false, + }, + IPAdapterCollection: { + name: 'IPAdapterField', + isCollection: true, + isPolymorphic: false, + }, + IPAdapterPolymorphic: { + name: 'IPAdapterField', + isCollection: false, + isPolymorphic: true, + }, + LatentsField: { + name: 'LatentsField', + isCollection: false, + isPolymorphic: false, + }, + LatentsCollection: { + name: 'LatentsField', + isCollection: true, + isPolymorphic: false, + }, + LatentsPolymorphic: { + name: 'LatentsField', + isCollection: false, + isPolymorphic: true, + }, + MetadataField: { + name: 'MetadataField', + isCollection: false, + isPolymorphic: false, + }, + MetadataCollection: { + name: 'MetadataField', + isCollection: true, + isPolymorphic: false, + }, + MetadataItemField: { + name: 'MetadataItemField', + isCollection: false, + isPolymorphic: false, + }, + MetadataItemCollection: { + name: 'MetadataItemField', + isCollection: true, + isPolymorphic: false, + }, + MetadataItemPolymorphic: { + name: 'MetadataItemField', + isCollection: false, + isPolymorphic: true, + }, + ONNXModelField: { + name: 'ONNXModelField', + isCollection: false, + isPolymorphic: false, + }, + T2IAdapterField: { + name: 'T2IAdapterField', + isCollection: false, + isPolymorphic: false, + }, + T2IAdapterCollection: { + name: 'T2IAdapterField', + isCollection: true, + isPolymorphic: false, + }, + T2IAdapterPolymorphic: { + name: 'T2IAdapterField', + isCollection: false, + isPolymorphic: true, + }, + UNetField: { name: 'UNetField', isCollection: false, isPolymorphic: false }, + VaeField: { name: 'VaeField', isCollection: false, isPolymorphic: false }, +}; + +export const FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING = { + ...FIELD_TYPE_V1_TO_STATEFUL_FIELD_TYPE_V2, + ...FIELD_TYPE_V1_TO_STATELESS_FIELD_TYPE_V2, +}; diff --git a/invokeai/frontend/web/src/features/nodes/types/migration/v1/workflowV1.ts b/invokeai/frontend/web/src/features/nodes/types/migration/v1/workflowV1.ts new file mode 100644 index 0000000000..98e4158f9a --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/migration/v1/workflowV1.ts @@ -0,0 +1,711 @@ +import { z } from 'zod'; + +// WorkflowV1 Schema + +const zScheduler = z.enum([ + 'euler', + 'deis', + 'ddim', + 'ddpm', + 'dpmpp_2s', + 'dpmpp_2m', + 'dpmpp_2m_sde', + 'dpmpp_sde', + 'heun', + 'kdpm_2', + 'lms', + 'pndm', + 'unipc', + 'euler_k', + 'dpmpp_2s_k', + 'dpmpp_2m_k', + 'dpmpp_2m_sde_k', + 'dpmpp_sde_k', + 'heun_k', + 'lms_k', + 'euler_a', + 'kdpm_2_a', + 'lcm', +]); +const zBaseModel = z.enum(['any', 'sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']); +const zMainModel = z.object({ + model_name: z.string().min(1), + base_model: zBaseModel, + model_type: z.literal('main'), +}); +const zOnnxModel = z.object({ + model_name: z.string().min(1), + base_model: zBaseModel, + model_type: z.literal('onnx'), +}); + +const zMainOrOnnxModel = z.union([zMainModel, zOnnxModel]); + +// TODO: Get this from the OpenAPI schema? may be tricky... +const zFieldTypeV1 = z.enum([ + 'Any', + 'BoardField', + 'boolean', + 'BooleanCollection', + 'BooleanPolymorphic', + 'ClipField', + 'Collection', + 'CollectionItem', + 'ColorCollection', + 'ColorField', + 'ColorPolymorphic', + 'ConditioningCollection', + 'ConditioningField', + 'ConditioningPolymorphic', + 'ControlCollection', + 'ControlField', + 'ControlNetModelField', + 'ControlPolymorphic', + 'DenoiseMaskField', + 'enum', + 'float', + 'FloatCollection', + 'FloatPolymorphic', + 'ImageCollection', + 'ImageField', + 'ImagePolymorphic', + 'integer', + 'IntegerCollection', + 'IntegerPolymorphic', + 'IPAdapterCollection', + 'IPAdapterField', + 'IPAdapterModelField', + 'IPAdapterPolymorphic', + 'LatentsCollection', + 'LatentsField', + 'LatentsPolymorphic', + 'LoRAModelField', + 'MainModelField', + 'MetadataField', + 'MetadataCollection', + 'MetadataItemField', + 'MetadataItemCollection', + 'MetadataItemPolymorphic', + 'ONNXModelField', + 'Scheduler', + 'SDXLMainModelField', + 'SDXLRefinerModelField', + 'string', + 'StringCollection', + 'StringPolymorphic', + 'T2IAdapterCollection', + 'T2IAdapterField', + 'T2IAdapterModelField', + 'T2IAdapterPolymorphic', + 'UNetField', + 'VaeField', + 'VaeModelField', +]); +export type FieldTypeV1 = z.infer; + +const zFieldValueBase = z.object({ + id: z.string().trim().min(1), + name: z.string().trim().min(1), + type: zFieldTypeV1, +}); + +/** + * An output field is persisted across as part of the user's local state. + * + * An output field has two properties: + * - `id` a unique identifier + * - `name` the name of the field, which comes from the python dataclass + */ + +const zOutputFieldValue = zFieldValueBase.extend({ + fieldKind: z.literal('output'), +}); + +const zInputFieldValueBase = zFieldValueBase.extend({ + fieldKind: z.literal('input'), + label: z.string(), +}); + +const zModelIdentifier = z.object({ + model_name: z.string().trim().min(1), + base_model: zBaseModel, +}); + +const zImageField = z.object({ + image_name: z.string().trim().min(1), +}); + +const zBoardField = z.object({ + board_id: z.string().trim().min(1), +}); + +const zLatentsField = z.object({ + latents_name: z.string().trim().min(1), + seed: z.number().int().optional(), +}); + +const zConditioningField = z.object({ + conditioning_name: z.string().trim().min(1), +}); + +const zDenoiseMaskField = z.object({ + mask_name: z.string().trim().min(1), + masked_latents_name: z.string().trim().min(1).optional(), +}); + +const zIntegerInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('integer'), + value: z.number().int().optional(), +}); + +const zIntegerCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('IntegerCollection'), + value: z.array(z.number().int()).optional(), +}); + +const zIntegerPolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('IntegerPolymorphic'), + value: z.number().int().optional(), +}); + +const zFloatInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('float'), + value: z.number().optional(), +}); + +const zFloatCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('FloatCollection'), + value: z.array(z.number()).optional(), +}); + +const zFloatPolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('FloatPolymorphic'), + value: z.number().optional(), +}); + +const zStringInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('string'), + value: z.string().optional(), +}); + +const zStringCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('StringCollection'), + value: z.array(z.string()).optional(), +}); + +const zStringPolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('StringPolymorphic'), + value: z.string().optional(), +}); + +const zBooleanInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('boolean'), + value: z.boolean().optional(), +}); + +const zBooleanCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('BooleanCollection'), + value: z.array(z.boolean()).optional(), +}); + +const zBooleanPolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('BooleanPolymorphic'), + value: z.boolean().optional(), +}); + +const zEnumInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('enum'), + value: z.string().optional(), +}); + +const zLatentsInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('LatentsField'), + value: zLatentsField.optional(), +}); + +const zLatentsCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('LatentsCollection'), + value: z.array(zLatentsField).optional(), +}); + +const zLatentsPolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('LatentsPolymorphic'), + value: z.union([zLatentsField, z.array(zLatentsField)]).optional(), +}); + +const zDenoiseMaskInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('DenoiseMaskField'), + value: zDenoiseMaskField.optional(), +}); + +const zConditioningInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ConditioningField'), + value: zConditioningField.optional(), +}); + +const zConditioningCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ConditioningCollection'), + value: z.array(zConditioningField).optional(), +}); + +const zConditioningPolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ConditioningPolymorphic'), + value: z.union([zConditioningField, z.array(zConditioningField)]).optional(), +}); + +const zControlNetModel = zModelIdentifier; + +const zControlField = z.object({ + image: zImageField, + control_model: zControlNetModel, + control_weight: z.union([z.number(), z.array(z.number())]).optional(), + begin_step_percent: z.number().optional(), + end_step_percent: z.number().optional(), + control_mode: z + .enum(['balanced', 'more_prompt', 'more_control', 'unbalanced']) + .optional(), + resize_mode: z + .enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple']) + .optional(), +}); + +const zControlInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ControlField'), + value: zControlField.optional(), +}); + +const zControlPolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ControlPolymorphic'), + value: z.union([zControlField, z.array(zControlField)]).optional(), +}); + +const zControlCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ControlCollection'), + value: z.array(zControlField).optional(), +}); + +const zIPAdapterModel = zModelIdentifier; + +const zIPAdapterField = z.object({ + image: zImageField, + ip_adapter_model: zIPAdapterModel, + weight: z.number(), + begin_step_percent: z.number().optional(), + end_step_percent: z.number().optional(), +}); + +const zIPAdapterInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('IPAdapterField'), + value: zIPAdapterField.optional(), +}); + +const zIPAdapterPolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('IPAdapterPolymorphic'), + value: z.union([zIPAdapterField, z.array(zIPAdapterField)]).optional(), +}); + +const zIPAdapterCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('IPAdapterCollection'), + value: z.array(zIPAdapterField).optional(), +}); + +const zT2IAdapterModel = zModelIdentifier; + +const zT2IAdapterField = z.object({ + image: zImageField, + t2i_adapter_model: zT2IAdapterModel, + weight: z.union([z.number(), z.array(z.number())]).optional(), + begin_step_percent: z.number().optional(), + end_step_percent: z.number().optional(), + resize_mode: z + .enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple']) + .optional(), +}); + +const zT2IAdapterInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('T2IAdapterField'), + value: zT2IAdapterField.optional(), +}); + +const zT2IAdapterPolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('T2IAdapterPolymorphic'), + value: z.union([zT2IAdapterField, z.array(zT2IAdapterField)]).optional(), +}); + +const zT2IAdapterCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('T2IAdapterCollection'), + value: z.array(zT2IAdapterField).optional(), +}); + +const zModelType = z.enum([ + 'onnx', + 'main', + 'vae', + 'lora', + 'controlnet', + 'embedding', +]); + +const zSubModelType = z.enum([ + 'unet', + 'text_encoder', + 'text_encoder_2', + 'tokenizer', + 'tokenizer_2', + 'vae', + 'vae_decoder', + 'vae_encoder', + 'scheduler', + 'safety_checker', +]); + +const zModelInfo = zModelIdentifier.extend({ + model_type: zModelType, + submodel: zSubModelType.optional(), +}); + +const zLoraInfo = zModelInfo.extend({ + weight: z.number().optional(), +}); + +const zUNetField = z.object({ + unet: zModelInfo, + scheduler: zModelInfo, + loras: z.array(zLoraInfo), +}); + +const zUNetInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('UNetField'), + value: zUNetField.optional(), +}); + +const zClipField = z.object({ + tokenizer: zModelInfo, + text_encoder: zModelInfo, + skipped_layers: z.number(), + loras: z.array(zLoraInfo), +}); + +const zClipInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ClipField'), + value: zClipField.optional(), +}); + +const zVaeField = z.object({ + vae: zModelInfo, +}); + +const zVaeInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('VaeField'), + value: zVaeField.optional(), +}); + +const zImageInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ImageField'), + value: zImageField.optional(), +}); + +const zBoardInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('BoardField'), + value: zBoardField.optional(), +}); + +const zImagePolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ImagePolymorphic'), + value: zImageField.optional(), +}); + +const zImageCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ImageCollection'), + value: z.array(zImageField).optional(), +}); + +const zMainModelInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('MainModelField'), + value: zMainOrOnnxModel.optional(), +}); + +const zSDXLMainModelInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('SDXLMainModelField'), + value: zMainOrOnnxModel.optional(), +}); + +const zSDXLRefinerModelInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('SDXLRefinerModelField'), + value: zMainOrOnnxModel.optional(), // TODO: should narrow this down to a refiner model +}); + +const zVaeModelField = zModelIdentifier; + +const zVaeModelInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('VaeModelField'), + value: zVaeModelField.optional(), +}); + +const zLoRAModelField = zModelIdentifier; + +const zLoRAModelInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('LoRAModelField'), + value: zLoRAModelField.optional(), +}); + +const zControlNetModelField = zModelIdentifier; + +const zControlNetModelInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ControlNetModelField'), + value: zControlNetModelField.optional(), +}); + +const zIPAdapterModelField = zModelIdentifier; + +const zIPAdapterModelInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('IPAdapterModelField'), + value: zIPAdapterModelField.optional(), +}); + +const zT2IAdapterModelField = zModelIdentifier; + +const zT2IAdapterModelInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('T2IAdapterModelField'), + value: zT2IAdapterModelField.optional(), +}); + +const zCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('Collection'), + value: z.array(z.any()).optional(), // TODO: should this field ever have a value? +}); + +const zCollectionItemInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('CollectionItem'), + value: z.any().optional(), // TODO: should this field ever have a value? +}); + +const zMetadataItemField = z.object({ + label: z.string(), + value: z.any(), +}); + +const zMetadataItemInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('MetadataItemField'), + value: zMetadataItemField.optional(), +}); + +const zMetadataItemCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('MetadataItemCollection'), + value: z.array(zMetadataItemField).optional(), +}); + +const zMetadataItemPolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('MetadataItemPolymorphic'), + value: z.union([zMetadataItemField, z.array(zMetadataItemField)]).optional(), +}); + +const zMetadataField = z.record(z.any()); + +const zMetadataInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('MetadataField'), + value: zMetadataField.optional(), +}); + +const zMetadataCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('MetadataCollection'), + value: z.array(zMetadataField).optional(), +}); + +const zColorField = z.object({ + r: z.number().int().min(0).max(255), + g: z.number().int().min(0).max(255), + b: z.number().int().min(0).max(255), + a: z.number().int().min(0).max(255), +}); + +const zColorInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ColorField'), + value: zColorField.optional(), +}); + +const zColorCollectionInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ColorCollection'), + value: z.array(zColorField).optional(), +}); + +const zColorPolymorphicInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('ColorPolymorphic'), + value: z.union([zColorField, z.array(zColorField)]).optional(), +}); + +const zSchedulerInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('Scheduler'), + value: zScheduler.optional(), +}); + +const zAnyInputFieldValue = zInputFieldValueBase.extend({ + type: z.literal('Any'), + value: z.any().optional(), +}); + +const zInputFieldValue = z.discriminatedUnion('type', [ + zAnyInputFieldValue, + zBoardInputFieldValue, + zBooleanCollectionInputFieldValue, + zBooleanInputFieldValue, + zBooleanPolymorphicInputFieldValue, + zClipInputFieldValue, + zCollectionInputFieldValue, + zCollectionItemInputFieldValue, + zColorInputFieldValue, + zColorCollectionInputFieldValue, + zColorPolymorphicInputFieldValue, + zConditioningInputFieldValue, + zConditioningCollectionInputFieldValue, + zConditioningPolymorphicInputFieldValue, + zControlInputFieldValue, + zControlNetModelInputFieldValue, + zControlCollectionInputFieldValue, + zControlPolymorphicInputFieldValue, + zDenoiseMaskInputFieldValue, + zEnumInputFieldValue, + zFloatCollectionInputFieldValue, + zFloatInputFieldValue, + zFloatPolymorphicInputFieldValue, + zImageCollectionInputFieldValue, + zImagePolymorphicInputFieldValue, + zImageInputFieldValue, + zIntegerCollectionInputFieldValue, + zIntegerPolymorphicInputFieldValue, + zIntegerInputFieldValue, + zIPAdapterInputFieldValue, + zIPAdapterModelInputFieldValue, + zIPAdapterCollectionInputFieldValue, + zIPAdapterPolymorphicInputFieldValue, + zLatentsInputFieldValue, + zLatentsCollectionInputFieldValue, + zLatentsPolymorphicInputFieldValue, + zLoRAModelInputFieldValue, + zMainModelInputFieldValue, + zSchedulerInputFieldValue, + zSDXLMainModelInputFieldValue, + zSDXLRefinerModelInputFieldValue, + zStringCollectionInputFieldValue, + zStringPolymorphicInputFieldValue, + zStringInputFieldValue, + zT2IAdapterInputFieldValue, + zT2IAdapterModelInputFieldValue, + zT2IAdapterCollectionInputFieldValue, + zT2IAdapterPolymorphicInputFieldValue, + zUNetInputFieldValue, + zVaeInputFieldValue, + zVaeModelInputFieldValue, + zMetadataItemInputFieldValue, + zMetadataItemCollectionInputFieldValue, + zMetadataItemPolymorphicInputFieldValue, + zMetadataInputFieldValue, + zMetadataCollectionInputFieldValue, +]); + +const zSemVer = z.string().refine((val) => { + const [major, minor, patch] = val.split('.'); + return ( + major !== undefined && + Number.isInteger(Number(major)) && + minor !== undefined && + Number.isInteger(Number(minor)) && + patch !== undefined && + Number.isInteger(Number(patch)) + ); +}); + +const zInvocationNodeData = z.object({ + id: z.string().trim().min(1), + // no easy way to build this dynamically, and we don't want to anyways, because this will be used + // to validate incoming workflows, and we want to allow community nodes. + type: z.string().trim().min(1), + inputs: z.record(zInputFieldValue), + outputs: z.record(zOutputFieldValue), + label: z.string(), + isOpen: z.boolean(), + notes: z.string(), + embedWorkflow: z.boolean(), + isIntermediate: z.boolean(), + useCache: z.boolean().default(true), + version: zSemVer.optional(), +}); + +const zNotesNodeData = z.object({ + id: z.string().trim().min(1), + type: z.literal('notes'), + label: z.string(), + isOpen: z.boolean(), + notes: z.string(), +}); + +const zPosition = z + .object({ + x: z.number(), + y: z.number(), + }) + .default({ x: 0, y: 0 }); + +const zDimension = z.number().gt(0).nullish(); + +const zWorkflowInvocationNode = z.object({ + id: z.string().trim().min(1), + type: z.literal('invocation'), + data: zInvocationNodeData, + width: zDimension, + height: zDimension, + position: zPosition, +}); + +const zWorkflowNotesNode = z.object({ + id: z.string().trim().min(1), + type: z.literal('notes'), + data: zNotesNodeData, + width: zDimension, + height: zDimension, + position: zPosition, +}); + +const zWorkflowNode = z.discriminatedUnion('type', [ + zWorkflowInvocationNode, + zWorkflowNotesNode, +]); + +const zDefaultWorkflowEdge = z.object({ + source: z.string().trim().min(1), + sourceHandle: z.string().trim().min(1), + target: z.string().trim().min(1), + targetHandle: z.string().trim().min(1), + id: z.string().trim().min(1), + type: z.literal('default'), +}); +const zCollapsedWorkflowEdge = z.object({ + source: z.string().trim().min(1), + target: z.string().trim().min(1), + id: z.string().trim().min(1), + type: z.literal('collapsed'), +}); + +const zWorkflowEdge = z.union([zDefaultWorkflowEdge, zCollapsedWorkflowEdge]); + +const zFieldIdentifier = z.object({ + nodeId: z.string().trim().min(1), + fieldName: z.string().trim().min(1), +}); + +export const zWorkflowV1 = z.object({ + name: z.string().default(''), + author: z.string().default(''), + description: z.string().default(''), + version: z.string().default(''), + contact: z.string().default(''), + tags: z.string().default(''), + notes: z.string().default(''), + nodes: z.array(zWorkflowNode).default([]), + edges: z.array(zWorkflowEdge).default([]), + exposedFields: z.array(zFieldIdentifier).default([]), + meta: z.object({ + version: z.literal('1.0.0'), + }), +}); +export type WorkflowV1 = z.infer; diff --git a/invokeai/frontend/web/src/features/nodes/types/openapi.ts b/invokeai/frontend/web/src/features/nodes/types/openapi.ts new file mode 100644 index 0000000000..0d8ffeb920 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/openapi.ts @@ -0,0 +1,108 @@ +import { OpenAPIV3_1 } from 'openapi-types'; +import { + InputFieldJSONSchemaExtra, + OutputFieldJSONSchemaExtra, +} from 'services/api/types'; + +// Janky customization of OpenAPI Schema :/ + +export type InvocationSchemaExtra = { + output: OpenAPIV3_1.ReferenceObject; // the output of the invocation + title: string; + category?: string; + tags?: string[]; + version: string; + properties: Omit< + NonNullable & + (InputFieldJSONSchemaExtra | OutputFieldJSONSchemaExtra), + 'type' + > & { + type: Omit & { + default: string; + }; + use_cache: Omit & { + default: boolean; + }; + }; +}; + +export type InvocationSchemaType = { + default: string; // the type of the invocation +}; + +export type InvocationBaseSchemaObject = Omit< + OpenAPIV3_1.BaseSchemaObject, + 'title' | 'type' | 'properties' +> & + InvocationSchemaExtra; + +export type InvocationOutputSchemaObject = Omit< + OpenAPIV3_1.SchemaObject, + 'properties' +> & { + properties: OpenAPIV3_1.SchemaObject['properties'] & { + type: Omit & { + default: string; + }; + } & { + class: 'output'; + }; +}; + +export type InvocationFieldSchema = OpenAPIV3_1.SchemaObject & + InputFieldJSONSchemaExtra; + +export type OpenAPIV3_1SchemaOrRef = + | OpenAPIV3_1.ReferenceObject + | OpenAPIV3_1.SchemaObject; + +export interface ArraySchemaObject extends InvocationBaseSchemaObject { + type: OpenAPIV3_1.ArraySchemaObjectType; + items: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject; +} +export interface NonArraySchemaObject extends InvocationBaseSchemaObject { + type?: OpenAPIV3_1.NonArraySchemaObjectType; +} + +export type InvocationSchemaObject = ( + | ArraySchemaObject + | NonArraySchemaObject +) & { class: 'invocation' }; + +export const isSchemaObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined +): obj is OpenAPIV3_1.SchemaObject => Boolean(obj && !('$ref' in obj)); + +export const isArraySchemaObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined +): obj is OpenAPIV3_1.ArraySchemaObject => + Boolean(obj && !('$ref' in obj) && obj.type === 'array'); + +export const isNonArraySchemaObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined +): obj is OpenAPIV3_1.NonArraySchemaObject => + Boolean(obj && !('$ref' in obj) && obj.type !== 'array'); + +export const isRefObject = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined +): obj is OpenAPIV3_1.ReferenceObject => Boolean(obj && '$ref' in obj); + +export const isInvocationSchemaObject = ( + obj: + | OpenAPIV3_1.ReferenceObject + | OpenAPIV3_1.SchemaObject + | InvocationSchemaObject +): obj is InvocationSchemaObject => + 'class' in obj && obj.class === 'invocation'; + +export const isInvocationOutputSchemaObject = ( + obj: + | OpenAPIV3_1.ReferenceObject + | OpenAPIV3_1.SchemaObject + | InvocationOutputSchemaObject +): obj is InvocationOutputSchemaObject => + 'class' in obj && obj.class === 'output'; + +export const isInvocationFieldSchema = ( + obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject +): obj is InvocationFieldSchema => !('$ref' in obj); diff --git a/invokeai/frontend/web/src/features/nodes/types/semver.ts b/invokeai/frontend/web/src/features/nodes/types/semver.ts new file mode 100644 index 0000000000..70dc228819 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/semver.ts @@ -0,0 +1,23 @@ +import { z } from 'zod'; + +// Schemas and types for working with semver + +const zVersionInt = z.coerce.number().int().min(0); + +export const zSemVer = z.string().refine((val) => { + const [major, minor, patch] = val.split('.'); + return ( + zVersionInt.safeParse(major).success && + zVersionInt.safeParse(minor).success && + zVersionInt.safeParse(patch).success + ); +}); + +export const zParsedSemver = zSemVer.transform((val) => { + const [major, minor, patch] = val.split('.'); + return { + major: Number(major), + minor: Number(minor), + patch: Number(patch), + }; +}); diff --git a/invokeai/frontend/web/src/features/nodes/types/types.ts b/invokeai/frontend/web/src/features/nodes/types/types.ts deleted file mode 100644 index c55d114dcf..0000000000 --- a/invokeai/frontend/web/src/features/nodes/types/types.ts +++ /dev/null @@ -1,1742 +0,0 @@ -import { $store } from 'app/store/nanostores/store'; -import { - SchedulerParam, - zBaseModel, - zMainModel, - zMainOrOnnxModel, - zOnnxModel, - zSDXLRefinerModel, - zScheduler, -} from 'features/parameters/types/parameterSchemas'; -import i18n from 'i18next'; -import { has, keyBy } from 'lodash-es'; -import { OpenAPIV3_1 } from 'openapi-types'; -import { RgbaColor } from 'react-colorful'; -import { Node } from 'reactflow'; -import { Graph, _InputField, _OutputField } from 'services/api/types'; -import { - AnyInvocationType, - AnyResult, - ProgressImage, -} from 'services/events/types'; -import { O } from 'ts-toolbelt'; -import { JsonObject } from 'type-fest'; -import { z } from 'zod'; - -export type NonNullableGraph = O.Required; - -export type InvocationTemplate = { - /** - * Unique type of the invocation - */ - type: AnyInvocationType; - /** - * Display name of the invocation - */ - title: string; - /** - * Description of the invocation - */ - description: string; - /** - * Invocation tags - */ - tags: string[]; - /** - * Array of invocation inputs - */ - inputs: Record; - /** - * Array of the invocation outputs - */ - outputs: Record; - /** - * The type of this node's output - */ - outputType: string; // TODO: generate a union of output types - /** - * Whether or not this invocation supports workflows - */ - withWorkflow: boolean; - /** - * The invocation's version. - */ - version?: string; - /** - * Whether or not this node should use the cache - */ - useCache: boolean; -}; - -export type FieldUIConfig = { - title: string; - description: string; - color: string; -}; - -// TODO: Get this from the OpenAPI schema? may be tricky... -export const zFieldType = z.enum([ - 'Any', - 'BoardField', - 'boolean', - 'BooleanCollection', - 'BooleanPolymorphic', - 'ClipField', - 'Collection', - 'CollectionItem', - 'ColorCollection', - 'ColorField', - 'ColorPolymorphic', - 'ConditioningCollection', - 'ConditioningField', - 'ConditioningPolymorphic', - 'ControlCollection', - 'ControlField', - 'ControlNetModelField', - 'ControlPolymorphic', - 'DenoiseMaskField', - 'enum', - 'float', - 'FloatCollection', - 'FloatPolymorphic', - 'ImageCollection', - 'ImageField', - 'ImagePolymorphic', - 'integer', - 'IntegerCollection', - 'IntegerPolymorphic', - 'IPAdapterCollection', - 'IPAdapterField', - 'IPAdapterModelField', - 'IPAdapterPolymorphic', - 'LatentsCollection', - 'LatentsField', - 'LatentsPolymorphic', - 'LoRAModelField', - 'MainModelField', - 'MetadataField', - 'MetadataCollection', - 'MetadataItemField', - 'MetadataItemCollection', - 'MetadataItemPolymorphic', - 'ONNXModelField', - 'Scheduler', - 'SDXLMainModelField', - 'SDXLRefinerModelField', - 'string', - 'StringCollection', - 'StringPolymorphic', - 'T2IAdapterCollection', - 'T2IAdapterField', - 'T2IAdapterModelField', - 'T2IAdapterPolymorphic', - 'UNetField', - 'VaeField', - 'VaeModelField', -]); - -export type FieldType = z.infer; -export type FieldTypeMap = { [key in FieldType]?: FieldType }; -export type FieldTypeMapWithNumber = { - [key in FieldType | 'number']?: FieldType; -}; - -export const zReservedFieldType = z.enum([ - 'WorkflowField', - 'IsIntermediate', - 'MetadataField', -]); - -export type ReservedFieldType = z.infer; - -export const isFieldType = (value: unknown): value is FieldType => - zFieldType.safeParse(value).success || - zReservedFieldType.safeParse(value).success; - -/** - * Indicates the kind of input(s) this field may have. - */ -export const zInputKind = z.enum(['connection', 'direct', 'any']); -export type InputKind = z.infer; - -export const zFieldValueBase = z.object({ - id: z.string().trim().min(1), - name: z.string().trim().min(1), - type: zFieldType, -}); -export type FieldValueBase = z.infer; - -/** - * An output field is persisted across as part of the user's local state. - * - * An output field has two properties: - * - `id` a unique identifier - * - `name` the name of the field, which comes from the python dataclass - */ - -export const zOutputFieldValue = zFieldValueBase.extend({ - fieldKind: z.literal('output'), -}); -export type OutputFieldValue = z.infer; - -/** - * An output field template is generated on each page load from the OpenAPI schema. - * - * The template provides the output field's name, type, title, and description. - */ -export type OutputFieldTemplate = { - fieldKind: 'output'; - name: string; - type: FieldType; - title: string; - description: string; -} & _OutputField; - -export const zInputFieldValueBase = zFieldValueBase.extend({ - fieldKind: z.literal('input'), - label: z.string(), -}); -export type InputFieldValueBase = z.infer; - -export const zModelIdentifier = z.object({ - model_name: z.string().trim().min(1), - base_model: zBaseModel, -}); - -export const zImageField = z.object({ - image_name: z.string().trim().min(1), -}); -export type ImageField = z.infer; - -export const zBoardField = z.object({ - board_id: z.string().trim().min(1), -}); -export type BoardField = z.infer; - -export const zLatentsField = z.object({ - latents_name: z.string().trim().min(1), - seed: z.number().int().optional(), -}); -export type LatentsField = z.infer; - -export const zConditioningField = z.object({ - conditioning_name: z.string().trim().min(1), -}); -export type ConditioningField = z.infer; - -export const zDenoiseMaskField = z.object({ - mask_name: z.string().trim().min(1), - masked_latents_name: z.string().trim().min(1).optional(), -}); -export type DenoiseMaskFieldValue = z.infer; - -export const zIntegerInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('integer'), - value: z.number().int().optional(), -}); -export type IntegerInputFieldValue = z.infer; - -export const zIntegerCollectionInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('IntegerCollection'), - value: z.array(z.number().int()).optional(), -}); -export type IntegerCollectionInputFieldValue = z.infer< - typeof zIntegerCollectionInputFieldValue ->; - -export const zIntegerPolymorphicInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('IntegerPolymorphic'), - value: z.number().int().optional(), -}); -export type IntegerPolymorphicInputFieldValue = z.infer< - typeof zIntegerPolymorphicInputFieldValue ->; - -export const zFloatInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('float'), - value: z.number().optional(), -}); -export type FloatInputFieldValue = z.infer; - -export const zFloatCollectionInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('FloatCollection'), - value: z.array(z.number()).optional(), -}); -export type FloatCollectionInputFieldValue = z.infer< - typeof zFloatCollectionInputFieldValue ->; - -export const zFloatPolymorphicInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('FloatPolymorphic'), - value: z.number().optional(), -}); -export type FloatPolymorphicInputFieldValue = z.infer< - typeof zFloatPolymorphicInputFieldValue ->; - -export const zStringInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('string'), - value: z.string().optional(), -}); -export type StringInputFieldValue = z.infer; - -export const zStringCollectionInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('StringCollection'), - value: z.array(z.string()).optional(), -}); -export type StringCollectionInputFieldValue = z.infer< - typeof zStringCollectionInputFieldValue ->; - -export const zStringPolymorphicInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('StringPolymorphic'), - value: z.string().optional(), -}); -export type StringPolymorphicInputFieldValue = z.infer< - typeof zStringPolymorphicInputFieldValue ->; - -export const zBooleanInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('boolean'), - value: z.boolean().optional(), -}); -export type BooleanInputFieldValue = z.infer; - -export const zBooleanCollectionInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('BooleanCollection'), - value: z.array(z.boolean()).optional(), -}); -export type BooleanCollectionInputFieldValue = z.infer< - typeof zBooleanCollectionInputFieldValue ->; - -export const zBooleanPolymorphicInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('BooleanPolymorphic'), - value: z.boolean().optional(), -}); -export type BooleanPolymorphicInputFieldValue = z.infer< - typeof zBooleanPolymorphicInputFieldValue ->; - -export const zEnumInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('enum'), - value: z.string().optional(), -}); -export type EnumInputFieldValue = z.infer; - -export const zLatentsInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('LatentsField'), - value: zLatentsField.optional(), -}); -export type LatentsInputFieldValue = z.infer; - -export const zLatentsCollectionInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('LatentsCollection'), - value: z.array(zLatentsField).optional(), -}); -export type LatentsCollectionInputFieldValue = z.infer< - typeof zLatentsCollectionInputFieldValue ->; - -export const zLatentsPolymorphicInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('LatentsPolymorphic'), - value: z.union([zLatentsField, z.array(zLatentsField)]).optional(), -}); -export type LatentsPolymorphicInputFieldValue = z.infer< - typeof zLatentsPolymorphicInputFieldValue ->; - -export const zDenoiseMaskInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('DenoiseMaskField'), - value: zDenoiseMaskField.optional(), -}); -export type DenoiseMaskInputFieldValue = z.infer< - typeof zDenoiseMaskInputFieldValue ->; - -export const zConditioningInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('ConditioningField'), - value: zConditioningField.optional(), -}); -export type ConditioningInputFieldValue = z.infer< - typeof zConditioningInputFieldValue ->; - -export const zConditioningCollectionInputFieldValue = - zInputFieldValueBase.extend({ - type: z.literal('ConditioningCollection'), - value: z.array(zConditioningField).optional(), - }); -export type ConditioningCollectionInputFieldValue = z.infer< - typeof zConditioningCollectionInputFieldValue ->; - -export const zConditioningPolymorphicInputFieldValue = - zInputFieldValueBase.extend({ - type: z.literal('ConditioningPolymorphic'), - value: z - .union([zConditioningField, z.array(zConditioningField)]) - .optional(), - }); -export type ConditioningPolymorphicInputFieldValue = z.infer< - typeof zConditioningPolymorphicInputFieldValue ->; - -export const zControlNetModel = zModelIdentifier; -export type ControlNetModel = z.infer; - -export const zControlField = z.object({ - image: zImageField, - control_model: zControlNetModel, - control_weight: z.union([z.number(), z.array(z.number())]).optional(), - begin_step_percent: z.number().optional(), - end_step_percent: z.number().optional(), - control_mode: z - .enum(['balanced', 'more_prompt', 'more_control', 'unbalanced']) - .optional(), - resize_mode: z - .enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple']) - .optional(), -}); -export type ControlField = z.infer; - -export const zControlInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('ControlField'), - value: zControlField.optional(), -}); -export type ControlInputFieldValue = z.infer; - -export const zControlPolymorphicInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('ControlPolymorphic'), - value: z.union([zControlField, z.array(zControlField)]).optional(), -}); -export type ControlPolymorphicInputFieldValue = z.infer< - typeof zControlPolymorphicInputFieldValue ->; - -export const zControlCollectionInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('ControlCollection'), - value: z.array(zControlField).optional(), -}); -export type ControlCollectionInputFieldValue = z.infer< - typeof zControlCollectionInputFieldValue ->; - -export const zIPAdapterModel = zModelIdentifier; -export type IPAdapterModel = z.infer; - -export const zIPAdapterField = z.object({ - image: zImageField, - ip_adapter_model: zIPAdapterModel, - weight: z.number(), - begin_step_percent: z.number().optional(), - end_step_percent: z.number().optional(), -}); -export type IPAdapterField = z.infer; - -export const zIPAdapterInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('IPAdapterField'), - value: zIPAdapterField.optional(), -}); -export type IPAdapterInputFieldValue = z.infer< - typeof zIPAdapterInputFieldValue ->; - -export const zIPAdapterPolymorphicInputFieldValue = zInputFieldValueBase.extend( - { - type: z.literal('IPAdapterPolymorphic'), - value: z.union([zIPAdapterField, z.array(zIPAdapterField)]).optional(), - } -); -export type IPAdapterPolymorphicInputFieldValue = z.infer< - typeof zT2IAdapterPolymorphicInputFieldValue ->; - -export const zIPAdapterCollectionInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('IPAdapterCollection'), - value: z.array(zIPAdapterField).optional(), -}); -export type IPAdapterCollectionInputFieldValue = z.infer< - typeof zIPAdapterCollectionInputFieldValue ->; - -export const zT2IAdapterModel = zModelIdentifier; -export type T2IAdapterModel = z.infer; - -export const zT2IAdapterField = z.object({ - image: zImageField, - t2i_adapter_model: zT2IAdapterModel, - weight: z.union([z.number(), z.array(z.number())]).optional(), - begin_step_percent: z.number().optional(), - end_step_percent: z.number().optional(), - resize_mode: z - .enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple']) - .optional(), -}); -export type T2IAdapterField = z.infer; - -export const zT2IAdapterInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('T2IAdapterField'), - value: zT2IAdapterField.optional(), -}); -export type T2IAdapterInputFieldValue = z.infer< - typeof zT2IAdapterInputFieldValue ->; - -export const zT2IAdapterPolymorphicInputFieldValue = - zInputFieldValueBase.extend({ - type: z.literal('T2IAdapterPolymorphic'), - value: z.union([zT2IAdapterField, z.array(zT2IAdapterField)]).optional(), - }); -export type T2IAdapterPolymorphicInputFieldValue = z.infer< - typeof zT2IAdapterPolymorphicInputFieldValue ->; - -export const zT2IAdapterCollectionInputFieldValue = zInputFieldValueBase.extend( - { - type: z.literal('T2IAdapterCollection'), - value: z.array(zT2IAdapterField).optional(), - } -); -export type T2IAdapterCollectionInputFieldValue = z.infer< - typeof zT2IAdapterCollectionInputFieldValue ->; - -export const zModelType = z.enum([ - 'onnx', - 'main', - 'vae', - 'lora', - 'controlnet', - 'embedding', -]); -export type ModelType = z.infer; - -export const zSubModelType = z.enum([ - 'unet', - 'text_encoder', - 'text_encoder_2', - 'tokenizer', - 'tokenizer_2', - 'vae', - 'vae_decoder', - 'vae_encoder', - 'scheduler', - 'safety_checker', -]); -export type SubModelType = z.infer; - -export const zModelInfo = zModelIdentifier.extend({ - model_type: zModelType, - submodel: zSubModelType.optional(), -}); -export type ModelInfo = z.infer; - -export const zLoraInfo = zModelInfo.extend({ - weight: z.number().optional(), -}); -export type LoraInfo = z.infer; - -export const zUNetField = z.object({ - unet: zModelInfo, - scheduler: zModelInfo, - loras: z.array(zLoraInfo), -}); -export type UNetField = z.infer; - -export const zUNetInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('UNetField'), - value: zUNetField.optional(), -}); -export type UNetInputFieldValue = z.infer; - -export const zClipField = z.object({ - tokenizer: zModelInfo, - text_encoder: zModelInfo, - skipped_layers: z.number(), - loras: z.array(zLoraInfo), -}); -export type ClipField = z.infer; - -export const zClipInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('ClipField'), - value: zClipField.optional(), -}); -export type ClipInputFieldValue = z.infer; - -export const zVaeField = z.object({ - vae: zModelInfo, -}); -export type VaeField = z.infer; - -export const zVaeInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('VaeField'), - value: zVaeField.optional(), -}); -export type VaeInputFieldValue = z.infer; - -export const zImageInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('ImageField'), - value: zImageField.optional(), -}); -export type ImageInputFieldValue = z.infer; - -export const zBoardInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('BoardField'), - value: zBoardField.optional(), -}); -export type BoardInputFieldValue = z.infer; - -export const zImagePolymorphicInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('ImagePolymorphic'), - value: zImageField.optional(), -}); -export type ImagePolymorphicInputFieldValue = z.infer< - typeof zImagePolymorphicInputFieldValue ->; - -export const zImageCollectionInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('ImageCollection'), - value: z.array(zImageField).optional(), -}); -export type ImageCollectionInputFieldValue = z.infer< - typeof zImageCollectionInputFieldValue ->; - -export const zMainModelInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('MainModelField'), - value: zMainOrOnnxModel.optional(), -}); -export type MainModelInputFieldValue = z.infer< - typeof zMainModelInputFieldValue ->; - -export const zSDXLMainModelInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('SDXLMainModelField'), - value: zMainOrOnnxModel.optional(), -}); -export type SDXLMainModelInputFieldValue = z.infer< - typeof zSDXLMainModelInputFieldValue ->; - -export const zSDXLRefinerModelInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('SDXLRefinerModelField'), - value: zMainOrOnnxModel.optional(), // TODO: should narrow this down to a refiner model -}); -export type SDXLRefinerModelInputFieldValue = z.infer< - typeof zSDXLRefinerModelInputFieldValue ->; - -export const zVaeModelField = zModelIdentifier; - -export const zVaeModelInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('VaeModelField'), - value: zVaeModelField.optional(), -}); -export type VaeModelInputFieldValue = z.infer; - -export const zLoRAModelField = zModelIdentifier; -export type LoRAModelField = z.infer; - -export const zLoRAModelInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('LoRAModelField'), - value: zLoRAModelField.optional(), -}); -export type LoRAModelInputFieldValue = z.infer< - typeof zLoRAModelInputFieldValue ->; - -export const zControlNetModelField = zModelIdentifier; -export type ControlNetModelField = z.infer; - -export const zControlNetModelInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('ControlNetModelField'), - value: zControlNetModelField.optional(), -}); -export type ControlNetModelInputFieldValue = z.infer< - typeof zControlNetModelInputFieldValue ->; - -export const zIPAdapterModelField = zModelIdentifier; -export type IPAdapterModelField = z.infer; - -export const zIPAdapterModelInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('IPAdapterModelField'), - value: zIPAdapterModelField.optional(), -}); -export type IPAdapterModelInputFieldValue = z.infer< - typeof zIPAdapterModelInputFieldValue ->; - -export const zT2IAdapterModelField = zModelIdentifier; -export type T2IAdapterModelField = z.infer; - -export const zT2IAdapterModelInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('T2IAdapterModelField'), - value: zT2IAdapterModelField.optional(), -}); -export type T2IAdapterModelInputFieldValue = z.infer< - typeof zT2IAdapterModelInputFieldValue ->; - -export const zCollectionInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('Collection'), - value: z.array(z.any()).optional(), // TODO: should this field ever have a value? -}); -export type CollectionInputFieldValue = z.infer< - typeof zCollectionInputFieldValue ->; - -export const zCollectionItemInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('CollectionItem'), - value: z.any().optional(), // TODO: should this field ever have a value? -}); -export type CollectionItemInputFieldValue = z.infer< - typeof zCollectionItemInputFieldValue ->; - -export const zMetadataItemField = z.object({ - label: z.string(), - value: z.any(), -}); -export type MetadataItemField = z.infer; - -export const zMetadataItemInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('MetadataItemField'), - value: zMetadataItemField.optional(), -}); -export type MetadataItemInputFieldValue = z.infer< - typeof zMetadataItemInputFieldValue ->; - -export const zMetadataItemCollectionInputFieldValue = - zInputFieldValueBase.extend({ - type: z.literal('MetadataItemCollection'), - value: z.array(zMetadataItemField).optional(), - }); -export type MetadataItemCollectionInputFieldValue = z.infer< - typeof zMetadataItemCollectionInputFieldValue ->; - -export const zMetadataItemPolymorphicInputFieldValue = - zInputFieldValueBase.extend({ - type: z.literal('MetadataItemPolymorphic'), - value: z - .union([zMetadataItemField, z.array(zMetadataItemField)]) - .optional(), - }); -export type MetadataItemPolymorphicInputFieldValue = z.infer< - typeof zMetadataItemPolymorphicInputFieldValue ->; - -export const zMetadataField = z.record(z.any()); -export type MetadataField = z.infer; - -export const zMetadataInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('MetadataField'), - value: zMetadataField.optional(), -}); -export type MetadataInputFieldValue = z.infer; - -export const zMetadataCollectionInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('MetadataCollection'), - value: z.array(zMetadataField).optional(), -}); -export type MetadataCollectionInputFieldValue = z.infer< - typeof zMetadataCollectionInputFieldValue ->; - -export const zColorField = z.object({ - r: z.number().int().min(0).max(255), - g: z.number().int().min(0).max(255), - b: z.number().int().min(0).max(255), - a: z.number().int().min(0).max(255), -}); -export type ColorField = z.infer; - -export const zColorInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('ColorField'), - value: zColorField.optional(), -}); -export type ColorInputFieldValue = z.infer; - -export const zColorCollectionInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('ColorCollection'), - value: z.array(zColorField).optional(), -}); -export type ColorCollectionInputFieldValue = z.infer< - typeof zColorCollectionInputFieldValue ->; - -export const zColorPolymorphicInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('ColorPolymorphic'), - value: z.union([zColorField, z.array(zColorField)]).optional(), -}); -export type ColorPolymorphicInputFieldValue = z.infer< - typeof zColorPolymorphicInputFieldValue ->; - -export const zSchedulerInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('Scheduler'), - value: zScheduler.optional(), -}); -export type SchedulerInputFieldValue = z.infer< - typeof zSchedulerInputFieldValue ->; - -export const zAnyInputFieldValue = zInputFieldValueBase.extend({ - type: z.literal('Any'), - value: z.any().optional(), -}); - -export const zInputFieldValue = z.discriminatedUnion('type', [ - zAnyInputFieldValue, - zBoardInputFieldValue, - zBooleanCollectionInputFieldValue, - zBooleanInputFieldValue, - zBooleanPolymorphicInputFieldValue, - zClipInputFieldValue, - zCollectionInputFieldValue, - zCollectionItemInputFieldValue, - zColorInputFieldValue, - zColorCollectionInputFieldValue, - zColorPolymorphicInputFieldValue, - zConditioningInputFieldValue, - zConditioningCollectionInputFieldValue, - zConditioningPolymorphicInputFieldValue, - zControlInputFieldValue, - zControlNetModelInputFieldValue, - zControlCollectionInputFieldValue, - zControlPolymorphicInputFieldValue, - zDenoiseMaskInputFieldValue, - zEnumInputFieldValue, - zFloatCollectionInputFieldValue, - zFloatInputFieldValue, - zFloatPolymorphicInputFieldValue, - zImageCollectionInputFieldValue, - zImagePolymorphicInputFieldValue, - zImageInputFieldValue, - zIntegerCollectionInputFieldValue, - zIntegerPolymorphicInputFieldValue, - zIntegerInputFieldValue, - zIPAdapterInputFieldValue, - zIPAdapterModelInputFieldValue, - zIPAdapterCollectionInputFieldValue, - zIPAdapterPolymorphicInputFieldValue, - zLatentsInputFieldValue, - zLatentsCollectionInputFieldValue, - zLatentsPolymorphicInputFieldValue, - zLoRAModelInputFieldValue, - zMainModelInputFieldValue, - zSchedulerInputFieldValue, - zSDXLMainModelInputFieldValue, - zSDXLRefinerModelInputFieldValue, - zStringCollectionInputFieldValue, - zStringPolymorphicInputFieldValue, - zStringInputFieldValue, - zT2IAdapterInputFieldValue, - zT2IAdapterModelInputFieldValue, - zT2IAdapterCollectionInputFieldValue, - zT2IAdapterPolymorphicInputFieldValue, - zUNetInputFieldValue, - zVaeInputFieldValue, - zVaeModelInputFieldValue, - zMetadataItemInputFieldValue, - zMetadataItemCollectionInputFieldValue, - zMetadataItemPolymorphicInputFieldValue, - zMetadataInputFieldValue, - zMetadataCollectionInputFieldValue, -]); - -export type InputFieldValue = z.infer; - -export type InputFieldTemplateBase = { - name: string; - title: string; - description: string; - required: boolean; - fieldKind: 'input'; -} & _InputField; - -export type AnyInputFieldTemplate = InputFieldTemplateBase & { - type: 'Any'; - default: undefined; -}; - -export type IntegerInputFieldTemplate = InputFieldTemplateBase & { - type: 'integer'; - default: number; - multipleOf?: number; - maximum?: number; - exclusiveMaximum?: number; - minimum?: number; - exclusiveMinimum?: number; -}; - -export type IntegerCollectionInputFieldTemplate = InputFieldTemplateBase & { - type: 'IntegerCollection'; - default: number[]; - item_default?: number; -}; - -export type IntegerPolymorphicInputFieldTemplate = Omit< - IntegerInputFieldTemplate, - 'type' -> & { - type: 'IntegerPolymorphic'; -}; - -export type FloatInputFieldTemplate = InputFieldTemplateBase & { - type: 'float'; - default: number; - multipleOf?: number; - maximum?: number; - exclusiveMaximum?: number; - minimum?: number; - exclusiveMinimum?: number; -}; - -export type FloatCollectionInputFieldTemplate = InputFieldTemplateBase & { - type: 'FloatCollection'; - default: number[]; - item_default?: number; -}; - -export type FloatPolymorphicInputFieldTemplate = Omit< - FloatInputFieldTemplate, - 'type' -> & { - type: 'FloatPolymorphic'; -}; - -export type StringInputFieldTemplate = InputFieldTemplateBase & { - type: 'string'; - default: string; - maxLength?: number; - minLength?: number; - pattern?: string; -}; - -export type StringCollectionInputFieldTemplate = InputFieldTemplateBase & { - type: 'StringCollection'; - default: string[]; - item_default?: string; -}; - -export type StringPolymorphicInputFieldTemplate = Omit< - StringInputFieldTemplate, - 'type' -> & { - type: 'StringPolymorphic'; -}; - -export type BooleanInputFieldTemplate = InputFieldTemplateBase & { - default: boolean; - type: 'boolean'; -}; - -export type BooleanCollectionInputFieldTemplate = InputFieldTemplateBase & { - type: 'BooleanCollection'; - default: boolean[]; - item_default?: boolean; -}; - -export type BooleanPolymorphicInputFieldTemplate = Omit< - BooleanInputFieldTemplate, - 'type' -> & { - type: 'BooleanPolymorphic'; -}; - -export type BoardInputFieldTemplate = InputFieldTemplateBase & { - default: BoardField; - type: 'BoardField'; -}; - -export type ImageInputFieldTemplate = InputFieldTemplateBase & { - default: ImageField; - type: 'ImageField'; -}; - -export type ImageCollectionInputFieldTemplate = InputFieldTemplateBase & { - default: ImageField[]; - type: 'ImageCollection'; - item_default?: ImageField; -}; - -export type ImagePolymorphicInputFieldTemplate = Omit< - ImageInputFieldTemplate, - 'type' -> & { - type: 'ImagePolymorphic'; -}; - -export type DenoiseMaskInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'DenoiseMaskField'; -}; - -export type LatentsInputFieldTemplate = InputFieldTemplateBase & { - default: LatentsField; - type: 'LatentsField'; -}; - -export type LatentsCollectionInputFieldTemplate = InputFieldTemplateBase & { - default: LatentsField[]; - type: 'LatentsCollection'; - item_default?: LatentsField; -}; - -export type LatentsPolymorphicInputFieldTemplate = InputFieldTemplateBase & { - default: LatentsField; - type: 'LatentsPolymorphic'; -}; - -export type ConditioningInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'ConditioningField'; -}; - -export type ConditioningCollectionInputFieldTemplate = - InputFieldTemplateBase & { - default: ConditioningField[]; - type: 'ConditioningCollection'; - item_default?: ConditioningField; - }; - -export type ConditioningPolymorphicInputFieldTemplate = Omit< - ConditioningInputFieldTemplate, - 'type' -> & { - type: 'ConditioningPolymorphic'; -}; - -export type UNetInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'UNetField'; -}; - -export type MetadataItemFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'MetadataItemField'; -}; - -export type ClipInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'ClipField'; -}; - -export type VaeInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'VaeField'; -}; - -export type ControlInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'ControlField'; -}; - -export type ControlCollectionInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'ControlCollection'; - item_default?: ControlField; -}; - -export type ControlPolymorphicInputFieldTemplate = Omit< - ControlInputFieldTemplate, - 'type' -> & { - type: 'ControlPolymorphic'; -}; - -export type IPAdapterInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'IPAdapterField'; -}; - -export type IPAdapterCollectionInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'IPAdapterCollection'; - item_default?: IPAdapterField; -}; - -export type IPAdapterPolymorphicInputFieldTemplate = Omit< - IPAdapterInputFieldTemplate, - 'type' -> & { - type: 'IPAdapterPolymorphic'; -}; - -export type T2IAdapterInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'T2IAdapterField'; -}; - -export type T2IAdapterCollectionInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'T2IAdapterCollection'; - item_default?: T2IAdapterField; -}; - -export type T2IAdapterPolymorphicInputFieldTemplate = Omit< - T2IAdapterInputFieldTemplate, - 'type' -> & { - type: 'T2IAdapterPolymorphic'; -}; - -export type EnumInputFieldTemplate = InputFieldTemplateBase & { - default: string; - type: 'enum'; - options: string[]; - labels?: { [key: string]: string }; -}; - -export type MainModelInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'MainModelField'; -}; - -export type SDXLMainModelInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'SDXLMainModelField'; -}; - -export type SDXLRefinerModelInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'SDXLRefinerModelField'; -}; - -export type VaeModelInputFieldTemplate = InputFieldTemplateBase & { - default: string; - type: 'VaeModelField'; -}; - -export type LoRAModelInputFieldTemplate = InputFieldTemplateBase & { - default: string; - type: 'LoRAModelField'; -}; - -export type ControlNetModelInputFieldTemplate = InputFieldTemplateBase & { - default: string; - type: 'ControlNetModelField'; -}; - -export type IPAdapterModelInputFieldTemplate = InputFieldTemplateBase & { - default: string; - type: 'IPAdapterModelField'; -}; - -export type T2IAdapterModelInputFieldTemplate = InputFieldTemplateBase & { - default: string; - type: 'T2IAdapterModelField'; -}; - -export type CollectionInputFieldTemplate = InputFieldTemplateBase & { - default: []; - type: 'Collection'; -}; - -export type CollectionItemInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'CollectionItem'; -}; - -export type ColorInputFieldTemplate = InputFieldTemplateBase & { - default: RgbaColor; - type: 'ColorField'; -}; - -export type ColorPolymorphicInputFieldTemplate = Omit< - ColorInputFieldTemplate, - 'type' -> & { - type: 'ColorPolymorphic'; -}; - -export type ColorCollectionInputFieldTemplate = InputFieldTemplateBase & { - default: []; - type: 'ColorCollection'; -}; - -export type SchedulerInputFieldTemplate = InputFieldTemplateBase & { - default: SchedulerParam; - type: 'Scheduler'; -}; - -export type WorkflowInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'WorkflowField'; -}; - -export type MetadataItemInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'MetadataItemField'; -}; - -export type MetadataItemCollectionInputFieldTemplate = - InputFieldTemplateBase & { - default: undefined; - type: 'MetadataItemCollection'; - }; - -export type MetadataItemPolymorphicInputFieldTemplate = Omit< - MetadataItemInputFieldTemplate, - 'type' -> & { - type: 'MetadataItemPolymorphic'; -}; - -export type MetadataInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'MetadataField'; -}; - -export type MetadataCollectionInputFieldTemplate = InputFieldTemplateBase & { - default: undefined; - type: 'MetadataCollection'; -}; - -/** - * An input field template is generated on each page load from the OpenAPI schema. - * - * The template provides the field type and other field metadata (e.g. title, description, - * maximum length, pattern to match, etc). - */ -export type InputFieldTemplate = - | AnyInputFieldTemplate - | BoardInputFieldTemplate - | BooleanCollectionInputFieldTemplate - | BooleanPolymorphicInputFieldTemplate - | BooleanInputFieldTemplate - | ClipInputFieldTemplate - | CollectionInputFieldTemplate - | CollectionItemInputFieldTemplate - | ColorInputFieldTemplate - | ColorCollectionInputFieldTemplate - | ColorPolymorphicInputFieldTemplate - | ConditioningInputFieldTemplate - | ConditioningCollectionInputFieldTemplate - | ConditioningPolymorphicInputFieldTemplate - | ControlInputFieldTemplate - | ControlCollectionInputFieldTemplate - | ControlNetModelInputFieldTemplate - | ControlPolymorphicInputFieldTemplate - | DenoiseMaskInputFieldTemplate - | EnumInputFieldTemplate - | FloatCollectionInputFieldTemplate - | FloatInputFieldTemplate - | FloatPolymorphicInputFieldTemplate - | ImageCollectionInputFieldTemplate - | ImagePolymorphicInputFieldTemplate - | ImageInputFieldTemplate - | IntegerCollectionInputFieldTemplate - | IntegerPolymorphicInputFieldTemplate - | IntegerInputFieldTemplate - | IPAdapterInputFieldTemplate - | IPAdapterCollectionInputFieldTemplate - | IPAdapterModelInputFieldTemplate - | IPAdapterPolymorphicInputFieldTemplate - | LatentsInputFieldTemplate - | LatentsCollectionInputFieldTemplate - | LatentsPolymorphicInputFieldTemplate - | LoRAModelInputFieldTemplate - | MainModelInputFieldTemplate - | SchedulerInputFieldTemplate - | SDXLMainModelInputFieldTemplate - | SDXLRefinerModelInputFieldTemplate - | StringCollectionInputFieldTemplate - | StringPolymorphicInputFieldTemplate - | StringInputFieldTemplate - | T2IAdapterInputFieldTemplate - | T2IAdapterCollectionInputFieldTemplate - | T2IAdapterModelInputFieldTemplate - | T2IAdapterPolymorphicInputFieldTemplate - | UNetInputFieldTemplate - | VaeInputFieldTemplate - | VaeModelInputFieldTemplate - | MetadataItemInputFieldTemplate - | MetadataItemCollectionInputFieldTemplate - | MetadataInputFieldTemplate - | MetadataItemPolymorphicInputFieldTemplate - | MetadataCollectionInputFieldTemplate; - -export const isInputFieldValue = ( - field?: InputFieldValue | OutputFieldValue -): field is InputFieldValue => Boolean(field && field.fieldKind === 'input'); - -export const isInputFieldTemplate = ( - fieldTemplate?: InputFieldTemplate | OutputFieldTemplate -): fieldTemplate is InputFieldTemplate => - Boolean(fieldTemplate && fieldTemplate.fieldKind === 'input'); - -/** - * JANKY CUSTOMISATION OF OpenAPI SCHEMA TYPES - */ - -export type TypeHints = { - [fieldName: string]: FieldType; -}; - -export type InvocationSchemaExtra = { - output: OpenAPIV3_1.ReferenceObject; // the output of the invocation - title: string; - category?: string; - tags?: string[]; - version?: string; - properties: Omit< - NonNullable & - (_InputField | _OutputField), - 'type' - > & { - type: Omit & { - default: AnyInvocationType; - }; - use_cache: Omit & { - default: boolean; - }; - }; -}; - -export type InvocationSchemaType = { - default: string; // the type of the invocation -}; - -export type InvocationBaseSchemaObject = Omit< - OpenAPIV3_1.BaseSchemaObject, - 'title' | 'type' | 'properties' -> & - InvocationSchemaExtra; - -export type InvocationOutputSchemaObject = Omit< - OpenAPIV3_1.SchemaObject, - 'properties' -> & { - properties: OpenAPIV3_1.SchemaObject['properties'] & { - type: Omit & { - default: string; - }; - } & { - class: 'output'; - }; -}; - -export type InvocationFieldSchema = OpenAPIV3_1.SchemaObject & _InputField; - -export type OpenAPIV3_1SchemaOrRef = - | OpenAPIV3_1.ReferenceObject - | OpenAPIV3_1.SchemaObject; - -export interface ArraySchemaObject extends InvocationBaseSchemaObject { - type: OpenAPIV3_1.ArraySchemaObjectType; - items: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject; -} -export interface NonArraySchemaObject extends InvocationBaseSchemaObject { - type?: OpenAPIV3_1.NonArraySchemaObjectType; -} - -export type InvocationSchemaObject = ( - | ArraySchemaObject - | NonArraySchemaObject -) & { class: 'invocation' }; - -export const isSchemaObject = ( - obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined -): obj is OpenAPIV3_1.SchemaObject => Boolean(obj && !('$ref' in obj)); - -export const isArraySchemaObject = ( - obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined -): obj is OpenAPIV3_1.ArraySchemaObject => - Boolean(obj && !('$ref' in obj) && obj.type === 'array'); - -export const isNonArraySchemaObject = ( - obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined -): obj is OpenAPIV3_1.NonArraySchemaObject => - Boolean(obj && !('$ref' in obj) && obj.type !== 'array'); - -export const isRefObject = ( - obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject | undefined -): obj is OpenAPIV3_1.ReferenceObject => Boolean(obj && '$ref' in obj); - -export const isInvocationSchemaObject = ( - obj: - | OpenAPIV3_1.ReferenceObject - | OpenAPIV3_1.SchemaObject - | InvocationSchemaObject -): obj is InvocationSchemaObject => - 'class' in obj && obj.class === 'invocation'; - -export const isInvocationOutputSchemaObject = ( - obj: - | OpenAPIV3_1.ReferenceObject - | OpenAPIV3_1.SchemaObject - | InvocationOutputSchemaObject -): obj is InvocationOutputSchemaObject => - 'class' in obj && obj.class === 'output'; - -export const isInvocationFieldSchema = ( - obj: OpenAPIV3_1.ReferenceObject | OpenAPIV3_1.SchemaObject -): obj is InvocationFieldSchema => !('$ref' in obj); - -export type InvocationEdgeExtra = { type: 'default' | 'collapsed' }; - -export const zLoRAMetadataItem = z.object({ - lora: zLoRAModelField.deepPartial(), - weight: z.number(), -}); - -export type LoRAMetadataItem = z.infer; - -const zControlNetMetadataItem = zControlField.deepPartial(); - -export type ControlNetMetadataItem = z.infer; - -const zIPAdapterMetadataItem = zIPAdapterField.deepPartial(); - -export type IPAdapterMetadataItem = z.infer; - -const zT2IAdapterMetadataItem = zT2IAdapterField.deepPartial(); - -export type T2IAdapterMetadataItem = z.infer; - -export const zCoreMetadata = z - .object({ - app_version: z.string().nullish().catch(null), - generation_mode: z.string().nullish().catch(null), - created_by: z.string().nullish().catch(null), - positive_prompt: z.string().nullish().catch(null), - negative_prompt: z.string().nullish().catch(null), - width: z.number().int().nullish().catch(null), - height: z.number().int().nullish().catch(null), - seed: z.number().int().nullish().catch(null), - rand_device: z.string().nullish().catch(null), - cfg_scale: z.number().nullish().catch(null), - steps: z.number().int().nullish().catch(null), - scheduler: z.string().nullish().catch(null), - clip_skip: z.number().int().nullish().catch(null), - model: z - .union([zMainModel.deepPartial(), zOnnxModel.deepPartial()]) - .nullish() - .catch(null), - controlnets: z.array(zControlNetMetadataItem).nullish().catch(null), - ipAdapters: z.array(zIPAdapterMetadataItem).nullish().catch(null), - t2iAdapters: z.array(zT2IAdapterMetadataItem).nullish().catch(null), - loras: z.array(zLoRAMetadataItem).nullish().catch(null), - vae: zVaeModelField.nullish().catch(null), - strength: z.number().nullish().catch(null), - hrf_enabled: z.boolean().nullish().catch(null), - hrf_strength: z.number().nullish().catch(null), - hrf_method: z.string().nullish().catch(null), - init_image: z.string().nullish().catch(null), - positive_style_prompt: z.string().nullish().catch(null), - negative_style_prompt: z.string().nullish().catch(null), - refiner_model: zSDXLRefinerModel.deepPartial().nullish().catch(null), - refiner_cfg_scale: z.number().nullish().catch(null), - refiner_steps: z.number().int().nullish().catch(null), - refiner_scheduler: z.string().nullish().catch(null), - refiner_positive_aesthetic_score: z.number().nullish().catch(null), - refiner_negative_aesthetic_score: z.number().nullish().catch(null), - refiner_start: z.number().nullish().catch(null), - }) - .passthrough(); - -export type CoreMetadata = z.infer; - -export const zSemVer = z.string().refine((val) => { - const [major, minor, patch] = val.split('.'); - return ( - major !== undefined && - Number.isInteger(Number(major)) && - minor !== undefined && - Number.isInteger(Number(minor)) && - patch !== undefined && - Number.isInteger(Number(patch)) - ); -}); - -export const zParsedSemver = zSemVer.transform((val) => { - const [major, minor, patch] = val.split('.'); - return { - major: Number(major), - minor: Number(minor), - patch: Number(patch), - }; -}); - -export type SemVer = z.infer; - -export const zInvocationNodeData = z.object({ - id: z.string().trim().min(1), - // no easy way to build this dynamically, and we don't want to anyways, because this will be used - // to validate incoming workflows, and we want to allow community nodes. - type: z.string().trim().min(1), - inputs: z.record(zInputFieldValue), - outputs: z.record(zOutputFieldValue), - label: z.string(), - isOpen: z.boolean(), - notes: z.string(), - embedWorkflow: z.boolean(), - isIntermediate: z.boolean(), - useCache: z.boolean().optional(), - version: zSemVer.optional(), -}); - -export const zInvocationNodeDataV2 = z.preprocess( - (arg) => { - try { - const data = zInvocationNodeData.parse(arg); - if (!has(data, 'useCache')) { - const nodeTemplates = $store.get()?.getState().nodes.nodeTemplates as - | Record - | undefined; - - const template = nodeTemplates?.[data.type]; - - let useCache = true; - if (template) { - useCache = template.useCache; - } - - Object.assign(data, { useCache }); - } - return data; - } catch { - return arg; - } - }, - zInvocationNodeData.extend({ - useCache: z.boolean(), - }) -); - -// Massage this to get better type safety while developing -export type InvocationNodeData = Omit< - z.infer, - 'type' -> & { - type: AnyInvocationType; -}; - -export const zNotesNodeData = z.object({ - id: z.string().trim().min(1), - type: z.literal('notes'), - label: z.string(), - isOpen: z.boolean(), - notes: z.string(), -}); - -export type NotesNodeData = z.infer; - -const zPosition = z - .object({ - x: z.number(), - y: z.number(), - }) - .default({ x: 0, y: 0 }); - -const zDimension = z.number().gt(0).nullish(); - -export const zWorkflowInvocationNode = z.object({ - id: z.string().trim().min(1), - type: z.literal('invocation'), - data: zInvocationNodeDataV2, - width: zDimension, - height: zDimension, - position: zPosition, -}); - -export type WorkflowInvocationNode = z.infer; - -export const isWorkflowInvocationNode = ( - val: unknown -): val is WorkflowInvocationNode => - zWorkflowInvocationNode.safeParse(val).success; - -export const zWorkflowNotesNode = z.object({ - id: z.string().trim().min(1), - type: z.literal('notes'), - data: zNotesNodeData, - width: zDimension, - height: zDimension, - position: zPosition, -}); - -export const zWorkflowNode = z.discriminatedUnion('type', [ - zWorkflowInvocationNode, - zWorkflowNotesNode, -]); - -export type WorkflowNode = z.infer; - -export const zDefaultWorkflowEdge = z.object({ - source: z.string().trim().min(1), - sourceHandle: z.string().trim().min(1), - target: z.string().trim().min(1), - targetHandle: z.string().trim().min(1), - id: z.string().trim().min(1), - type: z.literal('default'), -}); -export const zCollapsedWorkflowEdge = z.object({ - source: z.string().trim().min(1), - target: z.string().trim().min(1), - id: z.string().trim().min(1), - type: z.literal('collapsed'), -}); - -export const zWorkflowEdge = z.union([ - zDefaultWorkflowEdge, - zCollapsedWorkflowEdge, -]); - -export const zFieldIdentifier = z.object({ - nodeId: z.string().trim().min(1), - fieldName: z.string().trim().min(1), -}); - -export type FieldIdentifier = z.infer; - -export type WorkflowWarning = { - message: string; - issues: string[]; - data: JsonObject; -}; - -const CURRENT_WORKFLOW_VERSION = '1.0.0'; - -export const zWorkflow = z.object({ - name: z.string().default(''), - author: z.string().default(''), - description: z.string().default(''), - version: z.string().default(''), - contact: z.string().default(''), - tags: z.string().default(''), - notes: z.string().default(''), - nodes: z.array(zWorkflowNode).default([]), - edges: z.array(zWorkflowEdge).default([]), - exposedFields: z.array(zFieldIdentifier).default([]), - meta: z - .object({ - version: zSemVer, - }) - .default({ version: CURRENT_WORKFLOW_VERSION }), -}); - -export const zValidatedWorkflow = zWorkflow.transform((workflow) => { - const { nodes, edges } = workflow; - const warnings: WorkflowWarning[] = []; - const invocationNodes = nodes.filter(isWorkflowInvocationNode); - const keyedNodes = keyBy(invocationNodes, 'id'); - edges.forEach((edge, i) => { - const sourceNode = keyedNodes[edge.source]; - const targetNode = keyedNodes[edge.target]; - const issues: string[] = []; - if (!sourceNode) { - issues.push( - `${i18n.t('nodes.outputNode')} ${edge.source} ${i18n.t( - 'nodes.doesNotExist' - )}` - ); - } else if ( - edge.type === 'default' && - !(edge.sourceHandle in sourceNode.data.outputs) - ) { - issues.push( - `${i18n.t('nodes.outputField')}"${edge.source}.${ - edge.sourceHandle - }" ${i18n.t('nodes.doesNotExist')}` - ); - } - if (!targetNode) { - issues.push( - `${i18n.t('nodes.inputNode')} ${edge.target} ${i18n.t( - 'nodes.doesNotExist' - )}` - ); - } else if ( - edge.type === 'default' && - !(edge.targetHandle in targetNode.data.inputs) - ) { - issues.push( - `${i18n.t('nodes.inputField')} "${edge.target}.${ - edge.targetHandle - }" ${i18n.t('nodes.doesNotExist')}` - ); - } - if (issues.length) { - delete edges[i]; - const src = edge.type === 'default' ? edge.sourceHandle : edge.source; - const tgt = edge.type === 'default' ? edge.targetHandle : edge.target; - warnings.push({ - message: `${i18n.t('nodes.edge')} "${src} -> ${tgt}" ${i18n.t( - 'nodes.skipped' - )}`, - issues, - data: edge, - }); - } - }); - return { workflow, warnings }; -}); - -export type Workflow = z.infer; - -export type ImageMetadataAndWorkflow = { - metadata?: CoreMetadata; - workflow?: Workflow; -}; - -export type CurrentImageNodeData = { - id: string; - type: 'current_image'; - isOpen: boolean; - label: string; -}; - -export type NodeData = - | InvocationNodeData - | NotesNodeData - | CurrentImageNodeData; - -export const isInvocationNode = ( - node?: Node -): node is Node => - Boolean(node && node.type === 'invocation'); - -export const isInvocationNodeData = ( - node?: NodeData -): node is InvocationNodeData => - Boolean(node && !['notes', 'current_image'].includes(node.type)); - -export const isNotesNode = ( - node?: Node -): node is Node => Boolean(node && node.type === 'notes'); - -export const isProgressImageNode = ( - node?: Node -): node is Node => - Boolean(node && node.type === 'current_image'); - -export enum NodeStatus { - PENDING, - IN_PROGRESS, - COMPLETED, - FAILED, -} - -export type NodeExecutionState = { - nodeId: string; - status: NodeStatus; - progress: number | null; - progressImage: ProgressImage | null; - error: string | null; - outputs: AnyResult[]; -}; - -export type FieldComponentProps< - V extends InputFieldValue, - T extends InputFieldTemplate, -> = { - nodeId: string; - field: V; - fieldTemplate: T; -}; diff --git a/invokeai/frontend/web/src/features/nodes/types/workflow.ts b/invokeai/frontend/web/src/features/nodes/types/workflow.ts new file mode 100644 index 0000000000..7af8a2dd72 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/types/workflow.ts @@ -0,0 +1,91 @@ +import { z } from 'zod'; +import { zFieldIdentifier } from './field'; +import { zInvocationNodeData, zNotesNodeData } from './invocation'; + +// #region Workflow misc +export const zXYPosition = z + .object({ + x: z.number(), + y: z.number(), + }) + .default({ x: 0, y: 0 }); +export type XYPosition = z.infer; + +export const zDimension = z.number().gt(0).nullish(); +export type Dimension = z.infer; +// #endregion + +// #region Workflow Nodes +export const zWorkflowInvocationNode = z.object({ + id: z.string().trim().min(1), + type: z.literal('invocation'), + data: zInvocationNodeData, + width: zDimension, + height: zDimension, + position: zXYPosition, +}); +export const zWorkflowNotesNode = z.object({ + id: z.string().trim().min(1), + type: z.literal('notes'), + data: zNotesNodeData, + width: zDimension, + height: zDimension, + position: zXYPosition, +}); +export const zWorkflowNode = z.union([ + zWorkflowInvocationNode, + zWorkflowNotesNode, +]); + +export type WorkflowInvocationNode = z.infer; +export type WorkflowNotesNode = z.infer; +export type WorkflowNode = z.infer; + +export const isWorkflowInvocationNode = ( + val: unknown +): val is WorkflowInvocationNode => + zWorkflowInvocationNode.safeParse(val).success; +// #endregion + +// #region Workflow Edges +export const zWorkflowEdgeBase = z.object({ + id: z.string().trim().min(1), + source: z.string().trim().min(1), + target: z.string().trim().min(1), +}); +export const zWorkflowEdgeDefault = zWorkflowEdgeBase.extend({ + type: z.literal('default'), + sourceHandle: z.string().trim().min(1), + targetHandle: z.string().trim().min(1), +}); +export const zWorkflowEdgeCollapsed = zWorkflowEdgeBase.extend({ + type: z.literal('collapsed'), +}); +export const zWorkflowEdge = z.union([ + zWorkflowEdgeDefault, + zWorkflowEdgeCollapsed, +]); + +export type WorkflowEdgeDefault = z.infer; +export type WorkflowEdgeCollapsed = z.infer; +export type WorkflowEdge = z.infer; +// #endregion + +// #region Workflow +export const zWorkflowV2 = z.object({ + name: z.string(), + author: z.string(), + description: z.string(), + version: z.string(), + contact: z.string(), + tags: z.string(), + notes: z.string(), + nodes: z.array(zWorkflowNode), + edges: z.array(zWorkflowEdge), + exposedFields: z.array(zFieldIdentifier), + meta: z.object({ + version: z.literal('2.0.0'), + }), +}); +export type WorkflowV2 = z.infer; +// #endregion diff --git a/invokeai/frontend/web/src/features/nodes/util/buildFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/util/buildFieldInputInstance.ts new file mode 100644 index 0000000000..200bd98e86 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/buildFieldInputInstance.ts @@ -0,0 +1,42 @@ +import { get } from 'lodash-es'; +import { FieldInputInstance, FieldInputTemplate } from '../types/field'; + +const FIELD_VALUE_FALLBACK_MAP = { + EnumField: '', + BoardField: undefined, + BooleanField: false, + ClipField: undefined, + ColorField: { r: 0, g: 0, b: 0, a: 1 }, + FloatField: 0, + ImageField: undefined, + IntegerField: 0, + IPAdapterModelField: undefined, + LoRAModelField: undefined, + MainModelField: undefined, + ONNXModelField: undefined, + SchedulerField: 'euler', + SDXLMainModelField: undefined, + SDXLRefinerModelField: undefined, + StringField: '', + T2IAdapterModelField: undefined, + T2IAdapterPolymorphic: undefined, + VAEModelField: undefined, + ControlNetModelField: undefined, +}; + +export const buildFieldInputInstance = ( + id: string, + template: FieldInputTemplate +): FieldInputInstance => { + const fieldInstance: FieldInputInstance = { + id, + name: template.name, + type: template.type, + label: '', + fieldKind: 'input' as const, + value: + template.default ?? get(FIELD_VALUE_FALLBACK_MAP, template.type.name), + }; + + return fieldInstance; +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/buildFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/util/buildFieldInputTemplate.ts new file mode 100644 index 0000000000..8d11ac25b9 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/buildFieldInputTemplate.ts @@ -0,0 +1,376 @@ +import { isNumber, startCase } from 'lodash-es'; +import { + BoardFieldInputTemplate, + BooleanFieldInputTemplate, + ColorFieldInputTemplate, + ControlNetModelFieldInputTemplate, + EnumFieldInputTemplate, + FieldInputTemplate, + FieldType, + FloatFieldInputTemplate, + IPAdapterModelFieldInputTemplate, + ImageFieldInputTemplate, + IntegerFieldInputTemplate, + LoRAModelFieldInputTemplate, + MainModelFieldInputTemplate, + SDXLMainModelFieldInputTemplate, + SDXLRefinerModelFieldInputTemplate, + SchedulerFieldInputTemplate, + StatefulFieldType, + StatelessFieldInputTemplate, + StringFieldInputTemplate, + T2IAdapterModelFieldInputTemplate, + VAEModelFieldInputTemplate, + isStatefulFieldType, +} from '../types/field'; +import { InvocationFieldSchema } from '../types/openapi'; + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +type FieldInputTemplateBuilder = // valid `any`! + (arg: { + schemaObject: InvocationFieldSchema; + baseField: Omit; + isCollection: boolean; + isPolymorphic: boolean; + }) => T; + +const buildIntegerFieldInputTemplate: FieldInputTemplateBuilder< + IntegerFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: IntegerFieldInputTemplate = { + ...baseField, + type: { name: 'IntegerField', isCollection, isPolymorphic }, + default: schemaObject.default ?? 0, + }; + + if (schemaObject.multipleOf !== undefined) { + template.multipleOf = schemaObject.multipleOf; + } + + if (schemaObject.maximum !== undefined) { + template.maximum = schemaObject.maximum; + } + + if ( + schemaObject.exclusiveMaximum !== undefined && + isNumber(schemaObject.exclusiveMaximum) + ) { + template.exclusiveMaximum = schemaObject.exclusiveMaximum; + } + + if (schemaObject.minimum !== undefined) { + template.minimum = schemaObject.minimum; + } + + if ( + schemaObject.exclusiveMinimum !== undefined && + isNumber(schemaObject.exclusiveMinimum) + ) { + template.exclusiveMinimum = schemaObject.exclusiveMinimum; + } + + return template; +}; + +const buildFloatFieldInputTemplate: FieldInputTemplateBuilder< + FloatFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: FloatFieldInputTemplate = { + ...baseField, + type: { name: 'FloatField', isCollection, isPolymorphic }, + default: schemaObject.default ?? 0, + }; + + if (schemaObject.multipleOf !== undefined) { + template.multipleOf = schemaObject.multipleOf; + } + + if (schemaObject.maximum !== undefined) { + template.maximum = schemaObject.maximum; + } + + if ( + schemaObject.exclusiveMaximum !== undefined && + isNumber(schemaObject.exclusiveMaximum) + ) { + template.exclusiveMaximum = schemaObject.exclusiveMaximum; + } + + if (schemaObject.minimum !== undefined) { + template.minimum = schemaObject.minimum; + } + + if ( + schemaObject.exclusiveMinimum !== undefined && + isNumber(schemaObject.exclusiveMinimum) + ) { + template.exclusiveMinimum = schemaObject.exclusiveMinimum; + } + + return template; +}; + +const buildStringFieldInputTemplate: FieldInputTemplateBuilder< + StringFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: StringFieldInputTemplate = { + ...baseField, + type: { name: 'StringField', isCollection, isPolymorphic }, + default: schemaObject.default ?? '', + }; + + if (schemaObject.minLength !== undefined) { + template.minLength = schemaObject.minLength; + } + + if (schemaObject.maxLength !== undefined) { + template.maxLength = schemaObject.maxLength; + } + + return template; +}; + +const buildBooleanFieldInputTemplate: FieldInputTemplateBuilder< + BooleanFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: BooleanFieldInputTemplate = { + ...baseField, + type: { name: 'BooleanField', isCollection, isPolymorphic }, + default: schemaObject.default ?? false, + }; + + return template; +}; + +const buildMainModelFieldInputTemplate: FieldInputTemplateBuilder< + MainModelFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: MainModelFieldInputTemplate = { + ...baseField, + type: { name: 'MainModelField', isCollection, isPolymorphic }, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + +const buildSDXLMainModelFieldInputTemplate: FieldInputTemplateBuilder< + SDXLMainModelFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: SDXLMainModelFieldInputTemplate = { + ...baseField, + type: { name: 'SDXLMainModelField', isCollection, isPolymorphic }, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + +const buildRefinerModelFieldInputTemplate: FieldInputTemplateBuilder< + SDXLRefinerModelFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: SDXLRefinerModelFieldInputTemplate = { + ...baseField, + type: { name: 'SDXLRefinerModelField', isCollection, isPolymorphic }, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + +const buildVAEModelFieldInputTemplate: FieldInputTemplateBuilder< + VAEModelFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: VAEModelFieldInputTemplate = { + ...baseField, + type: { name: 'VAEModelField', isCollection, isPolymorphic }, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + +const buildLoRAModelFieldInputTemplate: FieldInputTemplateBuilder< + LoRAModelFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: LoRAModelFieldInputTemplate = { + ...baseField, + type: { name: 'LoRAModelField', isCollection, isPolymorphic }, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + +const buildControlNetModelFieldInputTemplate: FieldInputTemplateBuilder< + ControlNetModelFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: ControlNetModelFieldInputTemplate = { + ...baseField, + type: { name: 'ControlNetModelField', isCollection, isPolymorphic }, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + +const buildIPAdapterModelFieldInputTemplate: FieldInputTemplateBuilder< + IPAdapterModelFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: IPAdapterModelFieldInputTemplate = { + ...baseField, + type: { name: 'IPAdapterModelField', isCollection, isPolymorphic }, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + +const buildT2IAdapterModelFieldInputTemplate: FieldInputTemplateBuilder< + T2IAdapterModelFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: T2IAdapterModelFieldInputTemplate = { + ...baseField, + type: { name: 'T2IAdapterModelField', isCollection, isPolymorphic }, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + +const buildBoardFieldInputTemplate: FieldInputTemplateBuilder< + BoardFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: BoardFieldInputTemplate = { + ...baseField, + type: { name: 'BoardField', isCollection, isPolymorphic }, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + +const buildImageFieldInputTemplate: FieldInputTemplateBuilder< + ImageFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: ImageFieldInputTemplate = { + ...baseField, + type: { name: 'ImageField', isCollection, isPolymorphic }, + default: schemaObject.default ?? undefined, + }; + + return template; +}; + +const buildEnumFieldInputTemplate: FieldInputTemplateBuilder< + EnumFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const options = schemaObject.enum ?? []; + const template: EnumFieldInputTemplate = { + ...baseField, + type: { name: 'EnumField', isCollection, isPolymorphic }, + options, + ui_choice_labels: schemaObject.ui_choice_labels, + default: schemaObject.default ?? options[0], + }; + + return template; +}; + +const buildColorFieldInputTemplate: FieldInputTemplateBuilder< + ColorFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: ColorFieldInputTemplate = { + ...baseField, + type: { name: 'ColorField', isCollection, isPolymorphic }, + default: schemaObject.default ?? { r: 127, g: 127, b: 127, a: 255 }, + }; + + return template; +}; + +const buildSchedulerFieldInputTemplate: FieldInputTemplateBuilder< + SchedulerFieldInputTemplate +> = ({ schemaObject, baseField, isCollection, isPolymorphic }) => { + const template: SchedulerFieldInputTemplate = { + ...baseField, + type: { name: 'SchedulerField', isCollection, isPolymorphic }, + default: schemaObject.default ?? 'euler', + }; + + return template; +}; + +export const TEMPLATE_BUILDER_MAP: Record< + StatefulFieldType['name'], + FieldInputTemplateBuilder +> = { + BoardField: buildBoardFieldInputTemplate, + BooleanField: buildBooleanFieldInputTemplate, + ColorField: buildColorFieldInputTemplate, + ControlNetModelField: buildControlNetModelFieldInputTemplate, + EnumField: buildEnumFieldInputTemplate, + FloatField: buildFloatFieldInputTemplate, + ImageField: buildImageFieldInputTemplate, + IntegerField: buildIntegerFieldInputTemplate, + IPAdapterModelField: buildIPAdapterModelFieldInputTemplate, + LoRAModelField: buildLoRAModelFieldInputTemplate, + MainModelField: buildMainModelFieldInputTemplate, + SchedulerField: buildSchedulerFieldInputTemplate, + SDXLMainModelField: buildSDXLMainModelFieldInputTemplate, + SDXLRefinerModelField: buildRefinerModelFieldInputTemplate, + StringField: buildStringFieldInputTemplate, + T2IAdapterModelField: buildT2IAdapterModelFieldInputTemplate, + VAEModelField: buildVAEModelFieldInputTemplate, +}; + +export const buildFieldInputTemplate = ( + fieldSchema: InvocationFieldSchema, + name: string, + fieldType: FieldType +): FieldInputTemplate => { + const { + input, + ui_hidden, + ui_component, + ui_type, + ui_order, + ui_choice_labels, + orig_required: required, + } = fieldSchema; + + // This is the base field template that is common to all fields. The builder function will add all other + // properties to this template. + const baseField: Omit = { + name, + title: fieldSchema.title ?? (name ? startCase(name) : ''), + required, + description: fieldSchema.description ?? '', + fieldKind: 'input' as const, + input, + ui_hidden, + ui_component, + ui_type, + ui_order, + ui_choice_labels, + }; + + if (isStatefulFieldType(fieldType)) { + const builder = TEMPLATE_BUILDER_MAP[fieldType.name]; + return builder({ + schemaObject: fieldSchema, + baseField, + isCollection: fieldType.isCollection, + isPolymorphic: fieldType.isPolymorphic, + }); + } + + // This is a StatelessField, create it directly. + const template: StatelessFieldInputTemplate = { + ...baseField, + input: 'connection', // stateless --> connection only inputs + type: fieldType, + default: undefined, // stateless --> no default value + }; + return template; +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/buildWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/buildWorkflow.ts index 43ee75b735..7e49be4068 100644 --- a/invokeai/frontend/web/src/features/nodes/util/buildWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/buildWorkflow.ts @@ -1,13 +1,13 @@ import { logger } from 'app/logging/logger'; import { NodesState } from '../store/types'; -import { Workflow, zWorkflowEdge, zWorkflowNode } from '../types/types'; +import { WorkflowV2, zWorkflowEdge, zWorkflowNode } from '../types/workflow'; import { fromZodError } from 'zod-validation-error'; import { parseify } from 'common/util/serialize'; import i18n from 'i18next'; -export const buildWorkflow = (nodesState: NodesState): Workflow => { +export const buildWorkflow = (nodesState: NodesState): WorkflowV2 => { const { workflow: workflowMeta, nodes, edges } = nodesState; - const workflow: Workflow = { + const workflow: WorkflowV2 = { ...workflowMeta, nodes: [], edges: [], diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts deleted file mode 100644 index 92e44e9ab2..0000000000 --- a/invokeai/frontend/web/src/features/nodes/util/fieldTemplateBuilders.ts +++ /dev/null @@ -1,1210 +0,0 @@ -import { - isArray, - isBoolean, - isInteger, - isNumber, - isString, - startCase, -} from 'lodash-es'; -import { OpenAPIV3_1 } from 'openapi-types'; -import { ControlField } from 'services/api/types'; -import { - COLLECTION_MAP, - POLYMORPHIC_TYPES, - SINGLE_TO_POLYMORPHIC_MAP, - isCollectionItemType, - isPolymorphicItemType, -} from '../types/constants'; -import { - AnyInputFieldTemplate, - BoardInputFieldTemplate, - BooleanCollectionInputFieldTemplate, - BooleanInputFieldTemplate, - BooleanPolymorphicInputFieldTemplate, - ClipInputFieldTemplate, - CollectionInputFieldTemplate, - CollectionItemInputFieldTemplate, - ColorCollectionInputFieldTemplate, - ColorInputFieldTemplate, - ColorPolymorphicInputFieldTemplate, - ConditioningCollectionInputFieldTemplate, - ConditioningField, - ConditioningInputFieldTemplate, - ConditioningPolymorphicInputFieldTemplate, - ControlCollectionInputFieldTemplate, - ControlInputFieldTemplate, - ControlNetModelInputFieldTemplate, - ControlPolymorphicInputFieldTemplate, - DenoiseMaskInputFieldTemplate, - EnumInputFieldTemplate, - FieldType, - FloatCollectionInputFieldTemplate, - FloatInputFieldTemplate, - FloatPolymorphicInputFieldTemplate, - IPAdapterCollectionInputFieldTemplate, - IPAdapterField, - IPAdapterInputFieldTemplate, - IPAdapterModelInputFieldTemplate, - IPAdapterPolymorphicInputFieldTemplate, - ImageCollectionInputFieldTemplate, - ImageField, - ImageInputFieldTemplate, - ImagePolymorphicInputFieldTemplate, - InputFieldTemplate, - InputFieldTemplateBase, - IntegerCollectionInputFieldTemplate, - IntegerInputFieldTemplate, - IntegerPolymorphicInputFieldTemplate, - InvocationFieldSchema, - InvocationSchemaObject, - LatentsCollectionInputFieldTemplate, - LatentsField, - LatentsInputFieldTemplate, - LatentsPolymorphicInputFieldTemplate, - LoRAModelInputFieldTemplate, - MainModelInputFieldTemplate, - MetadataCollectionInputFieldTemplate, - MetadataInputFieldTemplate, - MetadataItemCollectionInputFieldTemplate, - MetadataItemInputFieldTemplate, - MetadataItemPolymorphicInputFieldTemplate, - OpenAPIV3_1SchemaOrRef, - SDXLMainModelInputFieldTemplate, - SDXLRefinerModelInputFieldTemplate, - SchedulerInputFieldTemplate, - StringCollectionInputFieldTemplate, - StringInputFieldTemplate, - StringPolymorphicInputFieldTemplate, - T2IAdapterCollectionInputFieldTemplate, - T2IAdapterField, - T2IAdapterInputFieldTemplate, - T2IAdapterModelInputFieldTemplate, - T2IAdapterPolymorphicInputFieldTemplate, - UNetInputFieldTemplate, - VaeInputFieldTemplate, - VaeModelInputFieldTemplate, - isArraySchemaObject, - isNonArraySchemaObject, - isRefObject, - isSchemaObject, -} from '../types/types'; - -export type BaseFieldProperties = 'name' | 'title' | 'description'; - -export type BuildInputFieldArg = { - schemaObject: InvocationFieldSchema; - baseField: Omit; -}; - -/** - * Transforms an invocation output ref object to field type. - * @param ref The ref string to transform - * @returns The field type. - * - * @example - * refObjectToFieldType({ "$ref": "#/components/schemas/ImageField" }) --> 'ImageField' - */ -export const refObjectToSchemaName = (refObject: OpenAPIV3_1.ReferenceObject) => - refObject.$ref.split('/').slice(-1)[0]; - -const buildIntegerInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): IntegerInputFieldTemplate => { - const template: IntegerInputFieldTemplate = { - ...baseField, - type: 'integer', - default: schemaObject.default ?? 0, - }; - - if (schemaObject.multipleOf !== undefined) { - template.multipleOf = schemaObject.multipleOf; - } - - if (schemaObject.maximum !== undefined) { - template.maximum = schemaObject.maximum; - } - - if ( - schemaObject.exclusiveMaximum !== undefined && - isNumber(schemaObject.exclusiveMaximum) - ) { - template.exclusiveMaximum = schemaObject.exclusiveMaximum; - } - - if (schemaObject.minimum !== undefined) { - template.minimum = schemaObject.minimum; - } - - if ( - schemaObject.exclusiveMinimum !== undefined && - isNumber(schemaObject.exclusiveMinimum) - ) { - template.exclusiveMinimum = schemaObject.exclusiveMinimum; - } - - return template; -}; - -const buildIntegerPolymorphicInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): IntegerPolymorphicInputFieldTemplate => { - const template: IntegerPolymorphicInputFieldTemplate = { - ...baseField, - type: 'IntegerPolymorphic', - default: schemaObject.default ?? 0, - }; - - if (schemaObject.multipleOf !== undefined) { - template.multipleOf = schemaObject.multipleOf; - } - - if (schemaObject.maximum !== undefined) { - template.maximum = schemaObject.maximum; - } - - if ( - schemaObject.exclusiveMaximum !== undefined && - isNumber(schemaObject.exclusiveMaximum) - ) { - template.exclusiveMaximum = schemaObject.exclusiveMaximum; - } - - if (schemaObject.minimum !== undefined) { - template.minimum = schemaObject.minimum; - } - - if ( - schemaObject.exclusiveMinimum !== undefined && - isNumber(schemaObject.exclusiveMinimum) - ) { - template.exclusiveMinimum = schemaObject.exclusiveMinimum; - } - - return template; -}; - -const buildIntegerCollectionInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): IntegerCollectionInputFieldTemplate => { - const item_default = - isNumber(schemaObject.item_default) && isInteger(schemaObject.item_default) - ? schemaObject.item_default - : 0; - const template: IntegerCollectionInputFieldTemplate = { - ...baseField, - type: 'IntegerCollection', - default: schemaObject.default ?? [], - item_default, - }; - - return template; -}; - -const buildFloatInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): FloatInputFieldTemplate => { - const template: FloatInputFieldTemplate = { - ...baseField, - type: 'float', - default: schemaObject.default ?? 0, - }; - - if (schemaObject.multipleOf !== undefined) { - template.multipleOf = schemaObject.multipleOf; - } - - if (schemaObject.maximum !== undefined) { - template.maximum = schemaObject.maximum; - } - - if ( - schemaObject.exclusiveMaximum !== undefined && - isNumber(schemaObject.exclusiveMaximum) - ) { - template.exclusiveMaximum = schemaObject.exclusiveMaximum; - } - - if (schemaObject.minimum !== undefined) { - template.minimum = schemaObject.minimum; - } - - if ( - schemaObject.exclusiveMinimum !== undefined && - isNumber(schemaObject.exclusiveMinimum) - ) { - template.exclusiveMinimum = schemaObject.exclusiveMinimum; - } - - return template; -}; - -const buildFloatPolymorphicInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): FloatPolymorphicInputFieldTemplate => { - const template: FloatPolymorphicInputFieldTemplate = { - ...baseField, - type: 'FloatPolymorphic', - default: schemaObject.default ?? 0, - }; - if (schemaObject.multipleOf !== undefined) { - template.multipleOf = schemaObject.multipleOf; - } - - if (schemaObject.maximum !== undefined) { - template.maximum = schemaObject.maximum; - } - - if ( - schemaObject.exclusiveMaximum !== undefined && - isNumber(schemaObject.exclusiveMaximum) - ) { - template.exclusiveMaximum = schemaObject.exclusiveMaximum; - } - - if (schemaObject.minimum !== undefined) { - template.minimum = schemaObject.minimum; - } - - if ( - schemaObject.exclusiveMinimum !== undefined && - isNumber(schemaObject.exclusiveMinimum) - ) { - template.exclusiveMinimum = schemaObject.exclusiveMinimum; - } - return template; -}; - -const buildFloatCollectionInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): FloatCollectionInputFieldTemplate => { - const item_default = isNumber(schemaObject.item_default) - ? schemaObject.item_default - : 0; - const template: FloatCollectionInputFieldTemplate = { - ...baseField, - type: 'FloatCollection', - default: schemaObject.default ?? [], - item_default, - }; - - return template; -}; - -const buildStringInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): StringInputFieldTemplate => { - const template: StringInputFieldTemplate = { - ...baseField, - type: 'string', - default: schemaObject.default ?? '', - }; - - if (schemaObject.minLength !== undefined) { - template.minLength = schemaObject.minLength; - } - - if (schemaObject.maxLength !== undefined) { - template.maxLength = schemaObject.maxLength; - } - - if (schemaObject.pattern !== undefined) { - template.pattern = schemaObject.pattern; - } - - return template; -}; - -const buildStringPolymorphicInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): StringPolymorphicInputFieldTemplate => { - const template: StringPolymorphicInputFieldTemplate = { - ...baseField, - type: 'StringPolymorphic', - default: schemaObject.default ?? '', - }; - - if (schemaObject.minLength !== undefined) { - template.minLength = schemaObject.minLength; - } - - if (schemaObject.maxLength !== undefined) { - template.maxLength = schemaObject.maxLength; - } - - if (schemaObject.pattern !== undefined) { - template.pattern = schemaObject.pattern; - } - - return template; -}; - -const buildStringCollectionInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): StringCollectionInputFieldTemplate => { - const item_default = isString(schemaObject.item_default) - ? schemaObject.item_default - : ''; - const template: StringCollectionInputFieldTemplate = { - ...baseField, - type: 'StringCollection', - default: schemaObject.default ?? [], - item_default, - }; - - return template; -}; - -const buildBooleanInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): BooleanInputFieldTemplate => { - const template: BooleanInputFieldTemplate = { - ...baseField, - type: 'boolean', - default: schemaObject.default ?? false, - }; - - return template; -}; - -const buildBooleanPolymorphicInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): BooleanPolymorphicInputFieldTemplate => { - const template: BooleanPolymorphicInputFieldTemplate = { - ...baseField, - type: 'BooleanPolymorphic', - default: schemaObject.default ?? false, - }; - - return template; -}; - -const buildBooleanCollectionInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): BooleanCollectionInputFieldTemplate => { - const item_default = - schemaObject.item_default && isBoolean(schemaObject.item_default) - ? schemaObject.item_default - : false; - const template: BooleanCollectionInputFieldTemplate = { - ...baseField, - type: 'BooleanCollection', - default: schemaObject.default ?? [], - item_default, - }; - - return template; -}; - -const buildMainModelInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): MainModelInputFieldTemplate => { - const template: MainModelInputFieldTemplate = { - ...baseField, - type: 'MainModelField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildSDXLMainModelInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): SDXLMainModelInputFieldTemplate => { - const template: SDXLMainModelInputFieldTemplate = { - ...baseField, - type: 'SDXLMainModelField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildRefinerModelInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): SDXLRefinerModelInputFieldTemplate => { - const template: SDXLRefinerModelInputFieldTemplate = { - ...baseField, - type: 'SDXLRefinerModelField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildVaeModelInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): VaeModelInputFieldTemplate => { - const template: VaeModelInputFieldTemplate = { - ...baseField, - type: 'VaeModelField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildLoRAModelInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): LoRAModelInputFieldTemplate => { - const template: LoRAModelInputFieldTemplate = { - ...baseField, - type: 'LoRAModelField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildControlNetModelInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): ControlNetModelInputFieldTemplate => { - const template: ControlNetModelInputFieldTemplate = { - ...baseField, - type: 'ControlNetModelField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildIPAdapterModelInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): IPAdapterModelInputFieldTemplate => { - const template: IPAdapterModelInputFieldTemplate = { - ...baseField, - type: 'IPAdapterModelField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildT2IAdapterModelInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): T2IAdapterModelInputFieldTemplate => { - const template: T2IAdapterModelInputFieldTemplate = { - ...baseField, - type: 'T2IAdapterModelField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildBoardInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): BoardInputFieldTemplate => { - const template: BoardInputFieldTemplate = { - ...baseField, - type: 'BoardField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildImageInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): ImageInputFieldTemplate => { - const template: ImageInputFieldTemplate = { - ...baseField, - type: 'ImageField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildImagePolymorphicInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): ImagePolymorphicInputFieldTemplate => { - const template: ImagePolymorphicInputFieldTemplate = { - ...baseField, - type: 'ImagePolymorphic', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildImageCollectionInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): ImageCollectionInputFieldTemplate => { - const template: ImageCollectionInputFieldTemplate = { - ...baseField, - type: 'ImageCollection', - default: schemaObject.default ?? [], - item_default: (schemaObject.item_default as ImageField) ?? undefined, - }; - - return template; -}; - -const buildDenoiseMaskInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): DenoiseMaskInputFieldTemplate => { - const template: DenoiseMaskInputFieldTemplate = { - ...baseField, - type: 'DenoiseMaskField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildLatentsInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): LatentsInputFieldTemplate => { - const template: LatentsInputFieldTemplate = { - ...baseField, - type: 'LatentsField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildLatentsPolymorphicInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): LatentsPolymorphicInputFieldTemplate => { - const template: LatentsPolymorphicInputFieldTemplate = { - ...baseField, - type: 'LatentsPolymorphic', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildLatentsCollectionInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): LatentsCollectionInputFieldTemplate => { - const template: LatentsCollectionInputFieldTemplate = { - ...baseField, - type: 'LatentsCollection', - default: schemaObject.default ?? [], - item_default: (schemaObject.item_default as LatentsField) ?? undefined, - }; - - return template; -}; - -const buildConditioningInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): ConditioningInputFieldTemplate => { - const template: ConditioningInputFieldTemplate = { - ...baseField, - type: 'ConditioningField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildConditioningPolymorphicInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): ConditioningPolymorphicInputFieldTemplate => { - const template: ConditioningPolymorphicInputFieldTemplate = { - ...baseField, - type: 'ConditioningPolymorphic', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildConditioningCollectionInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): ConditioningCollectionInputFieldTemplate => { - const template: ConditioningCollectionInputFieldTemplate = { - ...baseField, - type: 'ConditioningCollection', - default: schemaObject.default ?? [], - item_default: (schemaObject.item_default as ConditioningField) ?? undefined, - }; - - return template; -}; - -const buildUNetInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): UNetInputFieldTemplate => { - const template: UNetInputFieldTemplate = { - ...baseField, - type: 'UNetField', - - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildClipInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): ClipInputFieldTemplate => { - const template: ClipInputFieldTemplate = { - ...baseField, - type: 'ClipField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildVaeInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): VaeInputFieldTemplate => { - const template: VaeInputFieldTemplate = { - ...baseField, - type: 'VaeField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildControlInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): ControlInputFieldTemplate => { - const template: ControlInputFieldTemplate = { - ...baseField, - type: 'ControlField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildControlPolymorphicInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): ControlPolymorphicInputFieldTemplate => { - const template: ControlPolymorphicInputFieldTemplate = { - ...baseField, - type: 'ControlPolymorphic', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildControlCollectionInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): ControlCollectionInputFieldTemplate => { - const template: ControlCollectionInputFieldTemplate = { - ...baseField, - type: 'ControlCollection', - default: schemaObject.default ?? [], - item_default: (schemaObject.item_default as ControlField) ?? undefined, - }; - - return template; -}; - -const buildIPAdapterInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): IPAdapterInputFieldTemplate => { - const template: IPAdapterInputFieldTemplate = { - ...baseField, - type: 'IPAdapterField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildIPAdapterPolymorphicInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): IPAdapterPolymorphicInputFieldTemplate => { - const template: IPAdapterPolymorphicInputFieldTemplate = { - ...baseField, - type: 'IPAdapterPolymorphic', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildIPAdapterCollectionInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): IPAdapterCollectionInputFieldTemplate => { - const template: IPAdapterCollectionInputFieldTemplate = { - ...baseField, - type: 'IPAdapterCollection', - default: schemaObject.default ?? [], - item_default: (schemaObject.item_default as IPAdapterField) ?? undefined, - }; - - return template; -}; - -const buildT2IAdapterInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): T2IAdapterInputFieldTemplate => { - const template: T2IAdapterInputFieldTemplate = { - ...baseField, - type: 'T2IAdapterField', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildT2IAdapterPolymorphicInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): T2IAdapterPolymorphicInputFieldTemplate => { - const template: T2IAdapterPolymorphicInputFieldTemplate = { - ...baseField, - type: 'T2IAdapterPolymorphic', - default: schemaObject.default ?? undefined, - }; - - return template; -}; - -const buildT2IAdapterCollectionInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): T2IAdapterCollectionInputFieldTemplate => { - const template: T2IAdapterCollectionInputFieldTemplate = { - ...baseField, - type: 'T2IAdapterCollection', - default: schemaObject.default ?? [], - item_default: (schemaObject.item_default as T2IAdapterField) ?? undefined, - }; - - return template; -}; - -const buildEnumInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): EnumInputFieldTemplate => { - const options = schemaObject.enum ?? []; - const template: EnumInputFieldTemplate = { - ...baseField, - type: 'enum', - options, - ui_choice_labels: schemaObject.ui_choice_labels, - default: schemaObject.default ?? options[0], - }; - - return template; -}; - -const buildCollectionInputFieldTemplate = ({ - baseField, -}: BuildInputFieldArg): CollectionInputFieldTemplate => { - const template: CollectionInputFieldTemplate = { - ...baseField, - type: 'Collection', - default: [], - }; - - return template; -}; - -const buildCollectionItemInputFieldTemplate = ({ - baseField, -}: BuildInputFieldArg): CollectionItemInputFieldTemplate => { - const template: CollectionItemInputFieldTemplate = { - ...baseField, - type: 'CollectionItem', - default: undefined, - }; - - return template; -}; - -const buildAnyInputFieldTemplate = ({ - baseField, -}: BuildInputFieldArg): AnyInputFieldTemplate => { - const template: AnyInputFieldTemplate = { - ...baseField, - type: 'Any', - default: undefined, - }; - - return template; -}; - -const buildMetadataItemInputFieldTemplate = ({ - baseField, -}: BuildInputFieldArg): MetadataItemInputFieldTemplate => { - const template: MetadataItemInputFieldTemplate = { - ...baseField, - type: 'MetadataItemField', - default: undefined, - }; - - return template; -}; - -const buildMetadataItemCollectionInputFieldTemplate = ({ - baseField, -}: BuildInputFieldArg): MetadataItemCollectionInputFieldTemplate => { - const template: MetadataItemCollectionInputFieldTemplate = { - ...baseField, - type: 'MetadataItemCollection', - default: undefined, - }; - - return template; -}; - -const buildMetadataItemPolymorphicInputFieldTemplate = ({ - baseField, -}: BuildInputFieldArg): MetadataItemPolymorphicInputFieldTemplate => { - const template: MetadataItemPolymorphicInputFieldTemplate = { - ...baseField, - type: 'MetadataItemPolymorphic', - default: undefined, - }; - - return template; -}; - -const buildMetadataDictInputFieldTemplate = ({ - baseField, -}: BuildInputFieldArg): MetadataInputFieldTemplate => { - const template: MetadataInputFieldTemplate = { - ...baseField, - type: 'MetadataField', - default: undefined, - }; - - return template; -}; - -const buildMetadataCollectionInputFieldTemplate = ({ - baseField, -}: BuildInputFieldArg): MetadataCollectionInputFieldTemplate => { - const template: MetadataCollectionInputFieldTemplate = { - ...baseField, - type: 'MetadataCollection', - default: undefined, - }; - - return template; -}; - -const buildColorInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): ColorInputFieldTemplate => { - const template: ColorInputFieldTemplate = { - ...baseField, - type: 'ColorField', - default: schemaObject.default ?? { r: 127, g: 127, b: 127, a: 255 }, - }; - - return template; -}; - -const buildColorPolymorphicInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): ColorPolymorphicInputFieldTemplate => { - const template: ColorPolymorphicInputFieldTemplate = { - ...baseField, - type: 'ColorPolymorphic', - default: schemaObject.default ?? { r: 127, g: 127, b: 127, a: 255 }, - }; - - return template; -}; - -const buildColorCollectionInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): ColorCollectionInputFieldTemplate => { - const template: ColorCollectionInputFieldTemplate = { - ...baseField, - type: 'ColorCollection', - default: schemaObject.default ?? [], - }; - - return template; -}; - -const buildSchedulerInputFieldTemplate = ({ - schemaObject, - baseField, -}: BuildInputFieldArg): SchedulerInputFieldTemplate => { - const template: SchedulerInputFieldTemplate = { - ...baseField, - type: 'Scheduler', - default: schemaObject.default ?? 'euler', - }; - - return template; -}; - -export const getFieldType = ( - schemaObject: OpenAPIV3_1SchemaOrRef -): string | undefined => { - if (isSchemaObject(schemaObject)) { - if (!schemaObject.type) { - // if schemaObject has no type, then it should have one of allOf, anyOf, oneOf - - if (schemaObject.allOf) { - const allOf = schemaObject.allOf; - if (allOf && allOf[0] && isRefObject(allOf[0])) { - return refObjectToSchemaName(allOf[0]); - } - } else if (schemaObject.anyOf) { - // ignore null types - const anyOf = schemaObject.anyOf.filter((i) => { - if (isSchemaObject(i)) { - if (i.type === 'null') { - return false; - } - } - return true; - }); - if (anyOf.length === 1) { - if (isRefObject(anyOf[0])) { - return refObjectToSchemaName(anyOf[0]); - } else if (isSchemaObject(anyOf[0])) { - return getFieldType(anyOf[0]); - } - } - /** - * Handle Polymorphic inputs, eg string | string[]. In OpenAPI, this is: - * - an `anyOf` with two items - * - one is an `ArraySchemaObject` with a single `SchemaObject or ReferenceObject` of type T in its `items` - * - the other is a `SchemaObject` or `ReferenceObject` of type T - * - * Any other cases we ignore. - */ - - let firstType: string | undefined; - let secondType: string | undefined; - - if (isArraySchemaObject(anyOf[0])) { - // first is array, second is not - const first = anyOf[0].items; - const second = anyOf[1]; - if (isRefObject(first) && isRefObject(second)) { - firstType = refObjectToSchemaName(first); - secondType = refObjectToSchemaName(second); - } else if ( - isNonArraySchemaObject(first) && - isNonArraySchemaObject(second) - ) { - firstType = first.type; - secondType = second.type; - } - } else if (isArraySchemaObject(anyOf[1])) { - // first is not array, second is - const first = anyOf[0]; - const second = anyOf[1].items; - if (isRefObject(first) && isRefObject(second)) { - firstType = refObjectToSchemaName(first); - secondType = refObjectToSchemaName(second); - } else if ( - isNonArraySchemaObject(first) && - isNonArraySchemaObject(second) - ) { - firstType = first.type; - secondType = second.type; - } - } - if (firstType === secondType && isPolymorphicItemType(firstType)) { - return SINGLE_TO_POLYMORPHIC_MAP[firstType]; - } - } - } else if (schemaObject.enum) { - return 'enum'; - } else if (schemaObject.type) { - if (schemaObject.type === 'number') { - // floats are "number" in OpenAPI, while ints are "integer" - we need to distinguish them - return 'float'; - } else if (schemaObject.type === 'array') { - const itemType = isSchemaObject(schemaObject.items) - ? schemaObject.items.type - : refObjectToSchemaName(schemaObject.items); - - if (isArray(itemType)) { - // This is a nested array, which we don't support - return; - } - - if (isCollectionItemType(itemType)) { - return COLLECTION_MAP[itemType]; - } - - return; - } else if (!isArray(schemaObject.type)) { - return schemaObject.type; - } - } - } else if (isRefObject(schemaObject)) { - return refObjectToSchemaName(schemaObject); - } - return; -}; - -const TEMPLATE_BUILDER_MAP: { - [key in FieldType]?: (arg: BuildInputFieldArg) => InputFieldTemplate; -} = { - BoardField: buildBoardInputFieldTemplate, - Any: buildAnyInputFieldTemplate, - boolean: buildBooleanInputFieldTemplate, - BooleanCollection: buildBooleanCollectionInputFieldTemplate, - BooleanPolymorphic: buildBooleanPolymorphicInputFieldTemplate, - ClipField: buildClipInputFieldTemplate, - Collection: buildCollectionInputFieldTemplate, - CollectionItem: buildCollectionItemInputFieldTemplate, - ColorCollection: buildColorCollectionInputFieldTemplate, - ColorField: buildColorInputFieldTemplate, - ColorPolymorphic: buildColorPolymorphicInputFieldTemplate, - ConditioningCollection: buildConditioningCollectionInputFieldTemplate, - ConditioningField: buildConditioningInputFieldTemplate, - ConditioningPolymorphic: buildConditioningPolymorphicInputFieldTemplate, - ControlCollection: buildControlCollectionInputFieldTemplate, - ControlField: buildControlInputFieldTemplate, - ControlNetModelField: buildControlNetModelInputFieldTemplate, - ControlPolymorphic: buildControlPolymorphicInputFieldTemplate, - DenoiseMaskField: buildDenoiseMaskInputFieldTemplate, - enum: buildEnumInputFieldTemplate, - float: buildFloatInputFieldTemplate, - FloatCollection: buildFloatCollectionInputFieldTemplate, - FloatPolymorphic: buildFloatPolymorphicInputFieldTemplate, - ImageCollection: buildImageCollectionInputFieldTemplate, - ImageField: buildImageInputFieldTemplate, - ImagePolymorphic: buildImagePolymorphicInputFieldTemplate, - integer: buildIntegerInputFieldTemplate, - IntegerCollection: buildIntegerCollectionInputFieldTemplate, - IntegerPolymorphic: buildIntegerPolymorphicInputFieldTemplate, - IPAdapterCollection: buildIPAdapterCollectionInputFieldTemplate, - IPAdapterField: buildIPAdapterInputFieldTemplate, - IPAdapterModelField: buildIPAdapterModelInputFieldTemplate, - IPAdapterPolymorphic: buildIPAdapterPolymorphicInputFieldTemplate, - LatentsCollection: buildLatentsCollectionInputFieldTemplate, - LatentsField: buildLatentsInputFieldTemplate, - LatentsPolymorphic: buildLatentsPolymorphicInputFieldTemplate, - LoRAModelField: buildLoRAModelInputFieldTemplate, - MetadataItemField: buildMetadataItemInputFieldTemplate, - MetadataItemCollection: buildMetadataItemCollectionInputFieldTemplate, - MetadataItemPolymorphic: buildMetadataItemPolymorphicInputFieldTemplate, - MetadataField: buildMetadataDictInputFieldTemplate, - MetadataCollection: buildMetadataCollectionInputFieldTemplate, - MainModelField: buildMainModelInputFieldTemplate, - Scheduler: buildSchedulerInputFieldTemplate, - SDXLMainModelField: buildSDXLMainModelInputFieldTemplate, - SDXLRefinerModelField: buildRefinerModelInputFieldTemplate, - string: buildStringInputFieldTemplate, - StringCollection: buildStringCollectionInputFieldTemplate, - StringPolymorphic: buildStringPolymorphicInputFieldTemplate, - T2IAdapterCollection: buildT2IAdapterCollectionInputFieldTemplate, - T2IAdapterField: buildT2IAdapterInputFieldTemplate, - T2IAdapterModelField: buildT2IAdapterModelInputFieldTemplate, - T2IAdapterPolymorphic: buildT2IAdapterPolymorphicInputFieldTemplate, - UNetField: buildUNetInputFieldTemplate, - VaeField: buildVaeInputFieldTemplate, - VaeModelField: buildVaeModelInputFieldTemplate, -}; - -const isTemplatedFieldType = ( - fieldType: string | undefined -): fieldType is keyof typeof TEMPLATE_BUILDER_MAP => - Boolean(fieldType && fieldType in TEMPLATE_BUILDER_MAP); - -/** - * Builds an input field from an invocation schema property. - * @param fieldSchema The schema object - * @returns An input field - */ -export const buildInputFieldTemplate = ( - nodeSchema: InvocationSchemaObject, - fieldSchema: InvocationFieldSchema, - name: string, - fieldType: FieldType -) => { - const { - input, - ui_hidden, - ui_component, - ui_type, - ui_order, - ui_choice_labels, - item_default, - } = fieldSchema; - - const extra = { - // TODO: Can we support polymorphic inputs in the UI? - input: POLYMORPHIC_TYPES.includes(fieldType) ? 'connection' : input, - ui_hidden, - ui_component, - ui_type, - required: nodeSchema.required?.includes(name) ?? false, - ui_order, - ui_choice_labels, - item_default, - }; - - const baseField = { - name, - title: fieldSchema.title ?? (name ? startCase(name) : ''), - description: fieldSchema.description ?? '', - fieldKind: 'input' as const, - ...extra, - }; - - if (!isTemplatedFieldType(fieldType)) { - return; - } - - const builder = TEMPLATE_BUILDER_MAP[fieldType]; - - if (!builder) { - return; - } - - return builder({ - schemaObject: fieldSchema, - baseField, - }); -}; diff --git a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts b/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts deleted file mode 100644 index ca2513649d..0000000000 --- a/invokeai/frontend/web/src/features/nodes/util/fieldValueBuilders.ts +++ /dev/null @@ -1,85 +0,0 @@ -import { FieldType, InputFieldTemplate, InputFieldValue } from '../types/types'; - -const FIELD_VALUE_FALLBACK_MAP: { - [key in FieldType]: InputFieldValue['value']; -} = { - Any: undefined, - enum: '', - BoardField: undefined, - boolean: false, - BooleanCollection: [], - BooleanPolymorphic: false, - ClipField: undefined, - Collection: [], - CollectionItem: undefined, - ColorCollection: [], - ColorField: undefined, - ColorPolymorphic: undefined, - ConditioningCollection: [], - ConditioningField: undefined, - ConditioningPolymorphic: undefined, - ControlCollection: [], - ControlField: undefined, - ControlNetModelField: undefined, - ControlPolymorphic: undefined, - DenoiseMaskField: undefined, - float: 0, - FloatCollection: [], - FloatPolymorphic: 0, - ImageCollection: [], - ImageField: undefined, - ImagePolymorphic: undefined, - integer: 0, - IntegerCollection: [], - IntegerPolymorphic: 0, - IPAdapterCollection: [], - IPAdapterField: undefined, - IPAdapterModelField: undefined, - IPAdapterPolymorphic: undefined, - LatentsCollection: [], - LatentsField: undefined, - LatentsPolymorphic: undefined, - MetadataItemField: undefined, - MetadataItemCollection: [], - MetadataItemPolymorphic: undefined, - MetadataField: undefined, - MetadataCollection: [], - LoRAModelField: undefined, - MainModelField: undefined, - ONNXModelField: undefined, - Scheduler: 'euler', - SDXLMainModelField: undefined, - SDXLRefinerModelField: undefined, - string: '', - StringCollection: [], - StringPolymorphic: '', - T2IAdapterCollection: [], - T2IAdapterField: undefined, - T2IAdapterModelField: undefined, - T2IAdapterPolymorphic: undefined, - UNetField: undefined, - VaeField: undefined, - VaeModelField: undefined, -}; - -export const buildInputFieldValue = ( - id: string, - template: InputFieldTemplate -): InputFieldValue => { - // TODO: this should be `fieldValue: InputFieldValue`, but that introduces a TS issue I couldn't - // resolve - for some reason, it doesn't like `template.type`, which is the discriminant for both - // `InputFieldTemplate` union. It is (type-structurally) equal to the discriminant for the - // `InputFieldValue` union, but TS doesn't seem to like it... - const fieldValue = { - id, - name: template.name, - type: template.type, - label: '', - fieldKind: 'input', - } as InputFieldValue; - - fieldValue.value = - template.default ?? FIELD_VALUE_FALLBACK_MAP[template.type]; - - return fieldValue; -}; diff --git a/invokeai/frontend/web/src/features/nodes/util/getSortedFilteredFieldNames.ts b/invokeai/frontend/web/src/features/nodes/util/getSortedFilteredFieldNames.ts index b235fe8a07..2ed5faca29 100644 --- a/invokeai/frontend/web/src/features/nodes/util/getSortedFilteredFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/util/getSortedFilteredFieldNames.ts @@ -1,8 +1,8 @@ import { isNil } from 'lodash-es'; -import { InputFieldTemplate, OutputFieldTemplate } from '../types/types'; +import { FieldInputTemplate, FieldOutputTemplate } from '../types/field'; export const getSortedFilteredFieldNames = ( - fields: InputFieldTemplate[] | OutputFieldTemplate[] + fields: FieldInputTemplate[] | FieldOutputTemplate[] ): string[] => { const visibleFields = fields.filter((field) => !field.ui_hidden); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addControlNetToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addControlNetToLinearGraph.ts index 60d4e36dca..ff6028b38e 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addControlNetToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addControlNetToLinearGraph.ts @@ -6,8 +6,8 @@ import { ControlField, ControlNetInvocation, CoreMetadataInvocation, + NonNullableGraph, } from 'services/api/types'; -import { NonNullableGraph } from '../../types/types'; import { CANVAS_COHERENCE_DENOISE_LATENTS, CONTROL_NET_COLLECT, diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addHrfToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addHrfToGraph.ts index 9825ce754e..edbba61b73 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addHrfToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addHrfToGraph.ts @@ -1,13 +1,13 @@ import { logger } from 'app/logging/logger'; import { RootState } from 'app/store/store'; import { roundToMultiple } from 'common/util/roundDownToMultiple'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { DenoiseLatentsInvocation, ESRGANInvocation, Edge, LatentsToImageInvocation, NoiseInvocation, + NonNullableGraph, } from 'services/api/types'; import { DENOISE_LATENTS, diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addIPAdapterToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addIPAdapterToLinearGraph.ts index 93c6cdb284..9dd8b25368 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addIPAdapterToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addIPAdapterToLinearGraph.ts @@ -6,8 +6,8 @@ import { CoreMetadataInvocation, IPAdapterInvocation, IPAdapterMetadataField, + NonNullableGraph, } from 'services/api/types'; -import { NonNullableGraph } from '../../types/types'; import { CANVAS_COHERENCE_DENOISE_LATENTS, IP_ADAPTER_COLLECT, diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLinearUIOutputNode.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLinearUIOutputNode.ts index 926fa3a8f3..1676a5f53d 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLinearUIOutputNode.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLinearUIOutputNode.ts @@ -1,7 +1,6 @@ import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; -import { LinearUIOutputInvocation } from 'services/api/types'; +import { LinearUIOutputInvocation, NonNullableGraph } from 'services/api/types'; import { CANVAS_OUTPUT, LATENTS_TO_IMAGE, diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts index 66c2bd0444..acbe53e611 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts @@ -1,9 +1,9 @@ import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { forEach, size } from 'lodash-es'; import { CoreMetadataInvocation, LoraLoaderInvocation, + NonNullableGraph, } from 'services/api/types'; import { CANVAS_COHERENCE_DENOISE_LATENTS, diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addNSFWCheckerToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addNSFWCheckerToGraph.ts index 94fddccc8f..d4cd5e83cb 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addNSFWCheckerToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addNSFWCheckerToGraph.ts @@ -1,8 +1,8 @@ import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { ImageNSFWBlurInvocation, LatentsToImageInvocation, + NonNullableGraph, } from 'services/api/types'; import { LATENTS_TO_IMAGE, NSFW_CHECKER } from './constants'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLLoRAstoGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLLoRAstoGraph.ts index 04841f0def..544958c39d 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLLoRAstoGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLLoRAstoGraph.ts @@ -1,11 +1,10 @@ import { RootState } from 'app/store/store'; import { LoRAMetadataItem, - NonNullableGraph, zLoRAMetadataItem, -} from 'features/nodes/types/types'; +} from 'features/nodes/types/metadata'; import { forEach, size } from 'lodash-es'; -import { SDXLLoraLoaderInvocation } from 'services/api/types'; +import { NonNullableGraph, SDXLLoraLoaderInvocation } from 'services/api/types'; import { CANVAS_COHERENCE_DENOISE_LATENTS, LORA_LOADER, diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLRefinerToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLRefinerToGraph.ts index 136263f63e..8976d7ed5f 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLRefinerToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLRefinerToGraph.ts @@ -2,9 +2,9 @@ import { RootState } from 'app/store/store'; import { CreateDenoiseMaskInvocation, ImageDTO, + NonNullableGraph, SeamlessModeInvocation, } from 'services/api/types'; -import { NonNullableGraph } from '../../types/types'; import { CANVAS_OUTPUT, INPAINT_IMAGE_RESIZE_UP, diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSeamlessToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSeamlessToLinearGraph.ts index ba341a8a3d..d062b25309 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSeamlessToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSeamlessToLinearGraph.ts @@ -1,7 +1,5 @@ import { RootState } from 'app/store/store'; -import { SeamlessModeInvocation } from 'services/api/types'; -import { NonNullableGraph } from '../../types/types'; -import { upsertMetadata } from './metadata'; +import { NonNullableGraph, SeamlessModeInvocation } from 'services/api/types'; import { CANVAS_COHERENCE_DENOISE_LATENTS, CANVAS_INPAINT_GRAPH, @@ -16,6 +14,7 @@ import { SDXL_TEXT_TO_IMAGE_GRAPH, SEAMLESS, } from './constants'; +import { upsertMetadata } from './metadata'; export const addSeamlessToLinearGraph = ( state: RootState, diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addT2IAdapterToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addT2IAdapterToLinearGraph.ts index 71c2aaeede..550f9ba5f3 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addT2IAdapterToLinearGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addT2IAdapterToLinearGraph.ts @@ -4,9 +4,10 @@ import { omit } from 'lodash-es'; import { CollectInvocation, CoreMetadataInvocation, + NonNullableGraph, + T2IAdapterField, T2IAdapterInvocation, } from 'services/api/types'; -import { NonNullableGraph, T2IAdapterField } from '../../types/types'; import { CANVAS_COHERENCE_DENOISE_LATENTS, T2I_ADAPTER_COLLECT, diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts index f049a89e36..438bbfd892 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts @@ -1,5 +1,5 @@ import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; +import { NonNullableGraph } from 'services/api/types'; import { CANVAS_COHERENCE_INPAINT_CREATE_MASK, CANVAS_IMAGE_TO_IMAGE_GRAPH, diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addWatermarkerToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addWatermarkerToGraph.ts index c43437e4fc..f553e6d0f9 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addWatermarkerToGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addWatermarkerToGraph.ts @@ -1,10 +1,10 @@ import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { ImageNSFWBlurInvocation, ImageWatermarkInvocation, LatentsToImageInvocation, + NonNullableGraph, } from 'services/api/types'; import { LATENTS_TO_IMAGE, NSFW_CHECKER, WATERMARKER } from './constants'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildAdHocUpscaleGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildAdHocUpscaleGraph.ts index 8331c81eb3..60143252ab 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildAdHocUpscaleGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildAdHocUpscaleGraph.ts @@ -1,10 +1,10 @@ import { BoardId } from 'features/gallery/store/types'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { ESRGANModelName } from 'features/parameters/store/postprocessingSlice'; import { ESRGANInvocation, Graph, LinearUIOutputInvocation, + NonNullableGraph, } from 'services/api/types'; import { ESRGAN, LINEAR_UI_OUTPUT } from './constants'; import { addCoreMetadataNode, upsertMetadata } from './metadata'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasGraph.ts index d268a3990d..66500c9ce5 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasGraph.ts @@ -1,6 +1,5 @@ import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; -import { ImageDTO } from 'services/api/types'; +import { ImageDTO, NonNullableGraph } from 'services/api/types'; import { buildCanvasImageToImageGraph } from './buildCanvasImageToImageGraph'; import { buildCanvasInpaintGraph } from './buildCanvasInpaintGraph'; import { buildCanvasOutpaintGraph } from './buildCanvasOutpaintGraph'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts index a86fdb4ce6..2866aeff07 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts @@ -1,12 +1,15 @@ import { logger } from 'app/logging/logger'; import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; -import { ImageDTO, ImageToLatentsInvocation } from 'services/api/types'; +import { + ImageDTO, + ImageToLatentsInvocation, + NonNullableGraph, +} from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; +import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph'; import { addVAEToGraph } from './addVAEToGraph'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts index 48052e2a94..6253ce1f18 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts @@ -1,6 +1,5 @@ import { logger } from 'app/logging/logger'; import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { CreateDenoiseMaskInvocation, ImageBlurInvocation, @@ -8,12 +7,13 @@ import { ImageToLatentsInvocation, MaskEdgeInvocation, NoiseInvocation, + NonNullableGraph, } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; +import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph'; import { addVAEToGraph } from './addVAEToGraph'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasOutpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasOutpaintGraph.ts index 31cf5ca7e8..11573c21c2 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasOutpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasOutpaintGraph.ts @@ -1,18 +1,18 @@ import { logger } from 'app/logging/logger'; import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { ImageDTO, ImageToLatentsInvocation, InfillPatchMatchInvocation, InfillTileInvocation, NoiseInvocation, + NonNullableGraph, } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; +import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph'; import { addVAEToGraph } from './addVAEToGraph'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLImageToImageGraph.ts index 8281c9c248..f579a7d9e7 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLImageToImageGraph.ts @@ -1,14 +1,18 @@ import { logger } from 'app/logging/logger'; import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; -import { ImageDTO, ImageToLatentsInvocation } from 'services/api/types'; +import { + ImageDTO, + ImageToLatentsInvocation, + NonNullableGraph, +} from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; +import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; +import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph'; import { addVAEToGraph } from './addVAEToGraph'; import { addWatermarkerToGraph } from './addWatermarkerToGraph'; import { @@ -26,7 +30,6 @@ import { SEAMLESS, } from './constants'; import { buildSDXLStylePrompts } from './helpers/craftSDXLStylePrompt'; -import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph'; import { addCoreMetadataNode } from './metadata'; /** diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLInpaintGraph.ts index 40626e289a..de1dd0dfd2 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLInpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLInpaintGraph.ts @@ -1,6 +1,5 @@ import { logger } from 'app/logging/logger'; import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { CreateDenoiseMaskInvocation, ImageBlurInvocation, @@ -8,13 +7,14 @@ import { ImageToLatentsInvocation, MaskEdgeInvocation, NoiseInvocation, + NonNullableGraph, } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; +import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph'; import { addVAEToGraph } from './addVAEToGraph'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLOutpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLOutpaintGraph.ts index c7302cd56d..2f8b4fd653 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLOutpaintGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLOutpaintGraph.ts @@ -1,19 +1,19 @@ import { logger } from 'app/logging/logger'; import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { ImageDTO, ImageToLatentsInvocation, InfillPatchMatchInvocation, InfillTileInvocation, NoiseInvocation, + NonNullableGraph, } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; +import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph'; import { addVAEToGraph } from './addVAEToGraph'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLTextToImageGraph.ts index 2a712f2ef3..0a456aaccf 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLTextToImageGraph.ts @@ -1,16 +1,16 @@ import { logger } from 'app/logging/logger'; import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { DenoiseLatentsInvocation, + NonNullableGraph, ONNXTextToLatentsInvocation, } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; +import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph'; import { addVAEToGraph } from './addVAEToGraph'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts index 5c0c91ca71..72d7c1e460 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts @@ -1,15 +1,15 @@ import { logger } from 'app/logging/logger'; import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { DenoiseLatentsInvocation, + NonNullableGraph, ONNXTextToLatentsInvocation, } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; +import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph'; import { addVAEToGraph } from './addVAEToGraph'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearBatchConfig.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearBatchConfig.ts index 59f8d4123f..865a4535ae 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearBatchConfig.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearBatchConfig.ts @@ -1,10 +1,9 @@ import { NUMPY_RAND_MAX } from 'app/constants'; import { RootState } from 'app/store/store'; import { generateSeeds } from 'common/util/generateSeeds'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { range } from 'lodash-es'; import { components } from 'services/api/schema'; -import { Batch, BatchConfig } from 'services/api/types'; +import { Batch, BatchConfig, NonNullableGraph } from 'services/api/types'; import { CANVAS_COHERENCE_NOISE, METADATA, diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts index d1897ad9ee..8d4d0ae35f 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts @@ -1,15 +1,15 @@ import { logger } from 'app/logging/logger'; import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { ImageResizeInvocation, ImageToLatentsInvocation, + NonNullableGraph, } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; +import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph'; import { addVAEToGraph } from './addVAEToGraph'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph.ts index 0b57dcd5bf..621219eb67 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph.ts @@ -1,16 +1,16 @@ import { logger } from 'app/logging/logger'; import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { ImageResizeInvocation, ImageToLatentsInvocation, + NonNullableGraph, } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; +import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph'; import { addVAEToGraph } from './addVAEToGraph'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph.ts index 37e9b293c6..df28fbbd62 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph.ts @@ -1,17 +1,16 @@ import { logger } from 'app/logging/logger'; import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; +import { NonNullableGraph } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; +import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; import { addSDXLLoRAsToGraph } from './addSDXLLoRAstoGraph'; import { addSDXLRefinerToGraph } from './addSDXLRefinerToGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph'; import { addVAEToGraph } from './addVAEToGraph'; import { addWatermarkerToGraph } from './addWatermarkerToGraph'; -import { addCoreMetadataNode } from './metadata'; import { LATENTS_TO_IMAGE, NEGATIVE_CONDITIONING, @@ -24,6 +23,7 @@ import { SEAMLESS, } from './constants'; import { buildSDXLStylePrompts } from './helpers/craftSDXLStylePrompt'; +import { addCoreMetadataNode } from './metadata'; export const buildLinearSDXLTextToImageGraph = ( state: RootState diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts index f097cf0c42..ffec2a409f 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts @@ -1,21 +1,20 @@ import { logger } from 'app/logging/logger'; import { RootState } from 'app/store/store'; -import { NonNullableGraph } from 'features/nodes/types/types'; import { DenoiseLatentsInvocation, + NonNullableGraph, ONNXTextToLatentsInvocation, } from 'services/api/types'; import { addControlNetToLinearGraph } from './addControlNetToLinearGraph'; import { addHrfToGraph } from './addHrfToGraph'; import { addIPAdapterToLinearGraph } from './addIPAdapterToLinearGraph'; +import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addLoRAsToGraph } from './addLoRAsToGraph'; import { addNSFWCheckerToGraph } from './addNSFWCheckerToGraph'; -import { addLinearUIOutputNode } from './addLinearUIOutputNode'; import { addSeamlessToLinearGraph } from './addSeamlessToLinearGraph'; import { addT2IAdaptersToLinearGraph } from './addT2IAdapterToLinearGraph'; import { addVAEToGraph } from './addVAEToGraph'; import { addWatermarkerToGraph } from './addWatermarkerToGraph'; -import { addCoreMetadataNode } from './metadata'; import { CLIP_SKIP, DENOISE_LATENTS, @@ -28,6 +27,7 @@ import { SEAMLESS, TEXT_TO_IMAGE_GRAPH, } from './constants'; +import { addCoreMetadataNode } from './metadata'; export const buildLinearTextToImageGraph = ( state: RootState diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts index eb782f456a..9ed0eb1d32 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts @@ -1,16 +1,20 @@ import { NodesState } from 'features/nodes/store/types'; -import { InputFieldValue, isInvocationNode } from 'features/nodes/types/types'; +import { isInvocationNode } from 'features/nodes/types/invocation'; import { cloneDeep, omit, reduce } from 'lodash-es'; import { Graph } from 'services/api/types'; import { AnyInvocation } from 'services/events/types'; import { v4 as uuidv4 } from 'uuid'; import { buildWorkflow } from '../buildWorkflow'; +import { + FieldInputInstance, + isColorFieldInputInstance, +} from 'features/nodes/types/field'; /** * We need to do special handling for some fields */ -export const parseFieldValue = (field: InputFieldValue) => { - if (field.type === 'ColorField') { +export const parseFieldValue = (field: FieldInputInstance) => { + if (isColorFieldInputInstance(field)) { if (field.value) { const clonedValue = cloneDeep(field.value); diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/metadata.ts b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/metadata.ts index c80e1c80c6..f78a0af035 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/metadata.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graphBuilders/metadata.ts @@ -1,5 +1,4 @@ -import { NonNullableGraph } from 'features/nodes/types/types'; -import { CoreMetadataInvocation } from 'services/api/types'; +import { CoreMetadataInvocation, NonNullableGraph } from 'services/api/types'; import { JsonObject } from 'type-fest'; import { METADATA } from './constants'; diff --git a/invokeai/frontend/web/src/features/nodes/util/parseFieldType.ts b/invokeai/frontend/web/src/features/nodes/util/parseFieldType.ts new file mode 100644 index 0000000000..133a3d11c9 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/parseFieldType.ts @@ -0,0 +1,233 @@ +import { t } from 'i18next'; +import { isArray } from 'lodash-es'; +import { OpenAPIV3_1 } from 'openapi-types'; +import { FieldTypeParseError, UnsupportedFieldTypeError } from '../types/error'; +import { FieldType } from '../types/field'; +import { + OpenAPIV3_1SchemaOrRef, + isArraySchemaObject, + isInvocationFieldSchema, + isNonArraySchemaObject, + isRefObject, + isSchemaObject, +} from '../types/openapi'; + +/** + * Transforms an invocation output ref object to field type. + * @param ref The ref string to transform + * @returns The field type. + * + * @example + * refObjectToFieldType({ "$ref": "#/components/schemas/ImageField" }) --> 'ImageField' + */ +export const refObjectToSchemaName = (refObject: OpenAPIV3_1.ReferenceObject) => + refObject.$ref.split('/').slice(-1)[0]; + +const OPENAPI_TO_FIELD_TYPE_MAP: Record = { + integer: 'IntegerField', + number: 'FloatField', + string: 'StringField', + boolean: 'BooleanField', +}; + +const isCollectionFieldType = (fieldType: string) => { + /** + * CollectionField is `list[Any]` in the pydantic schema, but we need to distinguish between + * it and other `list[Any]` fields, due to its special internal handling. + * + * In pydantic, it gets an explicit field type of `CollectionField`. + */ + if (fieldType === 'CollectionField') { + return true; + } + return false; +}; + +export const parseFieldType = ( + schemaObject: OpenAPIV3_1SchemaOrRef +): FieldType => { + if (isInvocationFieldSchema(schemaObject)) { + // Check if this field has an explicit type provided by the node schema + const { ui_type } = schemaObject; + if (ui_type) { + return { + name: ui_type, + isCollection: isCollectionFieldType(ui_type), + isPolymorphic: false, + }; + } + } + if (isSchemaObject(schemaObject)) { + if (!schemaObject.type) { + // if schemaObject has no type, then it should have one of allOf, anyOf, oneOf + + if (schemaObject.allOf) { + const allOf = schemaObject.allOf; + if (allOf && allOf[0] && isRefObject(allOf[0])) { + // This is a single ref type + const name = refObjectToSchemaName(allOf[0]); + if (!name) { + throw new FieldTypeParseError( + t('nodes.unableToExtractSchemaNameFromRef') + ); + } + return { + name, + isCollection: false, + isPolymorphic: false, + }; + } + } else if (schemaObject.anyOf) { + // ignore null types + const filteredAnyOf = schemaObject.anyOf.filter((i) => { + if (isSchemaObject(i)) { + if (i.type === 'null') { + return false; + } + } + return true; + }); + if (filteredAnyOf.length === 1) { + // This is a single ref type + if (isRefObject(filteredAnyOf[0])) { + const name = refObjectToSchemaName(filteredAnyOf[0]); + if (!name) { + throw new FieldTypeParseError( + t('nodes.unableToExtractSchemaNameFromRef') + ); + } + + return { + name, + isCollection: false, + isPolymorphic: false, + }; + } else if (isSchemaObject(filteredAnyOf[0])) { + return parseFieldType(filteredAnyOf[0]); + } + } + /** + * Handle Polymorphic inputs, eg string | string[]. In OpenAPI, this is: + * - an `anyOf` with two items + * - one is an `ArraySchemaObject` with a single `SchemaObject or ReferenceObject` of type T in its `items` + * - the other is a `SchemaObject` or `ReferenceObject` of type T + * + * Any other cases we ignore. + */ + + let firstType: string | undefined; + let secondType: string | undefined; + + if (isArraySchemaObject(filteredAnyOf[0])) { + // first is array, second is not + const first = filteredAnyOf[0].items; + const second = filteredAnyOf[1]; + if (isRefObject(first) && isRefObject(second)) { + firstType = refObjectToSchemaName(first); + secondType = refObjectToSchemaName(second); + } else if ( + isNonArraySchemaObject(first) && + isNonArraySchemaObject(second) + ) { + firstType = first.type; + secondType = second.type; + } + } else if (isArraySchemaObject(filteredAnyOf[1])) { + // first is not array, second is + const first = filteredAnyOf[0]; + const second = filteredAnyOf[1].items; + if (isRefObject(first) && isRefObject(second)) { + firstType = refObjectToSchemaName(first); + secondType = refObjectToSchemaName(second); + } else if ( + isNonArraySchemaObject(first) && + isNonArraySchemaObject(second) + ) { + firstType = first.type; + secondType = second.type; + } + } + if (firstType && firstType === secondType) { + return { + name: OPENAPI_TO_FIELD_TYPE_MAP[firstType] ?? firstType, + isCollection: false, + isPolymorphic: true, // <-- don't forget, polymorphic! + }; + } + } + } else if (schemaObject.enum) { + return { name: 'EnumField', isCollection: false, isPolymorphic: false }; + } else if (schemaObject.type) { + if (schemaObject.type === 'array') { + // We need to get the type of the items + if (isSchemaObject(schemaObject.items)) { + const itemType = schemaObject.items.type; + if (!itemType || isArray(itemType)) { + throw new UnsupportedFieldTypeError( + t('nodes.unsupportedArrayItemType', { + type: itemType, + }) + ); + } + // This is an OpenAPI primitive - 'null', 'object', 'array', 'integer', 'number', 'string', 'boolean' + const name = OPENAPI_TO_FIELD_TYPE_MAP[itemType]; + if (!name) { + // it's 'null', 'object', or 'array' - skip + throw new UnsupportedFieldTypeError( + t('nodes.unsupportedArrayItemType', { + type: itemType, + }) + ); + } + return { + name, + isCollection: true, // <-- don't forget, collection! + isPolymorphic: false, + }; + } + + // This is a ref object, extract the type name + const name = refObjectToSchemaName(schemaObject.items); + if (!name) { + throw new FieldTypeParseError( + t('nodes.unableToExtractSchemaNameFromRef') + ); + } + return { + name, + isCollection: true, // <-- don't forget, collection! + isPolymorphic: false, + }; + } else if (!isArray(schemaObject.type)) { + // This is an OpenAPI primitive - 'null', 'object', 'array', 'integer', 'number', 'string', 'boolean' + const name = OPENAPI_TO_FIELD_TYPE_MAP[schemaObject.type]; + if (!name) { + // it's 'null', 'object', or 'array' - skip + throw new UnsupportedFieldTypeError( + t('nodes.unsupportedArrayItemType', { + type: schemaObject.type, + }) + ); + } + return { + name, + isCollection: false, + isPolymorphic: false, + }; + } + } + } else if (isRefObject(schemaObject)) { + const name = refObjectToSchemaName(schemaObject); + if (!name) { + throw new FieldTypeParseError( + t('nodes.unableToExtractSchemaNameFromRef') + ); + } + return { + name, + isCollection: false, + isPolymorphic: false, + }; + } + throw new FieldTypeParseError(t('nodes.unableToParseFieldType')); +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts index 8737fc52b9..2c59b6cb14 100644 --- a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts +++ b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts @@ -2,24 +2,24 @@ import { logger } from 'app/logging/logger'; import { parseify } from 'common/util/serialize'; import { reduce, startCase } from 'lodash-es'; import { OpenAPIV3_1 } from 'openapi-types'; -import { AnyInvocationType } from 'services/events/types'; +import { FieldInputTemplate, FieldOutputTemplate } from '../types/field'; +import { InvocationTemplate } from '../types/invocation'; import { - InputFieldTemplate, InvocationSchemaObject, - InvocationTemplate, - OutputFieldTemplate, - isFieldType, isInvocationFieldSchema, isInvocationOutputSchemaObject, isInvocationSchemaObject, -} from '../types/types'; -import { buildInputFieldTemplate, getFieldType } from './fieldTemplateBuilders'; +} from '../types/openapi'; +import { buildFieldInputTemplate } from './buildFieldInputTemplate'; +import { parseFieldType } from './parseFieldType'; +import { FieldTypeParseError, UnsupportedFieldTypeError } from '../types/error'; +import { t } from 'i18next'; const RESERVED_INPUT_FIELD_NAMES = ['id', 'type', 'use_cache']; const RESERVED_OUTPUT_FIELD_NAMES = ['type']; const RESERVED_FIELD_TYPES = ['IsIntermediate']; -const invocationDenylist: AnyInvocationType[] = ['graph', 'linear_ui_output']; +const invocationDenylist: string[] = ['graph', 'linear_ui_output']; const isReservedInputField = (nodeType: string, fieldName: string) => { if (RESERVED_INPUT_FIELD_NAMES.includes(fieldName)) { @@ -83,13 +83,13 @@ export const parseSchema = ( const inputs = reduce( schema.properties, ( - inputsAccumulator: Record, + inputsAccumulator: Record, property, propertyName ) => { if (isReservedInputField(type, propertyName)) { logger('nodes').trace( - { node: type, fieldName: propertyName, field: parseify(property) }, + { node: type, field: propertyName, schema: parseify(property) }, 'Skipped reserved input field' ); return inputsAccumulator; @@ -97,79 +97,53 @@ export const parseSchema = ( if (!isInvocationFieldSchema(property)) { logger('nodes').warn( - { node: type, propertyName, property: parseify(property) }, + { node: type, field: propertyName, schema: parseify(property) }, 'Unhandled input property' ); return inputsAccumulator; } - const fieldType = property.ui_type ?? getFieldType(property); + try { + const fieldType = parseFieldType(property); - if (!fieldType) { - logger('nodes').warn( - { - node: type, - fieldName: propertyName, - fieldType, - field: parseify(property), - }, - 'Missing input field type' + if (fieldType.name === 'WorkflowField') { + // This supports workflows, set the flag and skip to next field + withWorkflow = true; + return inputsAccumulator; + } + + if (isReservedFieldType(fieldType.name)) { + // Skip processing this reserved field + return inputsAccumulator; + } + + const fieldInputTemplate = buildFieldInputTemplate( + property, + propertyName, + fieldType ); - return inputsAccumulator; + + inputsAccumulator[propertyName] = fieldInputTemplate; + } catch (e) { + if ( + e instanceof FieldTypeParseError || + e instanceof UnsupportedFieldTypeError + ) { + logger('nodes').warn( + { + node: type, + field: propertyName, + schema: parseify(property), + }, + t('nodes.inputFieldTypeParseError', { + node: type, + field: propertyName, + message: e.message, + }) + ); + } } - if (fieldType === 'WorkflowField') { - withWorkflow = true; - return inputsAccumulator; - } - - if (isReservedFieldType(fieldType)) { - logger('nodes').trace( - { - node: type, - fieldName: propertyName, - fieldType, - field: parseify(property), - }, - `Skipping reserved input field type: ${fieldType}` - ); - return inputsAccumulator; - } - - if (!isFieldType(fieldType)) { - logger('nodes').warn( - { - node: type, - fieldName: propertyName, - fieldType, - field: parseify(property), - }, - `Skipping unknown input field type: ${fieldType}` - ); - return inputsAccumulator; - } - - const field = buildInputFieldTemplate( - schema, - property, - propertyName, - fieldType - ); - - if (!field) { - logger('nodes').warn( - { - node: type, - fieldName: propertyName, - fieldType, - field: parseify(property), - }, - 'Skipping input field with no template' - ); - return inputsAccumulator; - } - - inputsAccumulator[propertyName] = field; return inputsAccumulator; }, {} @@ -206,7 +180,7 @@ export const parseSchema = ( (outputsAccumulator, property, propertyName) => { if (!isAllowedOutputField(type, propertyName)) { logger('nodes').trace( - { type, propertyName, property: parseify(property) }, + { node: type, field: propertyName, schema: parseify(property) }, 'Skipped reserved output field' ); return outputsAccumulator; @@ -214,37 +188,62 @@ export const parseSchema = ( if (!isInvocationFieldSchema(property)) { logger('nodes').warn( - { type, propertyName, property: parseify(property) }, + { node: type, field: propertyName, schema: parseify(property) }, 'Unhandled output property' ); return outputsAccumulator; } - const fieldType = property.ui_type ?? getFieldType(property); + try { + const fieldType = parseFieldType(property); - if (!isFieldType(fieldType)) { - logger('nodes').warn( - { fieldName: propertyName, fieldType, field: parseify(property) }, - 'Skipping unknown output field type' - ); - return outputsAccumulator; + if (!fieldType) { + logger('nodes').warn( + { + node: type, + field: propertyName, + schema: parseify(property), + }, + 'Missing output field type' + ); + return outputsAccumulator; + } + + const fieldOutputTemplate: FieldOutputTemplate = { + fieldKind: 'output', + name: propertyName, + title: + property.title ?? (propertyName ? startCase(propertyName) : ''), + description: property.description ?? '', + type: fieldType, + ui_hidden: property.ui_hidden ?? false, + ui_type: property.ui_type, + ui_order: property.ui_order, + }; + + outputsAccumulator[propertyName] = fieldOutputTemplate; + } catch (e) { + if ( + e instanceof FieldTypeParseError || + e instanceof UnsupportedFieldTypeError + ) { + logger('nodes').warn( + { + node: type, + field: propertyName, + schema: parseify(property), + }, + t('nodes.outputFieldTypeParseError', { + node: type, + field: propertyName, + message: e.message, + }) + ); + } } - - outputsAccumulator[propertyName] = { - fieldKind: 'output', - name: propertyName, - title: - property.title ?? (propertyName ? startCase(propertyName) : ''), - description: property.description ?? '', - type: fieldType, - ui_hidden: property.ui_hidden ?? false, - ui_type: property.ui_type, - ui_order: property.ui_order, - }; - return outputsAccumulator; }, - {} as Record + {} as Record ); const useCache = schema.properties.use_cache.default; diff --git a/invokeai/frontend/web/src/features/nodes/util/validateWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/validateWorkflow.ts index 9e5cea13f6..6d2ee13cf2 100644 --- a/invokeai/frontend/web/src/features/nodes/util/validateWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/validateWorkflow.ts @@ -1,123 +1,159 @@ -import { compareVersions } from 'compare-versions'; -import { cloneDeep, keyBy } from 'lodash-es'; -import { - InvocationTemplate, - Workflow, - WorkflowWarning, - isWorkflowInvocationNode, -} from '../types/types'; import { parseify } from 'common/util/serialize'; -import i18n from 'i18next'; +import { t } from 'i18next'; +import { keyBy } from 'lodash-es'; +import { JsonObject } from 'type-fest'; +import { getNeedsUpdate } from '../store/util/nodeUpdate'; +import { InvocationTemplate } from '../types/invocation'; +import { parseAndMigrateWorkflow } from '../types/migration/migrations'; +import { WorkflowV2, isWorkflowInvocationNode } from '../types/workflow'; +type WorkflowWarning = { + message: string; + issues?: string[]; + data: JsonObject; +}; + +type ValidateWorkflowResult = { + workflow: WorkflowV2; + warnings: WorkflowWarning[]; +}; + +/** + * Parses and validates a workflow: + * - Parses the workflow schema, and migrates it to the latest version if necessary. + * - Validates the workflow against the node templates, warning if the template is not known. + * - Attempts to update nodes which have a mismatched version. + * - Removes edges which are invalid. + * @param workflow The raw workflow object (e.g. JSON.parse(stringifiedWorklow)) + * @param invocationTemplates The node templates to validate against. + * @throws {WorkflowVersionError} If the workflow version is not recognized. + * @throws {z.ZodError} If there is a validation error. + */ export const validateWorkflow = ( - workflow: Workflow, - nodeTemplates: Record -) => { - const clone = cloneDeep(workflow); - const { nodes, edges } = clone; - const errors: WorkflowWarning[] = []; + workflow: unknown, + invocationTemplates: Record +): ValidateWorkflowResult => { + // Parse the raw workflow data & migrate it to the latest version + const _workflow = parseAndMigrateWorkflow(workflow); + + // Now we can validate the graph + const { nodes, edges } = _workflow; + const warnings: WorkflowWarning[] = []; + + // We don't need to validate Note nodes or CurrentImage nodes - only Invocation nodes const invocationNodes = nodes.filter(isWorkflowInvocationNode); const keyedNodes = keyBy(invocationNodes, 'id'); - nodes.forEach((node) => { - if (!isWorkflowInvocationNode(node)) { - return; - } - const nodeTemplate = nodeTemplates[node.data.type]; - if (!nodeTemplate) { - errors.push({ - message: `${i18n.t('nodes.node')} "${node.data.type}" ${i18n.t( - 'nodes.skipped' - )}`, - issues: [ - `${i18n.t('nodes.nodeType')}"${node.data.type}" ${i18n.t( - 'nodes.doesNotExist' - )}`, - ], - data: node, + invocationNodes.forEach((node) => { + const template = invocationTemplates[node.data.type]; + if (!template) { + // This node's type template does not exist + const message = t('nodes.missingTemplate', { + node: node.id, + type: node.data.type, + }); + warnings.push({ + message, + data: parseify(node), }); return; } - if ( - nodeTemplate.version && - node.data.version && - compareVersions(nodeTemplate.version, node.data.version) !== 0 - ) { - errors.push({ - message: `${i18n.t('nodes.node')} "${node.data.type}" ${i18n.t( - 'nodes.mismatchedVersion' - )}`, - issues: [ - `${i18n.t('nodes.node')} "${node.data.type}" v${ - node.data.version - } ${i18n.t('nodes.maybeIncompatible')} v${nodeTemplate.version}`, - ], - data: { node, nodeTemplate: parseify(nodeTemplate) }, + if (getNeedsUpdate(node, template)) { + // This node needs to be updated, based on comparison of its version to the template version + const message = t('nodes.mismatchedVersion', { + node: node.id, + type: node.data.type, + }); + warnings.push({ + message, + data: parseify({ node, nodeTemplate: template }), }); return; } }); edges.forEach((edge, i) => { + // Validate each edge. If the edge is invalid, we must remove it to prevent runtime errors with reactflow. const sourceNode = keyedNodes[edge.source]; const targetNode = keyedNodes[edge.target]; const issues: string[] = []; + if (!sourceNode) { + // The edge's source/output node does not exist issues.push( - `${i18n.t('nodes.outputNode')} ${edge.source} ${i18n.t( - 'nodes.doesNotExist' - )}` + t('nodes.sourceNodeDoesNotExist', { + node: edge.source, + }) ); } else if ( edge.type === 'default' && !(edge.sourceHandle in sourceNode.data.outputs) ) { + // The edge's source/output node field does not exist issues.push( - `${i18n.t('nodes.outputNode')} "${edge.source}.${ - edge.sourceHandle - }" ${i18n.t('nodes.doesNotExist')}` + t('nodes.sourceNodeFieldDoesNotExist', { + node: edge.source, + field: edge.sourceHandle, + }) ); } + if (!targetNode) { + // The edge's target/input node does not exist issues.push( - `${i18n.t('nodes.inputNode')} ${edge.target} ${i18n.t( - 'nodes.doesNotExist' - )}` + t('nodes.targetNodeDoesNotExist', { + node: edge.target, + }) ); } else if ( edge.type === 'default' && !(edge.targetHandle in targetNode.data.inputs) ) { + // The edge's target/input node field does not exist issues.push( - `${i18n.t('nodes.inputField')} "${edge.target}.${ - edge.targetHandle - }" ${i18n.t('nodes.doesNotExist')}` + t('nodes.targetNodeFieldDoesNotExist', { + node: edge.target, + field: edge.targetHandle, + }) ); } - if (!nodeTemplates[sourceNode?.data.type ?? '__UNKNOWN_NODE_TYPE__']) { + + if (!sourceNode?.data.type || !invocationTemplates[sourceNode.data.type]) { + // The edge's source/output node template does not exist issues.push( - `${i18n.t('nodes.sourceNode')} "${edge.source}" ${i18n.t( - 'nodes.missingTemplate' - )} "${sourceNode?.data.type}"` + t('nodes.missingTemplate', { + node: edge.source, + type: sourceNode?.data.type, + }) ); } - if (!nodeTemplates[targetNode?.data.type ?? '__UNKNOWN_NODE_TYPE__']) { + if (!targetNode?.data.type || !invocationTemplates[targetNode?.data.type]) { + // The edge's target/input node template does not exist issues.push( - `${i18n.t('nodes.sourceNode')}"${edge.target}" ${i18n.t( - 'nodes.missingTemplate' - )} "${targetNode?.data.type}"` + t('nodes.missingTemplate', { + node: edge.target, + type: targetNode?.data.type, + }) ); } + if (issues.length) { + // This edge has some issues. Remove it. delete edges[i]; - const src = edge.type === 'default' ? edge.sourceHandle : edge.source; - const tgt = edge.type === 'default' ? edge.targetHandle : edge.target; - errors.push({ - message: `Edge "${src} -> ${tgt}" skipped`, + const source = + edge.type === 'default' + ? `${edge.source}.${edge.sourceHandle}` + : edge.source; + const target = + edge.type === 'default' + ? `${edge.source}.${edge.targetHandle}` + : edge.target; + warnings.push({ + message: t('nodes.deletedInvalidEdge', { source, target }), issues, data: edge, }); } }); - return { workflow: clone, errors }; + return { workflow: _workflow, warnings }; }; diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Advanced/ParamClipSkip.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Advanced/ParamClipSkip.tsx index bff8120b7b..49dd60beb5 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Advanced/ParamClipSkip.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Advanced/ParamClipSkip.tsx @@ -3,7 +3,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIInformationalPopover from 'common/components/IAIInformationalPopover/IAIInformationalPopover'; import IAISlider from 'common/components/IAISlider'; import { setClipSkip } from 'features/parameters/store/generationSlice'; -import { clipSkipMap } from 'features/parameters/types/constants'; +import { CLIP_SKIP_MAP } from 'features/parameters/types/constants'; import { useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -30,16 +30,16 @@ export default function ParamClipSkip() { const max = useMemo(() => { if (!model) { - return clipSkipMap['sd-1'].maxClip; + return CLIP_SKIP_MAP['sd-1'].maxClip; } - return clipSkipMap[model.base_model].maxClip; + return CLIP_SKIP_MAP[model.base_model].maxClip; }, [model]); const sliderMarks = useMemo(() => { if (!model) { - return clipSkipMap['sd-1'].markers; + return CLIP_SKIP_MAP['sd-1'].markers; } - return clipSkipMap[model.base_model].markers; + return CLIP_SKIP_MAP[model.base_model].markers; }, [model]); if (model?.base_model === 'sdxl') { diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/Compositing/CoherencePass/ParamCanvasCoherenceMode.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/Compositing/CoherencePass/ParamCanvasCoherenceMode.tsx index 1196719af3..1fe4f95c3b 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/Compositing/CoherencePass/ParamCanvasCoherenceMode.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Canvas/Compositing/CoherencePass/ParamCanvasCoherenceMode.tsx @@ -4,7 +4,7 @@ import IAIInformationalPopover from 'common/components/IAIInformationalPopover/I import { IAISelectDataType } from 'common/components/IAIMantineSearchableSelect'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { setCanvasCoherenceMode } from 'features/parameters/store/generationSlice'; -import { CanvasCoherenceModeParam } from 'features/parameters/types/parameterSchemas'; +import { ParameterCanvasCoherenceMode } from 'features/parameters/types/parameterSchemas'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -30,7 +30,7 @@ const ParamCanvasCoherenceMode = () => { return; } - dispatch(setCanvasCoherenceMode(v as CanvasCoherenceModeParam)); + dispatch(setCanvasCoherenceMode(v as ParameterCanvasCoherenceMode)); }, [dispatch] ); diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx index a44e6fb551..8f27747497 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/Core/ParamScheduler.tsx @@ -5,10 +5,8 @@ import IAIInformationalPopover from 'common/components/IAIInformationalPopover/I import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect'; import { generationSelector } from 'features/parameters/store/generationSelectors'; import { setScheduler } from 'features/parameters/store/generationSlice'; -import { - SCHEDULER_LABEL_MAP, - SchedulerParam, -} from 'features/parameters/types/parameterSchemas'; +import { ParameterScheduler } from 'features/parameters/types/parameterSchemas'; +import { SCHEDULER_LABEL_MAP } from 'features/parameters/types/constants'; import { uiSelector } from 'features/ui/store/uiSelectors'; import { map } from 'lodash-es'; import { memo, useCallback } from 'react'; @@ -23,7 +21,7 @@ const selector = createSelector( const data = map(SCHEDULER_LABEL_MAP, (label, name) => ({ value: name, label: label, - group: enabledSchedulers.includes(name as SchedulerParam) + group: enabledSchedulers.includes(name as ParameterScheduler) ? 'Favorites' : undefined, })).sort((a, b) => a.label.localeCompare(b.label)); @@ -46,7 +44,7 @@ const ParamScheduler = () => { if (!v) { return; } - dispatch(setScheduler(v as SchedulerParam)); + dispatch(setScheduler(v as ParameterScheduler)); }, [dispatch] ); diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/HighResFix/ParamHrfMethod.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/HighResFix/ParamHrfMethod.tsx index 403d2268c1..89c4f51356 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/HighResFix/ParamHrfMethod.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/HighResFix/ParamHrfMethod.tsx @@ -4,7 +4,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { setHrfMethod } from 'features/parameters/store/generationSlice'; -import { HrfMethodParam } from 'features/parameters/types/parameterSchemas'; +import { ParameterHRFMethod } from 'features/parameters/types/parameterSchemas'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; @@ -26,7 +26,7 @@ const ParamHrfMethodSelect = () => { const { hrfMethod, hrfEnabled } = useAppSelector(selector); const handleChange = useCallback( - (v: HrfMethodParam | null) => { + (v: ParameterHRFMethod | null) => { if (!v) { return; } diff --git a/invokeai/frontend/web/src/features/parameters/components/Parameters/VAEModel/ParamVAEPrecision.tsx b/invokeai/frontend/web/src/features/parameters/components/Parameters/VAEModel/ParamVAEPrecision.tsx index 723e57a288..ad75fa0b7b 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Parameters/VAEModel/ParamVAEPrecision.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Parameters/VAEModel/ParamVAEPrecision.tsx @@ -5,7 +5,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIInformationalPopover from 'common/components/IAIInformationalPopover/IAIInformationalPopover'; import IAIMantineSelect from 'common/components/IAIMantineSelect'; import { vaePrecisionChanged } from 'features/parameters/store/generationSlice'; -import { PrecisionParam } from 'features/parameters/types/parameterSchemas'; +import { ParameterPrecision } from 'features/parameters/types/parameterSchemas'; import { memo, useCallback } from 'react'; import { useTranslation } from 'react-i18next'; @@ -31,7 +31,7 @@ const ParamVAEModelSelect = () => { return; } - dispatch(vaePrecisionChanged(v as PrecisionParam)); + dispatch(vaePrecisionChanged(v as ParameterPrecision)); }, [dispatch] ); diff --git a/invokeai/frontend/web/src/features/parameters/util/useCoreParametersCollapseLabel.ts b/invokeai/frontend/web/src/features/parameters/hooks/useCoreParametersCollapseLabel.ts similarity index 100% rename from invokeai/frontend/web/src/features/parameters/util/useCoreParametersCollapseLabel.ts rename to invokeai/frontend/web/src/features/parameters/hooks/useCoreParametersCollapseLabel.ts diff --git a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts index 5cecd03753..898fe13618 100644 --- a/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts +++ b/invokeai/frontend/web/src/features/parameters/hooks/useRecallParameters.ts @@ -24,7 +24,7 @@ import { IPAdapterMetadataItem, LoRAMetadataItem, T2IAdapterMetadataItem, -} from 'features/nodes/types/types'; +} from 'features/nodes/types/metadata'; import { refinerModelChanged, setNegativeStylePromptSDXL, @@ -69,28 +69,28 @@ import { vaeSelected, } from '../store/generationSlice'; import { - isValidBoolean, - isValidCfgScale, - isValidControlNetModel, - isValidHeight, - isValidHrfMethod, - isValidIPAdapterModel, - isValidLoRAModel, - isValidMainModel, - isValidNegativePrompt, - isValidPositivePrompt, - isValidSDXLNegativeStylePrompt, - isValidSDXLPositiveStylePrompt, - isValidSDXLRefinerModel, - isValidSDXLRefinerNegativeAestheticScore, - isValidSDXLRefinerPositiveAestheticScore, - isValidSDXLRefinerStart, - isValidScheduler, - isValidSeed, - isValidSteps, - isValidStrength, - isValidVaeModel, - isValidWidth, + isParameterHRFEnabled, + isParameterCFGScale, + isParameterControlNetModel, + isParameterHeight, + isParameterHRFMethod, + isParameterIPAdapterModel, + isParameterLoRAModel, + isParameterModel, + isParameterNegativePrompt, + isParameterPositivePrompt, + isParameterNegativeStylePromptSDXL, + isParameterPositiveStylePromptSDXL, + isParameterSDXLRefinerModel, + isParameterSDXLRefinerNegativeAestheticScore, + isParameterSDXLRefinerPositiveAestheticScore, + isParameterSDXLRefinerStart, + isParameterScheduler, + isParameterSeed, + isParameterSteps, + isParameterStrength, + isParameterVAEModel, + isParameterWidth, } from '../types/parameterSchemas'; const selector = createSelector( @@ -160,24 +160,24 @@ export const useRecallParameters = () => { negativeStylePrompt: unknown ) => { if ( - isValidPositivePrompt(positivePrompt) || - isValidNegativePrompt(negativePrompt) || - isValidSDXLPositiveStylePrompt(positiveStylePrompt) || - isValidSDXLNegativeStylePrompt(negativeStylePrompt) + isParameterPositivePrompt(positivePrompt) || + isParameterNegativePrompt(negativePrompt) || + isParameterPositiveStylePromptSDXL(positiveStylePrompt) || + isParameterNegativeStylePromptSDXL(negativeStylePrompt) ) { - if (isValidPositivePrompt(positivePrompt)) { + if (isParameterPositivePrompt(positivePrompt)) { dispatch(setPositivePrompt(positivePrompt)); } - if (isValidNegativePrompt(negativePrompt)) { + if (isParameterNegativePrompt(negativePrompt)) { dispatch(setNegativePrompt(negativePrompt)); } - if (isValidSDXLPositiveStylePrompt(positiveStylePrompt)) { + if (isParameterPositiveStylePromptSDXL(positiveStylePrompt)) { dispatch(setPositiveStylePromptSDXL(positiveStylePrompt)); } - if (isValidSDXLPositiveStylePrompt(negativeStylePrompt)) { + if (isParameterPositiveStylePromptSDXL(negativeStylePrompt)) { dispatch(setNegativeStylePromptSDXL(negativeStylePrompt)); } @@ -194,7 +194,7 @@ export const useRecallParameters = () => { */ const recallPositivePrompt = useCallback( (positivePrompt: unknown) => { - if (!isValidPositivePrompt(positivePrompt)) { + if (!isParameterPositivePrompt(positivePrompt)) { parameterNotSetToast(); return; } @@ -209,7 +209,7 @@ export const useRecallParameters = () => { */ const recallNegativePrompt = useCallback( (negativePrompt: unknown) => { - if (!isValidNegativePrompt(negativePrompt)) { + if (!isParameterNegativePrompt(negativePrompt)) { parameterNotSetToast(); return; } @@ -224,7 +224,7 @@ export const useRecallParameters = () => { */ const recallSDXLPositiveStylePrompt = useCallback( (positiveStylePrompt: unknown) => { - if (!isValidSDXLPositiveStylePrompt(positiveStylePrompt)) { + if (!isParameterPositiveStylePromptSDXL(positiveStylePrompt)) { parameterNotSetToast(); return; } @@ -239,7 +239,7 @@ export const useRecallParameters = () => { */ const recallSDXLNegativeStylePrompt = useCallback( (negativeStylePrompt: unknown) => { - if (!isValidSDXLNegativeStylePrompt(negativeStylePrompt)) { + if (!isParameterNegativeStylePromptSDXL(negativeStylePrompt)) { parameterNotSetToast(); return; } @@ -254,7 +254,7 @@ export const useRecallParameters = () => { */ const recallSeed = useCallback( (seed: unknown) => { - if (!isValidSeed(seed)) { + if (!isParameterSeed(seed)) { parameterNotSetToast(); return; } @@ -269,7 +269,7 @@ export const useRecallParameters = () => { */ const recallCfgScale = useCallback( (cfgScale: unknown) => { - if (!isValidCfgScale(cfgScale)) { + if (!isParameterCFGScale(cfgScale)) { parameterNotSetToast(); return; } @@ -284,7 +284,7 @@ export const useRecallParameters = () => { */ const recallModel = useCallback( (model: unknown) => { - if (!isValidMainModel(model)) { + if (!isParameterModel(model)) { parameterNotSetToast(); return; } @@ -299,7 +299,7 @@ export const useRecallParameters = () => { */ const recallScheduler = useCallback( (scheduler: unknown) => { - if (!isValidScheduler(scheduler)) { + if (!isParameterScheduler(scheduler)) { parameterNotSetToast(); return; } @@ -314,7 +314,7 @@ export const useRecallParameters = () => { */ const recallVaeModel = useCallback( (vae: unknown) => { - if (!isValidVaeModel(vae) && !isNil(vae)) { + if (!isParameterVAEModel(vae) && !isNil(vae)) { parameterNotSetToast(); return; } @@ -333,7 +333,7 @@ export const useRecallParameters = () => { */ const recallSteps = useCallback( (steps: unknown) => { - if (!isValidSteps(steps)) { + if (!isParameterSteps(steps)) { parameterNotSetToast(); return; } @@ -348,7 +348,7 @@ export const useRecallParameters = () => { */ const recallWidth = useCallback( (width: unknown) => { - if (!isValidWidth(width)) { + if (!isParameterWidth(width)) { parameterNotSetToast(); return; } @@ -363,7 +363,7 @@ export const useRecallParameters = () => { */ const recallHeight = useCallback( (height: unknown) => { - if (!isValidHeight(height)) { + if (!isParameterHeight(height)) { parameterNotSetToast(); return; } @@ -378,11 +378,11 @@ export const useRecallParameters = () => { */ const recallWidthAndHeight = useCallback( (width: unknown, height: unknown) => { - if (!isValidWidth(width)) { + if (!isParameterWidth(width)) { allParameterNotSetToast(); return; } - if (!isValidHeight(height)) { + if (!isParameterHeight(height)) { allParameterNotSetToast(); return; } @@ -398,7 +398,7 @@ export const useRecallParameters = () => { */ const recallStrength = useCallback( (strength: unknown) => { - if (!isValidStrength(strength)) { + if (!isParameterStrength(strength)) { parameterNotSetToast(); return; } @@ -413,7 +413,7 @@ export const useRecallParameters = () => { */ const recallHrfEnabled = useCallback( (hrfEnabled: unknown) => { - if (!isValidBoolean(hrfEnabled)) { + if (!isParameterHRFEnabled(hrfEnabled)) { parameterNotSetToast(); return; } @@ -428,7 +428,7 @@ export const useRecallParameters = () => { */ const recallHrfStrength = useCallback( (hrfStrength: unknown) => { - if (!isValidStrength(hrfStrength)) { + if (!isParameterStrength(hrfStrength)) { parameterNotSetToast(); return; } @@ -443,7 +443,7 @@ export const useRecallParameters = () => { */ const recallHrfMethod = useCallback( (hrfMethod: unknown) => { - if (!isValidHrfMethod(hrfMethod)) { + if (!isParameterHRFMethod(hrfMethod)) { parameterNotSetToast(); return; } @@ -461,7 +461,7 @@ export const useRecallParameters = () => { const prepareLoRAMetadataItem = useCallback( (loraMetadataItem: LoRAMetadataItem) => { - if (!isValidLoRAModel(loraMetadataItem.lora)) { + if (!isParameterLoRAModel(loraMetadataItem.lora)) { return { lora: null, error: 'Invalid LoRA model' }; } @@ -518,7 +518,7 @@ export const useRecallParameters = () => { const prepareControlNetMetadataItem = useCallback( (controlnetMetadataItem: ControlNetMetadataItem) => { - if (!isValidControlNetModel(controlnetMetadataItem.control_model)) { + if (!isParameterControlNetModel(controlnetMetadataItem.control_model)) { return { controlnet: null, error: 'Invalid ControlNet model' }; } @@ -613,7 +613,9 @@ export const useRecallParameters = () => { const prepareT2IAdapterMetadataItem = useCallback( (t2iAdapterMetadataItem: T2IAdapterMetadataItem) => { - if (!isValidControlNetModel(t2iAdapterMetadataItem.t2i_adapter_model)) { + if ( + !isParameterControlNetModel(t2iAdapterMetadataItem.t2i_adapter_model) + ) { return { controlnet: null, error: 'Invalid ControlNet model' }; } @@ -703,7 +705,7 @@ export const useRecallParameters = () => { const prepareIPAdapterMetadataItem = useCallback( (ipAdapterMetadataItem: IPAdapterMetadataItem) => { - if (!isValidIPAdapterModel(ipAdapterMetadataItem?.ip_adapter_model)) { + if (!isParameterIPAdapterModel(ipAdapterMetadataItem?.ip_adapter_model)) { return { ipAdapter: null, error: 'Invalid IP Adapter model' }; } @@ -822,26 +824,26 @@ export const useRecallParameters = () => { t2iAdapters, } = metadata; - if (isValidCfgScale(cfg_scale)) { + if (isParameterCFGScale(cfg_scale)) { dispatch(setCfgScale(cfg_scale)); } - if (isValidMainModel(model)) { + if (isParameterModel(model)) { dispatch(modelSelected(model)); } - if (isValidPositivePrompt(positive_prompt)) { + if (isParameterPositivePrompt(positive_prompt)) { dispatch(setPositivePrompt(positive_prompt)); } - if (isValidNegativePrompt(negative_prompt)) { + if (isParameterNegativePrompt(negative_prompt)) { dispatch(setNegativePrompt(negative_prompt)); } - if (isValidScheduler(scheduler)) { + if (isParameterScheduler(scheduler)) { dispatch(setScheduler(scheduler)); } - if (isValidVaeModel(vae) || isNil(vae)) { + if (isParameterVAEModel(vae) || isNil(vae)) { if (isNil(vae)) { dispatch(vaeSelected(null)); } else { @@ -849,64 +851,64 @@ export const useRecallParameters = () => { } } - if (isValidSeed(seed)) { + if (isParameterSeed(seed)) { dispatch(setSeed(seed)); } - if (isValidSteps(steps)) { + if (isParameterSteps(steps)) { dispatch(setSteps(steps)); } - if (isValidWidth(width)) { + if (isParameterWidth(width)) { dispatch(setWidth(width)); } - if (isValidHeight(height)) { + if (isParameterHeight(height)) { dispatch(setHeight(height)); } - if (isValidStrength(strength)) { + if (isParameterStrength(strength)) { dispatch(setImg2imgStrength(strength)); } - if (isValidBoolean(hrf_enabled)) { + if (isParameterHRFEnabled(hrf_enabled)) { dispatch(setHrfEnabled(hrf_enabled)); } - if (isValidStrength(hrf_strength)) { + if (isParameterStrength(hrf_strength)) { dispatch(setHrfStrength(hrf_strength)); } - if (isValidHrfMethod(hrf_method)) { + if (isParameterHRFMethod(hrf_method)) { dispatch(setHrfMethod(hrf_method)); } - if (isValidSDXLPositiveStylePrompt(positive_style_prompt)) { + if (isParameterPositiveStylePromptSDXL(positive_style_prompt)) { dispatch(setPositiveStylePromptSDXL(positive_style_prompt)); } - if (isValidSDXLNegativeStylePrompt(negative_style_prompt)) { + if (isParameterNegativeStylePromptSDXL(negative_style_prompt)) { dispatch(setNegativeStylePromptSDXL(negative_style_prompt)); } - if (isValidSDXLRefinerModel(refiner_model)) { + if (isParameterSDXLRefinerModel(refiner_model)) { dispatch(refinerModelChanged(refiner_model)); } - if (isValidSteps(refiner_steps)) { + if (isParameterSteps(refiner_steps)) { dispatch(setRefinerSteps(refiner_steps)); } - if (isValidCfgScale(refiner_cfg_scale)) { + if (isParameterCFGScale(refiner_cfg_scale)) { dispatch(setRefinerCFGScale(refiner_cfg_scale)); } - if (isValidScheduler(refiner_scheduler)) { + if (isParameterScheduler(refiner_scheduler)) { dispatch(setRefinerScheduler(refiner_scheduler)); } if ( - isValidSDXLRefinerPositiveAestheticScore( + isParameterSDXLRefinerPositiveAestheticScore( refiner_positive_aesthetic_score ) ) { @@ -916,7 +918,7 @@ export const useRecallParameters = () => { } if ( - isValidSDXLRefinerNegativeAestheticScore( + isParameterSDXLRefinerNegativeAestheticScore( refiner_negative_aesthetic_score ) ) { @@ -925,7 +927,7 @@ export const useRecallParameters = () => { ); } - if (isValidSDXLRefinerStart(refiner_start)) { + if (isParameterSDXLRefinerStart(refiner_start)) { dispatch(setRefinerStart(refiner_start)); } diff --git a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts index 8fbdfafbde..e23747c921 100644 --- a/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts +++ b/invokeai/frontend/web/src/features/parameters/store/generationSlice.ts @@ -6,63 +6,62 @@ import { clamp } from 'lodash-es'; import { ImageDTO } from 'services/api/types'; import { isAnyControlAdapterAdded } from 'features/controlAdapters/store/controlAdaptersSlice'; -import { clipSkipMap } from '../types/constants'; +import { CLIP_SKIP_MAP } from '../types/constants'; import { - CanvasCoherenceModeParam, - CfgScaleParam, - HeightParam, - HrfMethodParam, - MainModelParam, - MaskBlurMethodParam, - NegativePromptParam, - OnnxModelParam, - PositivePromptParam, - PrecisionParam, - SchedulerParam, - SeedParam, - StepsParam, - StrengthParam, - VaeModelParam, - WidthParam, - zMainModel, + ParameterCanvasCoherenceMode, + ParameterCFGScale, + ParameterHeight, + ParameterHRFMethod, + ParameterModel, + ParameterMaskBlurMethod, + ParameterNegativePrompt, + ParameterPositivePrompt, + ParameterPrecision, + ParameterScheduler, + ParameterSeed, + ParameterSteps, + ParameterStrength, + ParameterVAEModel, + ParameterWidth, + zParameterModel, } from '../types/parameterSchemas'; export interface GenerationState { hrfEnabled: boolean; - hrfStrength: StrengthParam; - hrfMethod: HrfMethodParam; - cfgScale: CfgScaleParam; - height: HeightParam; - img2imgStrength: StrengthParam; + hrfStrength: ParameterStrength; + hrfMethod: ParameterHRFMethod; + cfgScale: ParameterCFGScale; + height: ParameterHeight; + img2imgStrength: ParameterStrength; infillMethod: string; initialImage?: { imageName: string; width: number; height: number }; iterations: number; perlin: number; - positivePrompt: PositivePromptParam; - negativePrompt: NegativePromptParam; - scheduler: SchedulerParam; + positivePrompt: ParameterPositivePrompt; + negativePrompt: ParameterNegativePrompt; + scheduler: ParameterScheduler; maskBlur: number; - maskBlurMethod: MaskBlurMethodParam; - canvasCoherenceMode: CanvasCoherenceModeParam; + maskBlurMethod: ParameterMaskBlurMethod; + canvasCoherenceMode: ParameterCanvasCoherenceMode; canvasCoherenceSteps: number; - canvasCoherenceStrength: StrengthParam; - seed: SeedParam; + canvasCoherenceStrength: ParameterStrength; + seed: ParameterSeed; seedWeights: string; shouldFitToWidthHeight: boolean; shouldGenerateVariations: boolean; shouldRandomizeSeed: boolean; - steps: StepsParam; + steps: ParameterSteps; threshold: number; infillTileSize: number; infillPatchmatchDownscaleSize: number; variationAmount: number; - width: WidthParam; + width: ParameterWidth; shouldUseSymmetry: boolean; horizontalSymmetrySteps: number; verticalSymmetrySteps: number; - model: MainModelParam | OnnxModelParam | null; - vae: VaeModelParam | null; - vaePrecision: PrecisionParam; + model: ParameterModel | null; + vae: ParameterVAEModel | null; + vaePrecision: ParameterPrecision; seamlessXAxis: boolean; seamlessYAxis: boolean; clipSkip: number; @@ -166,7 +165,7 @@ export const generationSlice = createSlice({ state.width = height; state.height = width; }, - setScheduler: (state, action: PayloadAction) => { + setScheduler: (state, action: PayloadAction) => { state.scheduler = action.payload; }, setSeed: (state, action: PayloadAction) => { @@ -214,12 +213,15 @@ export const generationSlice = createSlice({ setMaskBlur: (state, action: PayloadAction) => { state.maskBlur = action.payload; }, - setMaskBlurMethod: (state, action: PayloadAction) => { + setMaskBlurMethod: ( + state, + action: PayloadAction + ) => { state.maskBlurMethod = action.payload; }, setCanvasCoherenceMode: ( state, - action: PayloadAction + action: PayloadAction ) => { state.canvasCoherenceMode = action.payload; }, @@ -254,10 +256,7 @@ export const generationSlice = createSlice({ const { image_name, width, height } = action.payload; state.initialImage = { imageName: image_name, width, height }; }, - modelChanged: ( - state, - action: PayloadAction - ) => { + modelChanged: (state, action: PayloadAction) => { state.model = action.payload; if (state.model === null) { @@ -265,14 +264,14 @@ export const generationSlice = createSlice({ } // Clamp ClipSkip Based On Selected Model - const { maxClip } = clipSkipMap[state.model.base_model]; + const { maxClip } = CLIP_SKIP_MAP[state.model.base_model]; state.clipSkip = clamp(state.clipSkip, 0, maxClip); }, - vaeSelected: (state, action: PayloadAction) => { + vaeSelected: (state, action: PayloadAction) => { // null is a valid VAE! state.vae = action.payload; }, - vaePrecisionChanged: (state, action: PayloadAction) => { + vaePrecisionChanged: (state, action: PayloadAction) => { state.vaePrecision = action.payload; }, setClipSkip: (state, action: PayloadAction) => { @@ -284,7 +283,7 @@ export const generationSlice = createSlice({ setHrfEnabled: (state, action: PayloadAction) => { state.hrfEnabled = action.payload; }, - setHrfMethod: (state, action: PayloadAction) => { + setHrfMethod: (state, action: PayloadAction) => { state.hrfMethod = action.payload; }, shouldUseCpuNoiseChanged: (state, action: PayloadAction) => { @@ -308,7 +307,7 @@ export const generationSlice = createSlice({ if (defaultModel && !state.model) { const [base_model, model_type, model_name] = defaultModel.split('/'); - const result = zMainModel.safeParse({ + const result = zParameterModel.safeParse({ model_name, base_model, model_type, diff --git a/invokeai/frontend/web/src/features/parameters/types/constants.ts b/invokeai/frontend/web/src/features/parameters/types/constants.ts index 4494d235af..2d9fa62a79 100644 --- a/invokeai/frontend/web/src/features/parameters/types/constants.ts +++ b/invokeai/frontend/web/src/features/parameters/types/constants.ts @@ -1,5 +1,9 @@ -import { components } from 'services/api/schema'; +import { SchedulerField } from 'features/nodes/types/common'; +import { LoRAModelFormat } from 'services/api/types'; +/** + * Mapping of model type to human readable name + */ export const MODEL_TYPE_MAP = { any: 'Any', 'sd-1': 'Stable Diffusion 1.x', @@ -8,6 +12,9 @@ export const MODEL_TYPE_MAP = { 'sdxl-refiner': 'Stable Diffusion XL Refiner', }; +/** + * Mapping of model type to (short) human readable name + */ export const MODEL_TYPE_SHORT_MAP = { any: 'Any', 'sd-1': 'SD1', @@ -16,7 +23,10 @@ export const MODEL_TYPE_SHORT_MAP = { 'sdxl-refiner': 'SDXLR', }; -export const clipSkipMap = { +/** + * Mapping of model type to CLIP skip parameter constraints + */ +export const CLIP_SKIP_MAP = { any: { maxClip: 0, markers: [], @@ -39,11 +49,41 @@ export const clipSkipMap = { }, }; -type LoRAModelFormatMap = { - [key in components['schemas']['LoRAModelFormat']]: string; -}; - -export const LORA_MODEL_FORMAT_MAP: LoRAModelFormatMap = { +/** + * Mapping of LoRA format to human readable name + */ +export const LORA_MODEL_FORMAT_MAP: { + [key in LoRAModelFormat]: string; +} = { lycoris: 'LyCORIS', diffusers: 'Diffusers', }; + +/** + * Mapping of schedulers to human readable name + */ +export const SCHEDULER_LABEL_MAP: Record = { + euler: 'Euler', + deis: 'DEIS', + ddim: 'DDIM', + ddpm: 'DDPM', + dpmpp_sde: 'DPM++ SDE', + dpmpp_2s: 'DPM++ 2S', + dpmpp_2m: 'DPM++ 2M', + dpmpp_2m_sde: 'DPM++ 2M SDE', + heun: 'Heun', + kdpm_2: 'KDPM 2', + lms: 'LMS', + pndm: 'PNDM', + unipc: 'UniPC', + euler_k: 'Euler Karras', + dpmpp_sde_k: 'DPM++ SDE Karras', + dpmpp_2s_k: 'DPM++ 2S Karras', + dpmpp_2m_k: 'DPM++ 2M Karras', + dpmpp_2m_sde_k: 'DPM++ 2M SDE Karras', + heun_k: 'Heun Karras', + lms_k: 'LMS Karras', + euler_a: 'Euler Ancestral', + kdpm_2_a: 'KDPM 2 Ancestral', + lcm: 'LCM', +}; diff --git a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts index ec3f9baba1..a96e8af002 100644 --- a/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts +++ b/invokeai/frontend/web/src/features/parameters/types/parameterSchemas.ts @@ -1,522 +1,269 @@ import { NUMPY_RAND_MAX } from 'app/constants'; +import { + zControlNetModelField, + zIPAdapterModelField, + zLoRAModelField, + zMainOrONNXModelField, + zSDXLRefinerModelField, + zSchedulerField, + zT2IAdapterModelField, + zVAEModelField, +} from 'features/nodes/types/common'; import { z } from 'zod'; /** - * These zod schemas should match the pydantic node schemas. + * Schemas, types and type guards for parameters. * - * Parameters only need schemas if we want to recall them from metadata. + * Parameters need schemas if we want to recall them from metadata or some untrusted source. * * Each parameter needs: * - a zod schema * - a type alias, inferred from the zod schema - * - a combo validation/type guard function, which returns true if the value is valid + * - a combo validation/type guard function, which returns true if the value is valid, should + * simply be the zod schema's safeParse function */ -/** - * Zod schema for positive prompt parameter - */ -export const zPositivePrompt = z.string(); -/** - * Type alias for positive prompt parameter, inferred from its zod schema - */ -export type PositivePromptParam = z.infer; -/** - * Validates/type-guards a value as a positive prompt parameter - */ -export const isValidPositivePrompt = ( +// #region Positive prompt +export const zParameterPositivePrompt = z.string(); +export type ParameterPositivePrompt = z.infer; +export const isParameterPositivePrompt = ( val: unknown -): val is PositivePromptParam => zPositivePrompt.safeParse(val).success; +): val is ParameterPositivePrompt => + zParameterPositivePrompt.safeParse(val).success; +// #endregion -/** - * Zod schema for negative prompt parameter - */ -export const zNegativePrompt = z.string(); -/** - * Type alias for negative prompt parameter, inferred from its zod schema - */ -export type NegativePromptParam = z.infer; -/** - * Validates/type-guards a value as a negative prompt parameter - */ -export const isValidNegativePrompt = ( +// #region Negative prompt +export const zParameterNegativePrompt = z.string(); +export type ParameterNegativePrompt = z.infer; +export const isParameterNegativePrompt = ( val: unknown -): val is NegativePromptParam => zNegativePrompt.safeParse(val).success; +): val is ParameterNegativePrompt => + zParameterNegativePrompt.safeParse(val).success; +// #endregion -/** - * Zod schema for SDXL positive style prompt parameter - */ -export const zPositiveStylePromptSDXL = z.string(); -/** - * Type alias for SDXL positive style prompt parameter, inferred from its zod schema - */ -export type PositiveStylePromptSDXLParam = z.infer< - typeof zPositiveStylePromptSDXL +// #region Positive style prompt (SDXL) +export const zParameterPositiveStylePromptSDXL = z.string(); +export type ParameterPositiveStylePromptSDXL = z.infer< + typeof zParameterPositiveStylePromptSDXL >; -/** - * Validates/type-guards a value as a SDXL positive style prompt parameter - */ -export const isValidSDXLPositiveStylePrompt = ( +export const isParameterPositiveStylePromptSDXL = ( val: unknown -): val is PositiveStylePromptSDXLParam => - zPositiveStylePromptSDXL.safeParse(val).success; +): val is ParameterPositiveStylePromptSDXL => + zParameterPositiveStylePromptSDXL.safeParse(val).success; +// #endregion -/** - * Zod schema for SDXL negative style prompt parameter - */ -export const zNegativeStylePromptSDXL = z.string(); -/** - * Type alias for SDXL negative style prompt parameter, inferred from its zod schema - */ -export type NegativeStylePromptSDXLParam = z.infer< - typeof zNegativeStylePromptSDXL +// #region Positive style prompt (SDXL) +export const zParameterNegativeStylePromptSDXL = z.string(); +export type ParameterNegativeStylePromptSDXL = z.infer< + typeof zParameterNegativeStylePromptSDXL >; -/** - * Validates/type-guards a value as a SDXL negative style prompt parameter - */ -export const isValidSDXLNegativeStylePrompt = ( +export const isParameterNegativeStylePromptSDXL = ( val: unknown -): val is NegativeStylePromptSDXLParam => - zNegativeStylePromptSDXL.safeParse(val).success; +): val is ParameterNegativeStylePromptSDXL => + zParameterNegativeStylePromptSDXL.safeParse(val).success; +// #endregion -/** - * Zod schema for steps parameter - */ -export const zSteps = z.number().int().min(1); -/** - * Type alias for steps parameter, inferred from its zod schema - */ -export type StepsParam = z.infer; -/** - * Validates/type-guards a value as a steps parameter - */ -export const isValidSteps = (val: unknown): val is StepsParam => - zSteps.safeParse(val).success; +// #region Steps +export const zParameterSteps = z.number().int().min(1); +export type ParameterSteps = z.infer; +export const isParameterSteps = (val: unknown): val is ParameterSteps => + zParameterSteps.safeParse(val).success; +// #endregion -/** - * Zod schema for CFG scale parameter - */ -export const zCfgScale = z.number().min(1); -/** - * Type alias for CFG scale parameter, inferred from its zod schema - */ -export type CfgScaleParam = z.infer; -/** - * Validates/type-guards a value as a CFG scale parameter - */ -export const isValidCfgScale = (val: unknown): val is CfgScaleParam => - zCfgScale.safeParse(val).success; +// #region CFG scale parameter +export const zParameterCFGScale = z.number().min(1); +export type ParameterCFGScale = z.infer; +export const isParameterCFGScale = (val: unknown): val is ParameterCFGScale => + zParameterCFGScale.safeParse(val).success; +// #endregion -/** - * Zod schema for scheduler parameter - */ -export const zScheduler = z.enum([ - 'euler', - 'deis', - 'ddim', - 'ddpm', - 'dpmpp_2s', - 'dpmpp_2m', - 'dpmpp_2m_sde', - 'dpmpp_sde', - 'heun', - 'kdpm_2', - 'lms', - 'pndm', - 'unipc', - 'euler_k', - 'dpmpp_2s_k', - 'dpmpp_2m_k', - 'dpmpp_2m_sde_k', - 'dpmpp_sde_k', - 'heun_k', - 'lms_k', - 'euler_a', - 'kdpm_2_a', - 'lcm', +// #region Scheduler +export const zParameterScheduler = zSchedulerField; +export type ParameterScheduler = z.infer; +export const isParameterScheduler = (val: unknown): val is ParameterScheduler => + zParameterScheduler.safeParse(val).success; +// #endregion + +// #region seed +export const zParameterSeed = z.number().int().min(0).max(NUMPY_RAND_MAX); +export type ParameterSeed = z.infer; +export const isParameterSeed = (val: unknown): val is ParameterSeed => + zParameterSeed.safeParse(val).success; +// #endregion + +// #region Width +export const zParameterWidth = z.number().multipleOf(8).min(64); +export type ParameterWidth = z.infer; +export const isParameterWidth = (val: unknown): val is ParameterWidth => + zParameterWidth.safeParse(val).success; +// #endregion + +// #region Height +export const zParameterHeight = zParameterWidth; +export type ParameterHeight = z.infer; +export const isParameterHeight = (val: unknown): val is ParameterHeight => + zParameterHeight.safeParse(val).success; +// #endregion + +// #region Resolution +export const zParameterResolution = z.tuple([ + zParameterWidth, + zParameterHeight, ]); -/** - * Type alias for scheduler parameter, inferred from its zod schema - */ -export type SchedulerParam = z.infer; -/** - * Validates/type-guards a value as a scheduler parameter - */ -export const isValidScheduler = (val: unknown): val is SchedulerParam => - zScheduler.safeParse(val).success; +export type ParameterResolution = z.infer; +export const iParameterResolution = ( + val: unknown +): val is ParameterResolution => zParameterResolution.safeParse(val).success; +// #endregion -export const SCHEDULER_LABEL_MAP: Record = { - euler: 'Euler', - deis: 'DEIS', - ddim: 'DDIM', - ddpm: 'DDPM', - dpmpp_sde: 'DPM++ SDE', - dpmpp_2s: 'DPM++ 2S', - dpmpp_2m: 'DPM++ 2M', - dpmpp_2m_sde: 'DPM++ 2M SDE', - heun: 'Heun', - kdpm_2: 'KDPM 2', - lms: 'LMS', - pndm: 'PNDM', - unipc: 'UniPC', - euler_k: 'Euler Karras', - dpmpp_sde_k: 'DPM++ SDE Karras', - dpmpp_2s_k: 'DPM++ 2S Karras', - dpmpp_2m_k: 'DPM++ 2M Karras', - dpmpp_2m_sde_k: 'DPM++ 2M SDE Karras', - heun_k: 'Heun Karras', - lms_k: 'LMS Karras', - euler_a: 'Euler Ancestral', - kdpm_2_a: 'KDPM 2 Ancestral', - lcm: 'LCM', -}; +// #region Model +export const zParameterModel = zMainOrONNXModelField; +export type ParameterModel = z.infer; +export const isParameterModel = (val: unknown): val is ParameterModel => + zParameterModel.safeParse(val).success; +// #endregion -/** - * Zod schema for seed parameter - */ -export const zSeed = z.number().int().min(0).max(NUMPY_RAND_MAX); -/** - * Type alias for seed parameter, inferred from its zod schema - */ -export type SeedParam = z.infer; -/** - * Validates/type-guards a value as a seed parameter - */ -export const isValidSeed = (val: unknown): val is SeedParam => - zSeed.safeParse(val).success; +// #region SDXL Refiner Model +export const zParameterSDXLRefinerModel = zSDXLRefinerModelField; +export type ParameterSDXLRefinerModel = z.infer< + typeof zParameterSDXLRefinerModel +>; +export const isParameterSDXLRefinerModel = ( + val: unknown +): val is ParameterSDXLRefinerModel => + zParameterSDXLRefinerModel.safeParse(val).success; +// #endregion -/** - * Zod schema for width parameter - */ -export const zWidth = z.number().multipleOf(8).min(64); -/** - * Type alias for width parameter, inferred from its zod schema - */ -export type WidthParam = z.infer; -/** - * Validates/type-guards a value as a width parameter - */ -export const isValidWidth = (val: unknown): val is WidthParam => - zWidth.safeParse(val).success; +// #region VAE Model +export const zParameterVAEModel = zVAEModelField; +export type ParameterVAEModel = z.infer; +export const isParameterVAEModel = (val: unknown): val is ParameterVAEModel => + zParameterVAEModel.safeParse(val).success; +// #endregion -/** - * Zod schema for height parameter - */ -export const zHeight = z.number().multipleOf(8).min(64); -/** - * Type alias for height parameter, inferred from its zod schema - */ -export type HeightParam = z.infer; -/** - * Validates/type-guards a value as a height parameter - */ -export const isValidHeight = (val: unknown): val is HeightParam => - zHeight.safeParse(val).success; +// #region LoRA Model +export const zParameterLoRAModel = zLoRAModelField; +export type ParameterLoRAModel = z.infer; +export const isParameterLoRAModel = (val: unknown): val is ParameterLoRAModel => + zParameterLoRAModel.safeParse(val).success; +// #endregion -/** - * Zod schema for resolution parameter - */ -export const zResolution = z.tuple([zWidth, zHeight]); -/** - * Type alias for resolution parameter, inferred from its zod schema - */ -export type ResolutionParam = z.infer; +// #region ControlNet Model +export const zParameterControlNetModel = zControlNetModelField; +export type ParameterControlNetModel = z.infer; +export const isParameterControlNetModel = ( + val: unknown +): val is ParameterControlNetModel => + zParameterControlNetModel.safeParse(val).success; +// #endregion -export const zBaseModel = z.enum([ - 'any', - 'sd-1', - 'sd-2', - 'sdxl', - 'sdxl-refiner', +// #region IP Adapter Model +export const zParameterIPAdapterModel = zIPAdapterModelField; +export type ParameterIPAdapterModel = z.infer; +export const isParameterIPAdapterModel = ( + val: unknown +): val is ParameterIPAdapterModel => + zParameterIPAdapterModel.safeParse(val).success; +// #endregion + +// #region T2I Adapter Model +export const zParameterT2IAdapterModel = zT2IAdapterModelField; +export type ParameterT2IAdapterModel = z.infer< + typeof zParameterT2IAdapterModel +>; +export const isParameterT2IAdapterModel = ( + val: unknown +): val is ParameterT2IAdapterModel => + zParameterT2IAdapterModel.safeParse(val).success; +// #endregion + +// #region Strength (l2l strength) +export const zParameterStrength = z.number().min(0).max(1); +export type ParameterStrength = z.infer; +export const isParameterStrength = (val: unknown): val is ParameterStrength => + zParameterStrength.safeParse(val).success; +// #endregion + +// #region Precision +export const zParameterPrecision = z.enum(['fp16', 'fp32']); +export type ParameterPrecision = z.infer; +export const isParameterPrecision = (val: unknown): val is ParameterPrecision => + zParameterPrecision.safeParse(val).success; +// #endregion + +// #region HRF Method +export const zParameterHRFMethod = z.enum(['ESRGAN', 'bilinear']); +export type ParameterHRFMethod = z.infer; +export const isParameterHRFMethod = (val: unknown): val is ParameterHRFMethod => + zParameterHRFMethod.safeParse(val).success; +// #endregion + +// #region HRF Enabled +export const zParameterHRFEnabled = z.boolean(); +export type ParameterHRFEnabled = z.infer; +export const isParameterHRFEnabled = (val: unknown): val is boolean => + zParameterHRFEnabled.safeParse(val).success && + val !== null && + val !== undefined; +// #endregion + +// #region SDXL Refiner Positive Aesthetic Score +export const zParameterSDXLRefinerPositiveAestheticScore = z + .number() + .min(1) + .max(10); +export type ParameterSDXLRefinerPositiveAestheticScore = z.infer< + typeof zParameterSDXLRefinerPositiveAestheticScore +>; +export const isParameterSDXLRefinerPositiveAestheticScore = ( + val: unknown +): val is ParameterSDXLRefinerPositiveAestheticScore => + zParameterSDXLRefinerPositiveAestheticScore.safeParse(val).success; +// #endregion + +// #region SDXL Refiner Negative Aesthetic Score +export const zParameterSDXLRefinerNegativeAestheticScore = + zParameterSDXLRefinerPositiveAestheticScore; +export type ParameterSDXLRefinerNegativeAestheticScore = z.infer< + typeof zParameterSDXLRefinerNegativeAestheticScore +>; +export const isParameterSDXLRefinerNegativeAestheticScore = ( + val: unknown +): val is ParameterSDXLRefinerNegativeAestheticScore => + zParameterSDXLRefinerNegativeAestheticScore.safeParse(val).success; +// #endregion + +// #region SDXL Refiner Start +export const zParameterSDXLRefinerStart = z.number().min(0).max(1); +export type ParameterSDXLRefinerStart = z.infer< + typeof zParameterSDXLRefinerStart +>; +export const isParameterSDXLRefinerStart = ( + val: unknown +): val is ParameterSDXLRefinerStart => + zParameterSDXLRefinerStart.safeParse(val).success; +// #endregion + +// #region Mask Blur Method +export const zParameterMaskBlurMethod = z.enum(['box', 'gaussian']); +export type ParameterMaskBlurMethod = z.infer; +export const isParameterMaskBlurMethod = ( + val: unknown +): val is ParameterMaskBlurMethod => + zParameterMaskBlurMethod.safeParse(val).success; +// #endregion + +// #region Canvas Coherence Mode +export const zParameterCanvasCoherenceMode = z.enum([ + 'unmasked', + 'mask', + 'edge', ]); - -export type BaseModelParam = z.infer; - -/** - * Zod schema for main model parameter - * TODO: Make this a dynamically generated enum? - */ -export const zMainModel = z.object({ - model_name: z.string().min(1), - base_model: zBaseModel, - model_type: z.literal('main'), -}); -/** - * Type alias for main model parameter, inferred from its zod schema - */ -export type MainModelParam = z.infer; -/** - * Validates/type-guards a value as a main model parameter - */ -export const isValidMainModel = (val: unknown): val is MainModelParam => - zMainModel.safeParse(val).success; - -/** - * Zod schema for SDXL refiner model parameter - * TODO: Make this a dynamically generated enum? - */ -export const zSDXLRefinerModel = z.object({ - model_name: z.string().min(1), - base_model: z.literal('sdxl-refiner'), - model_type: z.literal('main'), -}); -/** - * Type alias for SDXL refiner model parameter, inferred from its zod schema - */ -export type SDXLRefinerModelParam = z.infer; -/** - * Validates/type-guards a value as a SDXL refiner model parameter - */ -export const isValidSDXLRefinerModel = ( - val: unknown -): val is SDXLRefinerModelParam => zSDXLRefinerModel.safeParse(val).success; - -/** - * Zod schema for Onnx model parameter - * TODO: Make this a dynamically generated enum? - */ -export const zOnnxModel = z.object({ - model_name: z.string().min(1), - base_model: zBaseModel, - model_type: z.literal('onnx'), -}); -/** - * Type alias for Onnx model parameter, inferred from its zod schema - */ -export type OnnxModelParam = z.infer; -/** - * Validates/type-guards a value as a Onnx model parameter - */ -export const isValidOnnxModel = (val: unknown): val is OnnxModelParam => - zOnnxModel.safeParse(val).success; - -export const zMainOrOnnxModel = z.union([zMainModel, zOnnxModel]); - -/** - * Zod schema for VAE parameter - */ -export const zVaeModel = z.object({ - model_name: z.string().min(1), - base_model: zBaseModel, -}); -/** - * Type alias for model parameter, inferred from its zod schema - */ -export type VaeModelParam = z.infer; -/** - * Validates/type-guards a value as a model parameter - */ -export const isValidVaeModel = (val: unknown): val is VaeModelParam => - zVaeModel.safeParse(val).success; -/** - * Zod schema for LoRA - */ -export const zLoRAModel = z.object({ - model_name: z.string().min(1), - base_model: zBaseModel, -}); -/** - * Type alias for model parameter, inferred from its zod schema - */ -export type LoRAModelParam = z.infer; -/** - * Validates/type-guards a value as a model parameter - */ -export const isValidLoRAModel = (val: unknown): val is LoRAModelParam => - zLoRAModel.safeParse(val).success; -/** - * Zod schema for ControlNet models - */ -export const zControlNetModel = z.object({ - model_name: z.string().min(1), - base_model: zBaseModel, -}); -/** - * Type alias for model parameter, inferred from its zod schema - */ -export type ControlNetModelParam = z.infer; -/** - * Validates/type-guards a value as a model parameter - */ -export const isValidControlNetModel = ( - val: unknown -): val is ControlNetModelParam => zControlNetModel.safeParse(val).success; -/** - * Zod schema for IP-Adapter models - */ -export const zIPAdapterModel = z.object({ - model_name: z.string().min(1), - base_model: zBaseModel, -}); -/** - * Type alias for model parameter, inferred from its zod schema - */ -export type IPAdapterModelParam = z.infer; -/** - * Zod schema for T2I-Adapter models - */ -export const zT2IAdapterModel = z.object({ - model_name: z.string().min(1), - base_model: zBaseModel, -}); -export const isValidT2IAdapterModel = ( - val: unknown -): val is T2IAdapterModelParam => zT2IAdapterModel.safeParse(val).success; - -/** - * Type alias for model parameter, inferred from its zod schema - */ -export type T2IAdapterModelParam = z.infer; -/** - * Zod schema for l2l strength parameter - */ -/** - * Validates/type-guards a value as a model parameter - */ -export const isValidIPAdapterModel = ( - val: unknown -): val is IPAdapterModelParam => zIPAdapterModel.safeParse(val).success; -export const zStrength = z.number().min(0).max(1); -/** - * Type alias for l2l strength parameter, inferred from its zod schema - */ -export type StrengthParam = z.infer; -/** - * Validates/type-guards a value as a l2l strength parameter - */ -export const isValidStrength = (val: unknown): val is StrengthParam => - zStrength.safeParse(val).success; - -/** - * Zod schema for a precision parameter - */ -export const zPrecision = z.enum(['fp16', 'fp32']); -/** - * Type alias for precision parameter, inferred from its zod schema - */ -export type PrecisionParam = z.infer; -/** - * Validates/type-guards a value as a precision parameter - */ -export const isValidPrecision = (val: unknown): val is PrecisionParam => - zPrecision.safeParse(val).success; - -/** - * Zod schema for a high resolution fix method parameter. - */ -export const zHrfMethod = z.enum(['ESRGAN', 'bilinear']); -/** - * Type alias for high resolution fix method parameter, inferred from its zod schema - */ -export type HrfMethodParam = z.infer; -/** - * Validates/type-guards a value as a high resolution fix method parameter - */ -export const isValidHrfMethod = (val: unknown): val is HrfMethodParam => - zHrfMethod.safeParse(val).success; - -/** - * Zod schema for SDXL refiner positive aesthetic score parameter - */ -export const zSDXLRefinerPositiveAestheticScore = z.number().min(1).max(10); -/** - * Type alias for SDXL refiner aesthetic positive score parameter, inferred from its zod schema - */ -export type SDXLRefinerPositiveAestheticScoreParam = z.infer< - typeof zSDXLRefinerPositiveAestheticScore +export type ParameterCanvasCoherenceMode = z.infer< + typeof zParameterCanvasCoherenceMode >; -/** - * Validates/type-guards a value as a SDXL refiner positive aesthetic score parameter - */ -export const isValidSDXLRefinerPositiveAestheticScore = ( +export const isParameterCanvasCoherenceMode = ( val: unknown -): val is SDXLRefinerPositiveAestheticScoreParam => - zSDXLRefinerPositiveAestheticScore.safeParse(val).success; - -/** - * Zod schema for SDXL refiner negative aesthetic score parameter - */ -export const zSDXLRefinerNegativeAestheticScore = z.number().min(1).max(10); -/** - * Type alias for SDXL refiner aesthetic negative score parameter, inferred from its zod schema - */ -export type SDXLRefinerNegativeAestheticScoreParam = z.infer< - typeof zSDXLRefinerNegativeAestheticScore ->; -/** - * Validates/type-guards a value as a SDXL refiner negative aesthetic score parameter - */ -export const isValidSDXLRefinerNegativeAestheticScore = ( - val: unknown -): val is SDXLRefinerNegativeAestheticScoreParam => - zSDXLRefinerNegativeAestheticScore.safeParse(val).success; - -/** - * Zod schema for SDXL start parameter - */ -export const zSDXLRefinerstart = z.number().min(0).max(1); -/** - * Type alias for SDXL start, inferred from its zod schema - */ -export type SDXLRefinerStartParam = z.infer; -/** - * Validates/type-guards a value as a SDXL refiner aesthetic score parameter - */ -export const isValidSDXLRefinerStart = ( - val: unknown -): val is SDXLRefinerStartParam => zSDXLRefinerstart.safeParse(val).success; - -/** - * Zod schema for a mask blur method parameter - */ -export const zMaskBlurMethod = z.enum(['box', 'gaussian']); -/** - * Type alias for mask blur method parameter, inferred from its zod schema - */ -export type MaskBlurMethodParam = z.infer; -/** - * Validates/type-guards a value as a mask blur method parameter - */ -export const isValidMaskBlurMethod = ( - val: unknown -): val is MaskBlurMethodParam => zMaskBlurMethod.safeParse(val).success; - -/** - * Zod schema for a Canvas Coherence Mode method parameter - */ -export const zCanvasCoherenceMode = z.enum(['unmasked', 'mask', 'edge']); -/** - * Type alias for Canvas Coherence Mode parameter, inferred from its zod schema - */ -export type CanvasCoherenceModeParam = z.infer; -/** - * Validates/type-guards a value as a mask blur method parameter - */ -export const isValidCoherenceModeParam = ( - val: unknown -): val is CanvasCoherenceModeParam => - zCanvasCoherenceMode.safeParse(val).success; - -/** - * Zod schema for a boolean. - */ -export const zBoolean = z.boolean(); - -/** - * Validates/type-guards a value as a boolean parameter - */ -export const isValidBoolean = (val: unknown): val is boolean => - zBoolean.safeParse(val).success && val !== null && val !== undefined; - -// /** -// * Zod schema for BaseModelType -// */ -// export const zBaseModelType = z.enum(['sd-1', 'sd-2']); -// /** -// * Type alias for base model type, inferred from its zod schema. Should be identical to the type alias from OpenAPI. -// */ -// export type BaseModelType = z.infer; -// /** -// * Validates/type-guards a value as a base model type -// */ -// export const isValidBaseModelType = (val: unknown): val is BaseModelType => -// zBaseModelType.safeParse(val).success; +): val is ParameterCanvasCoherenceMode => + zParameterCanvasCoherenceMode.safeParse(val).success; +// #endregion diff --git a/invokeai/frontend/web/src/features/parameters/util/modelIdToControlNetModelParam.ts b/invokeai/frontend/web/src/features/parameters/util/modelIdToControlNetModelParam.ts index 30e6fdcd3d..d823edbce2 100644 --- a/invokeai/frontend/web/src/features/parameters/util/modelIdToControlNetModelParam.ts +++ b/invokeai/frontend/web/src/features/parameters/util/modelIdToControlNetModelParam.ts @@ -1,5 +1,5 @@ import { logger } from 'app/logging/logger'; -import { zControlNetModel } from 'features/parameters/types/parameterSchemas'; +import { zParameterControlNetModel } from 'features/parameters/types/parameterSchemas'; import { ControlNetModelField } from 'services/api/types'; export const modelIdToControlNetModelParam = ( @@ -8,7 +8,7 @@ export const modelIdToControlNetModelParam = ( const log = logger('models'); const [base_model, _model_type, model_name] = controlNetModelId.split('/'); - const result = zControlNetModel.safeParse({ + const result = zParameterControlNetModel.safeParse({ base_model, model_name, }); diff --git a/invokeai/frontend/web/src/features/parameters/util/modelIdToIPAdapterModelParams.ts b/invokeai/frontend/web/src/features/parameters/util/modelIdToIPAdapterModelParams.ts index 4d58046545..f3ccce47df 100644 --- a/invokeai/frontend/web/src/features/parameters/util/modelIdToIPAdapterModelParams.ts +++ b/invokeai/frontend/web/src/features/parameters/util/modelIdToIPAdapterModelParams.ts @@ -1,5 +1,5 @@ import { logger } from 'app/logging/logger'; -import { zIPAdapterModel } from 'features/parameters/types/parameterSchemas'; +import { zParameterIPAdapterModel } from 'features/parameters/types/parameterSchemas'; import { IPAdapterModelField } from 'services/api/types'; export const modelIdToIPAdapterModelParam = ( @@ -8,7 +8,7 @@ export const modelIdToIPAdapterModelParam = ( const log = logger('models'); const [base_model, _model_type, model_name] = ipAdapterModelId.split('/'); - const result = zIPAdapterModel.safeParse({ + const result = zParameterIPAdapterModel.safeParse({ base_model, model_name, }); diff --git a/invokeai/frontend/web/src/features/parameters/util/modelIdToLoRAModelParam.ts b/invokeai/frontend/web/src/features/parameters/util/modelIdToLoRAModelParam.ts index bf4c6454fb..abe0e9e58f 100644 --- a/invokeai/frontend/web/src/features/parameters/util/modelIdToLoRAModelParam.ts +++ b/invokeai/frontend/web/src/features/parameters/util/modelIdToLoRAModelParam.ts @@ -1,14 +1,17 @@ import { logger } from 'app/logging/logger'; -import { LoRAModelParam, zLoRAModel } from '../types/parameterSchemas'; +import { + ParameterLoRAModel, + zParameterLoRAModel, +} from '../types/parameterSchemas'; export const modelIdToLoRAModelParam = ( loraModelId: string -): LoRAModelParam | undefined => { +): ParameterLoRAModel | undefined => { const log = logger('models'); const [base_model, _model_type, model_name] = loraModelId.split('/'); - const result = zLoRAModel.safeParse({ + const result = zParameterLoRAModel.safeParse({ base_model, model_name, }); diff --git a/invokeai/frontend/web/src/features/parameters/util/modelIdToMainModelParam.ts b/invokeai/frontend/web/src/features/parameters/util/modelIdToMainModelParam.ts index 78a3bcc515..9500546f84 100644 --- a/invokeai/frontend/web/src/features/parameters/util/modelIdToMainModelParam.ts +++ b/invokeai/frontend/web/src/features/parameters/util/modelIdToMainModelParam.ts @@ -1,17 +1,16 @@ import { logger } from 'app/logging/logger'; import { - MainModelParam, - OnnxModelParam, - zMainOrOnnxModel, + ParameterModel, + zParameterModel, } from 'features/parameters/types/parameterSchemas'; export const modelIdToMainModelParam = ( mainModelId: string -): OnnxModelParam | MainModelParam | undefined => { +): ParameterModel | undefined => { const log = logger('models'); const [base_model, model_type, model_name] = mainModelId.split('/'); - const result = zMainOrOnnxModel.safeParse({ + const result = zParameterModel.safeParse({ base_model, model_name, model_type, diff --git a/invokeai/frontend/web/src/features/parameters/util/modelIdToSDXLRefinerModelParam.ts b/invokeai/frontend/web/src/features/parameters/util/modelIdToSDXLRefinerModelParam.ts index 780ac56459..5ed185ef8e 100644 --- a/invokeai/frontend/web/src/features/parameters/util/modelIdToSDXLRefinerModelParam.ts +++ b/invokeai/frontend/web/src/features/parameters/util/modelIdToSDXLRefinerModelParam.ts @@ -1,16 +1,16 @@ import { logger } from 'app/logging/logger'; import { - SDXLRefinerModelParam, - zSDXLRefinerModel, + ParameterSDXLRefinerModel, + zParameterSDXLRefinerModel, } from 'features/parameters/types/parameterSchemas'; export const modelIdToSDXLRefinerModelParam = ( mainModelId: string -): SDXLRefinerModelParam | undefined => { +): ParameterSDXLRefinerModel | undefined => { const log = logger('models'); const [base_model, model_type, model_name] = mainModelId.split('/'); - const result = zSDXLRefinerModel.safeParse({ + const result = zParameterSDXLRefinerModel.safeParse({ base_model, model_name, model_type, diff --git a/invokeai/frontend/web/src/features/parameters/util/modelIdToT2IAdapterModelParam.ts b/invokeai/frontend/web/src/features/parameters/util/modelIdToT2IAdapterModelParam.ts index 95f1a3f25a..3d66ef66e8 100644 --- a/invokeai/frontend/web/src/features/parameters/util/modelIdToT2IAdapterModelParam.ts +++ b/invokeai/frontend/web/src/features/parameters/util/modelIdToT2IAdapterModelParam.ts @@ -1,5 +1,5 @@ import { logger } from 'app/logging/logger'; -import { zT2IAdapterModel } from 'features/parameters/types/parameterSchemas'; +import { zParameterT2IAdapterModel } from 'features/parameters/types/parameterSchemas'; import { T2IAdapterModelField } from 'services/api/types'; export const modelIdToT2IAdapterModelParam = ( @@ -8,7 +8,7 @@ export const modelIdToT2IAdapterModelParam = ( const log = logger('models'); const [base_model, _model_type, model_name] = t2iAdapterModelId.split('/'); - const result = zT2IAdapterModel.safeParse({ + const result = zParameterT2IAdapterModel.safeParse({ base_model, model_name, }); diff --git a/invokeai/frontend/web/src/features/parameters/util/modelIdToVAEModelParam.ts b/invokeai/frontend/web/src/features/parameters/util/modelIdToVAEModelParam.ts index 1f3908dd47..a30dbcc12f 100644 --- a/invokeai/frontend/web/src/features/parameters/util/modelIdToVAEModelParam.ts +++ b/invokeai/frontend/web/src/features/parameters/util/modelIdToVAEModelParam.ts @@ -1,13 +1,16 @@ import { logger } from 'app/logging/logger'; -import { VaeModelParam, zVaeModel } from '../types/parameterSchemas'; +import { + ParameterVAEModel, + zParameterVAEModel, +} from '../types/parameterSchemas'; export const modelIdToVAEModelParam = ( vaeModelId: string -): VaeModelParam | undefined => { +): ParameterVAEModel | undefined => { const log = logger('models'); const [base_model, _model_type, model_name] = vaeModelId.split('/'); - const result = zVaeModel.safeParse({ + const result = zParameterVAEModel.safeParse({ base_model, model_name, }); diff --git a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerScheduler.tsx b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerScheduler.tsx index 50400aef9f..90a3e6eeed 100644 --- a/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerScheduler.tsx +++ b/invokeai/frontend/web/src/features/sdxl/components/SDXLRefiner/ParamSDXLRefinerScheduler.tsx @@ -3,10 +3,8 @@ import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect'; -import { - SCHEDULER_LABEL_MAP, - SchedulerParam, -} from 'features/parameters/types/parameterSchemas'; +import { ParameterScheduler } from 'features/parameters/types/parameterSchemas'; +import { SCHEDULER_LABEL_MAP } from 'features/parameters/types/constants'; import { setRefinerScheduler } from 'features/sdxl/store/sdxlSlice'; import { map } from 'lodash-es'; import { memo, useCallback } from 'react'; @@ -22,7 +20,7 @@ const selector = createSelector( const data = map(SCHEDULER_LABEL_MAP, (label, name) => ({ value: name, label: label, - group: enabledSchedulers.includes(name as SchedulerParam) + group: enabledSchedulers.includes(name as ParameterScheduler) ? 'Favorites' : undefined, })).sort((a, b) => a.label.localeCompare(b.label)); @@ -45,7 +43,7 @@ const ParamSDXLRefinerScheduler = () => { if (!v) { return; } - dispatch(setRefinerScheduler(v as SchedulerParam)); + dispatch(setRefinerScheduler(v as ParameterScheduler)); }, [dispatch] ); diff --git a/invokeai/frontend/web/src/features/sdxl/store/sdxlSlice.ts b/invokeai/frontend/web/src/features/sdxl/store/sdxlSlice.ts index 73f5779a52..861fa91e23 100644 --- a/invokeai/frontend/web/src/features/sdxl/store/sdxlSlice.ts +++ b/invokeai/frontend/web/src/features/sdxl/store/sdxlSlice.ts @@ -1,21 +1,21 @@ import { PayloadAction, createSlice } from '@reduxjs/toolkit'; import { - NegativeStylePromptSDXLParam, - PositiveStylePromptSDXLParam, - SDXLRefinerModelParam, - SchedulerParam, + ParameterNegativeStylePromptSDXL, + ParameterPositiveStylePromptSDXL, + ParameterSDXLRefinerModel, + ParameterScheduler, } from 'features/parameters/types/parameterSchemas'; type SDXLState = { - positiveStylePrompt: PositiveStylePromptSDXLParam; - negativeStylePrompt: NegativeStylePromptSDXLParam; + positiveStylePrompt: ParameterPositiveStylePromptSDXL; + negativeStylePrompt: ParameterNegativeStylePromptSDXL; shouldConcatSDXLStylePrompt: boolean; shouldUseSDXLRefiner: boolean; sdxlImg2ImgDenoisingStrength: number; - refinerModel: SDXLRefinerModelParam | null; + refinerModel: ParameterSDXLRefinerModel | null; refinerSteps: number; refinerCFGScale: number; - refinerScheduler: SchedulerParam; + refinerScheduler: ParameterScheduler; refinerPositiveAestheticScore: number; refinerNegativeAestheticScore: number; refinerStart: number; @@ -57,7 +57,7 @@ const sdxlSlice = createSlice({ }, refinerModelChanged: ( state, - action: PayloadAction + action: PayloadAction ) => { state.refinerModel = action.payload; }, @@ -67,7 +67,7 @@ const sdxlSlice = createSlice({ setRefinerCFGScale: (state, action: PayloadAction) => { state.refinerCFGScale = action.payload; }, - setRefinerScheduler: (state, action: PayloadAction) => { + setRefinerScheduler: (state, action: PayloadAction) => { state.refinerScheduler = action.payload; }, setRefinerPositiveAestheticScore: ( diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx index 0be58a8815..270d9aed2c 100644 --- a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx +++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsSchedulers.tsx @@ -1,10 +1,8 @@ import { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIMantineMultiSelect from 'common/components/IAIMantineMultiSelect'; -import { - SCHEDULER_LABEL_MAP, - SchedulerParam, -} from 'features/parameters/types/parameterSchemas'; +import { ParameterScheduler } from 'features/parameters/types/parameterSchemas'; +import { SCHEDULER_LABEL_MAP } from 'features/parameters/types/constants'; import { favoriteSchedulersChanged } from 'features/ui/store/uiSlice'; import { map } from 'lodash-es'; import { useCallback } from 'react'; @@ -26,7 +24,7 @@ export default function SettingsSchedulers() { const handleChange = useCallback( (v: string[]) => { - dispatch(favoriteSchedulersChanged(v as SchedulerParam[])); + dispatch(favoriteSchedulersChanged(v as ParameterScheduler[])); }, [dispatch] ); diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabCoreParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabCoreParameters.tsx index 2668062da6..46ea9fb051 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabCoreParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/ImageToImage/ImageToImageTabCoreParameters.tsx @@ -9,7 +9,7 @@ import ParamSteps from 'features/parameters/components/Parameters/Core/ParamStep import ImageToImageFit from 'features/parameters/components/Parameters/ImageToImage/ImageToImageFit'; import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength'; import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull'; -import { useCoreParametersCollapseLabel } from 'features/parameters/util/useCoreParametersCollapseLabel'; +import { useCoreParametersCollapseLabel } from 'features/parameters/hooks/useCoreParametersCollapseLabel'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters.tsx index 3f3cf2db05..29ab63cb1c 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/TextToImage/TextToImageTabCoreParameters.tsx @@ -7,7 +7,7 @@ import ParamModelandVAEandScheduler from 'features/parameters/components/Paramet import ParamSize from 'features/parameters/components/Parameters/Core/ParamSize'; import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps'; import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull'; -import { useCoreParametersCollapseLabel } from 'features/parameters/util/useCoreParametersCollapseLabel'; +import { useCoreParametersCollapseLabel } from 'features/parameters/hooks/useCoreParametersCollapseLabel'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; diff --git a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasCoreParameters.tsx b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasCoreParameters.tsx index 40a5026d09..bc86386515 100644 --- a/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasCoreParameters.tsx +++ b/invokeai/frontend/web/src/features/ui/components/tabs/UnifiedCanvas/UnifiedCanvasCoreParameters.tsx @@ -8,7 +8,7 @@ import ParamModelandVAEandScheduler from 'features/parameters/components/Paramet import ParamSteps from 'features/parameters/components/Parameters/Core/ParamSteps'; import ImageToImageStrength from 'features/parameters/components/Parameters/ImageToImage/ImageToImageStrength'; import ParamSeedFull from 'features/parameters/components/Parameters/Seed/ParamSeedFull'; -import { useCoreParametersCollapseLabel } from 'features/parameters/util/useCoreParametersCollapseLabel'; +import { useCoreParametersCollapseLabel } from 'features/parameters/hooks/useCoreParametersCollapseLabel'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; diff --git a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts index 9782d0bfac..69cfe42827 100644 --- a/invokeai/frontend/web/src/features/ui/store/uiSlice.ts +++ b/invokeai/frontend/web/src/features/ui/store/uiSlice.ts @@ -1,7 +1,7 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice } from '@reduxjs/toolkit'; import { initialImageChanged } from 'features/parameters/store/generationSlice'; -import { SchedulerParam } from 'features/parameters/types/parameterSchemas'; +import { ParameterScheduler } from 'features/parameters/types/parameterSchemas'; import { InvokeTabName } from './tabMap'; import { UIState } from './uiTypes'; @@ -50,7 +50,7 @@ export const uiSlice = createSlice({ }, favoriteSchedulersChanged: ( state, - action: PayloadAction + action: PayloadAction ) => { state.favoriteSchedulers = action.payload; }, diff --git a/invokeai/frontend/web/src/features/ui/store/uiTypes.ts b/invokeai/frontend/web/src/features/ui/store/uiTypes.ts index 1b9fee6989..b532043054 100644 --- a/invokeai/frontend/web/src/features/ui/store/uiTypes.ts +++ b/invokeai/frontend/web/src/features/ui/store/uiTypes.ts @@ -1,4 +1,4 @@ -import { SchedulerParam } from 'features/parameters/types/parameterSchemas'; +import { ParameterScheduler } from 'features/parameters/types/parameterSchemas'; import { InvokeTabName } from './tabMap'; export type Coordinates = { @@ -23,7 +23,7 @@ export interface UIState { shouldShowProgressInViewer: boolean; shouldShowEmbeddingPicker: boolean; shouldAutoChangeDimensions: boolean; - favoriteSchedulers: SchedulerParam[]; + favoriteSchedulers: ParameterScheduler[]; globalContextMenuCloseTrigger: number; panels: Record; } diff --git a/invokeai/frontend/web/src/services/api/endpoints/images.ts b/invokeai/frontend/web/src/services/api/endpoints/images.ts index 166d00a3db..97473b21f2 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/images.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/images.ts @@ -7,7 +7,7 @@ import { IMAGE_CATEGORIES, IMAGE_LIMIT, } from 'features/gallery/store/types'; -import { CoreMetadata, zCoreMetadata } from 'features/nodes/types/types'; +import { CoreMetadata, zCoreMetadata } from 'features/nodes/types/metadata'; import { keyBy } from 'lodash-es'; import { ApiTagDescription, LIST_TAG, api } from '..'; import { components, paths } from '../schema'; diff --git a/invokeai/frontend/web/src/services/api/endpoints/workflows.ts b/invokeai/frontend/web/src/services/api/endpoints/workflows.ts index 1792788d57..b7611cb397 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/workflows.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/workflows.ts @@ -1,11 +1,11 @@ import { logger } from 'app/logging/logger'; -import { Workflow, zWorkflow } from 'features/nodes/types/types'; +import { WorkflowV2, zWorkflowV2 } from 'features/nodes/types/workflow'; import { api } from '..'; import { paths } from '../schema'; export const workflowsApi = api.injectEndpoints({ endpoints: (build) => ({ - getWorkflow: build.query({ + getWorkflow: build.query({ query: (workflow_id) => `workflows/i/${workflow_id}`, providesTags: (result, error, workflow_id) => [ { type: 'Workflow', id: workflow_id }, @@ -14,7 +14,7 @@ export const workflowsApi = api.injectEndpoints({ response: paths['/api/v1/workflows/i/{workflow_id}']['get']['responses']['200']['content']['application/json'] ) => { if (response) { - const result = zWorkflow.safeParse(response); + const result = zWorkflowV2.safeParse(response); if (result.success) { return result.data; } else { diff --git a/invokeai/frontend/web/src/services/api/guards.ts b/invokeai/frontend/web/src/services/api/guards.ts deleted file mode 100644 index 2893d88e07..0000000000 --- a/invokeai/frontend/web/src/services/api/guards.ts +++ /dev/null @@ -1,67 +0,0 @@ -import { get, isObject, isString } from 'lodash-es'; -import { - GraphExecutionState, - GraphInvocationOutput, - ImageOutput, - IterateInvocationOutput, - CollectInvocationOutput, - ImageField, - LatentsOutput, - ResourceOrigin, - ImageDTO, - BoardDTO, -} from 'services/api/types'; - -export const isImageDTO = (obj: unknown): obj is ImageDTO => { - return ( - isObject(obj) && - 'image_name' in obj && - isString(obj?.image_name) && - 'thumbnail_url' in obj && - isString(obj?.thumbnail_url) && - 'image_url' in obj && - isString(obj?.image_url) && - 'image_origin' in obj && - isString(obj?.image_origin) && - 'created_at' in obj && - isString(obj?.created_at) - ); -}; - -export const isBoardDTO = (obj: unknown): obj is BoardDTO => { - return ( - isObject(obj) && - 'board_id' in obj && - isString(obj?.board_id) && - 'board_name' in obj && - isString(obj?.board_name) - ); -}; - -export const isImageOutput = ( - output: GraphExecutionState['results'][string] -): output is ImageOutput => output?.type === 'image_output'; - -export const isLatentsOutput = ( - output: GraphExecutionState['results'][string] -): output is LatentsOutput => output?.type === 'latents_output'; - -export const isGraphOutput = ( - output: GraphExecutionState['results'][string] -): output is GraphInvocationOutput => output?.type === 'graph_output'; - -export const isIterateOutput = ( - output: GraphExecutionState['results'][string] -): output is IterateInvocationOutput => output?.type === 'iterate_output'; - -export const isCollectOutput = ( - output: GraphExecutionState['results'][string] -): output is CollectInvocationOutput => output?.type === 'collect_output'; - -export const isResourceOrigin = (t: unknown): t is ResourceOrigin => - isString(t) && ['internal', 'external'].includes(t); - -export const isImageField = (imageField: unknown): imageField is ImageField => - isObject(imageField) && - isString(get(imageField, 'image_name')) && - isResourceOrigin(get(imageField, 'image_origin')); diff --git a/invokeai/frontend/web/src/services/api/schema.d.ts b/invokeai/frontend/web/src/services/api/schema.d.ts index c82a195028..5d5b382c4c 100644 --- a/invokeai/frontend/web/src/services/api/schema.d.ts +++ b/invokeai/frontend/web/src/services/api/schema.d.ts @@ -930,6 +930,7 @@ export type components = { /** * Collection * @description The collection of boolean values + * @default [] */ collection?: boolean[]; /** @@ -1310,6 +1311,7 @@ export type components = { /** * Collection * @description The collection, will be provided on execution + * @default [] */ collection?: unknown[]; /** @@ -1581,6 +1583,7 @@ export type components = { /** * Collection * @description The collection of conditioning tensors + * @default [] */ collection?: components["schemas"]["ConditioningField"][]; /** @@ -2893,6 +2896,7 @@ export type components = { /** * Collection * @description The collection of float values + * @default [] */ collection?: number[]; /** @@ -3216,7 +3220,7 @@ export type components = { * @description The nodes in this graph */ nodes?: { - [key: string]: components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["StepParamEasingInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["LinearUIOutputInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"]; + [key: string]: components["schemas"]["ScaleLatentsInvocation"] | components["schemas"]["SubtractInvocation"] | components["schemas"]["BooleanInvocation"] | components["schemas"]["BooleanCollectionInvocation"] | components["schemas"]["RandomFloatInvocation"] | components["schemas"]["ImageHueAdjustmentInvocation"] | components["schemas"]["RandomIntInvocation"] | components["schemas"]["NoiseInvocation"] | components["schemas"]["FaceMaskInvocation"] | components["schemas"]["FaceIdentifierInvocation"] | components["schemas"]["NormalbaeImageProcessorInvocation"] | components["schemas"]["ImageConvertInvocation"] | components["schemas"]["ConditioningInvocation"] | components["schemas"]["HedImageProcessorInvocation"] | components["schemas"]["LatentsInvocation"] | components["schemas"]["ZoeDepthImageProcessorInvocation"] | components["schemas"]["CompelInvocation"] | components["schemas"]["RandomRangeInvocation"] | components["schemas"]["CreateDenoiseMaskInvocation"] | components["schemas"]["OpenposeImageProcessorInvocation"] | components["schemas"]["StringInvocation"] | components["schemas"]["ColorInvocation"] | components["schemas"]["ContentShuffleImageProcessorInvocation"] | components["schemas"]["MaskFromAlphaInvocation"] | components["schemas"]["MetadataInvocation"] | components["schemas"]["LatentsToImageInvocation"] | components["schemas"]["StringReplaceInvocation"] | components["schemas"]["IntegerCollectionInvocation"] | components["schemas"]["ImageScaleInvocation"] | components["schemas"]["LaMaInfillInvocation"] | components["schemas"]["ImageBlurInvocation"] | components["schemas"]["RoundInvocation"] | components["schemas"]["ControlNetInvocation"] | components["schemas"]["ImageToLatentsInvocation"] | components["schemas"]["BlendLatentsInvocation"] | components["schemas"]["SDXLLoraLoaderInvocation"] | components["schemas"]["FloatToIntegerInvocation"] | components["schemas"]["CollectInvocation"] | components["schemas"]["FreeUInvocation"] | components["schemas"]["StringCollectionInvocation"] | components["schemas"]["LineartImageProcessorInvocation"] | components["schemas"]["FloatMathInvocation"] | components["schemas"]["ImageInvocation"] | components["schemas"]["ImageLerpInvocation"] | components["schemas"]["InfillTileInvocation"] | components["schemas"]["ImageInverseLerpInvocation"] | components["schemas"]["LineartAnimeImageProcessorInvocation"] | components["schemas"]["CoreMetadataInvocation"] | components["schemas"]["SeamlessModeInvocation"] | components["schemas"]["SDXLModelLoaderInvocation"] | components["schemas"]["OnnxModelLoaderInvocation"] | components["schemas"]["ColorCorrectInvocation"] | components["schemas"]["SaveImageInvocation"] | components["schemas"]["StringSplitInvocation"] | components["schemas"]["DivideInvocation"] | components["schemas"]["ClipSkipInvocation"] | components["schemas"]["MetadataItemInvocation"] | components["schemas"]["VaeLoaderInvocation"] | components["schemas"]["LoraLoaderInvocation"] | components["schemas"]["MediapipeFaceProcessorInvocation"] | components["schemas"]["ImageMultiplyInvocation"] | components["schemas"]["AddInvocation"] | components["schemas"]["LinearUIOutputInvocation"] | components["schemas"]["ONNXLatentsToImageInvocation"] | components["schemas"]["RangeInvocation"] | components["schemas"]["LatentsCollectionInvocation"] | components["schemas"]["ImageWatermarkInvocation"] | components["schemas"]["ESRGANInvocation"] | components["schemas"]["ColorMapImageProcessorInvocation"] | components["schemas"]["BlankImageInvocation"] | components["schemas"]["ImageNSFWBlurInvocation"] | components["schemas"]["SchedulerInvocation"] | components["schemas"]["GraphInvocation"] | components["schemas"]["PidiImageProcessorInvocation"] | components["schemas"]["ONNXPromptInvocation"] | components["schemas"]["LeresImageProcessorInvocation"] | components["schemas"]["FloatCollectionInvocation"] | components["schemas"]["T2IAdapterInvocation"] | components["schemas"]["IPAdapterInvocation"] | components["schemas"]["MultiplyInvocation"] | components["schemas"]["InfillPatchMatchInvocation"] | components["schemas"]["MainModelLoaderInvocation"] | components["schemas"]["ImagePasteInvocation"] | components["schemas"]["ImageCropInvocation"] | components["schemas"]["MaskEdgeInvocation"] | components["schemas"]["IntegerMathInvocation"] | components["schemas"]["DynamicPromptInvocation"] | components["schemas"]["FaceOffInvocation"] | components["schemas"]["SDXLRefinerCompelPromptInvocation"] | components["schemas"]["ResizeLatentsInvocation"] | components["schemas"]["SegmentAnythingProcessorInvocation"] | components["schemas"]["InfillColorInvocation"] | components["schemas"]["ImageCollectionInvocation"] | components["schemas"]["FloatLinearRangeInvocation"] | components["schemas"]["TileResamplerProcessorInvocation"] | components["schemas"]["CannyImageProcessorInvocation"] | components["schemas"]["MlsdImageProcessorInvocation"] | components["schemas"]["RangeOfSizeInvocation"] | components["schemas"]["SDXLRefinerModelLoaderInvocation"] | components["schemas"]["MaskCombineInvocation"] | components["schemas"]["ImageChannelOffsetInvocation"] | components["schemas"]["IterateInvocation"] | components["schemas"]["IntegerInvocation"] | components["schemas"]["CvInpaintInvocation"] | components["schemas"]["ImageChannelMultiplyInvocation"] | components["schemas"]["ONNXTextToLatentsInvocation"] | components["schemas"]["ImageResizeInvocation"] | components["schemas"]["FloatInvocation"] | components["schemas"]["SDXLCompelPromptInvocation"] | components["schemas"]["DenoiseLatentsInvocation"] | components["schemas"]["ConditioningCollectionInvocation"] | components["schemas"]["CV2InfillInvocation"] | components["schemas"]["MergeMetadataInvocation"] | components["schemas"]["ShowImageInvocation"] | components["schemas"]["StringJoinThreeInvocation"] | components["schemas"]["StringSplitNegInvocation"] | components["schemas"]["ImageChannelInvocation"] | components["schemas"]["StringJoinInvocation"] | components["schemas"]["MidasDepthImageProcessorInvocation"] | components["schemas"]["PromptsFromFileInvocation"] | components["schemas"]["StepParamEasingInvocation"]; }; /** * Edges @@ -3253,7 +3257,7 @@ export type components = { * @description The results of node executions */ results: { - [key: string]: components["schemas"]["String2Output"] | components["schemas"]["UNetOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["SchedulerOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["SeamlessModeOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["ONNXModelLoaderOutput"] | components["schemas"]["SDXLLoraLoaderOutput"]; + [key: string]: components["schemas"]["SeamlessModeOutput"] | components["schemas"]["CollectInvocationOutput"] | components["schemas"]["ImageCollectionOutput"] | components["schemas"]["FloatCollectionOutput"] | components["schemas"]["MetadataItemOutput"] | components["schemas"]["LatentsOutput"] | components["schemas"]["GraphInvocationOutput"] | components["schemas"]["BooleanOutput"] | components["schemas"]["SDXLLoraLoaderOutput"] | components["schemas"]["IPAdapterOutput"] | components["schemas"]["ControlOutput"] | components["schemas"]["T2IAdapterOutput"] | components["schemas"]["ONNXModelLoaderOutput"] | components["schemas"]["String2Output"] | components["schemas"]["IterateInvocationOutput"] | components["schemas"]["SDXLModelLoaderOutput"] | components["schemas"]["FloatOutput"] | components["schemas"]["VAEOutput"] | components["schemas"]["ColorOutput"] | components["schemas"]["BooleanCollectionOutput"] | components["schemas"]["ConditioningCollectionOutput"] | components["schemas"]["FaceMaskOutput"] | components["schemas"]["LatentsCollectionOutput"] | components["schemas"]["StringOutput"] | components["schemas"]["StringCollectionOutput"] | components["schemas"]["IntegerOutput"] | components["schemas"]["UNetOutput"] | components["schemas"]["DenoiseMaskOutput"] | components["schemas"]["ConditioningOutput"] | components["schemas"]["FaceOffOutput"] | components["schemas"]["IntegerCollectionOutput"] | components["schemas"]["ImageOutput"] | components["schemas"]["StringPosNegOutput"] | components["schemas"]["ModelLoaderOutput"] | components["schemas"]["ColorCollectionOutput"] | components["schemas"]["LoraLoaderOutput"] | components["schemas"]["SDXLRefinerModelLoaderOutput"] | components["schemas"]["CLIPOutput"] | components["schemas"]["MetadataOutput"] | components["schemas"]["NoiseOutput"] | components["schemas"]["ClipSkipInvocationOutput"] | components["schemas"]["SchedulerOutput"]; }; /** * Errors @@ -4727,6 +4731,7 @@ export type components = { /** * Seed * @description The seed to use for tile generation (omit for random) + * @default 0 */ seed?: number; /** @@ -4761,6 +4766,7 @@ export type components = { /** * Collection * @description The collection of integer values + * @default [] */ collection?: number[]; /** @@ -4940,6 +4946,7 @@ export type components = { /** * Collection * @description The list of items to iterate over + * @default [] */ collection?: unknown[]; /** @@ -6342,6 +6349,7 @@ export type components = { /** * Seed * @description Seed for random number generation + * @default 0 */ seed?: number; /** @@ -7194,6 +7202,7 @@ export type components = { /** * Seed * @description The seed for the RNG (omit for random) + * @default 0 */ seed?: number; /** @@ -8532,6 +8541,7 @@ export type components = { /** * Collection * @description The collection of string values + * @default [] */ collection?: string[]; /** @@ -9529,6 +9539,24 @@ export type components = { * @enum {string} */ invokeai__backend__model_manager__config__SchedulerPredictionType: "epsilon" | "v_prediction" | "sample"; + /** + * FieldKind + * @description The kind of field. + * - `Input`: An input field on a node. + * - `Output`: An output field on a node. + * - `Internal`: A field which is treated as an input, but cannot be used in node definitions. Metadata is + * one example. It is provided to nodes via the WithMetadata class, and we want to reserve the field name + * "metadata" for this on all nodes. `FieldKind` is used to short-circuit the field name validation logic, + * allowing "metadata" for that field. + * - `NodeAttribute`: The field is a node attribute. These are fields which are not inputs or outputs, + * but which are used to store information about the node. For example, the `id` and `type` fields are node + * attributes. + * + * The presence of this in `json_schema_extra["field_kind"]` is used when initializing node schemas on app + * startup, and when generating the OpenAPI schema for the workflow editor. + * @enum {string} + */ + FieldKind: "input" | "output" | "internal" | "node_attribute"; /** * Input * @description The type of input a field accepts. @@ -9538,9 +9566,65 @@ export type components = { * @enum {string} */ Input: "connection" | "direct" | "any"; + /** + * InputFieldJSONSchemaExtra + * @description Extra attributes to be added to input fields and their OpenAPI schema. Used during graph execution, + * and by the workflow editor during schema parsing and UI rendering. + */ + InputFieldJSONSchemaExtra: { + input: components["schemas"]["Input"]; + /** Orig Required */ + orig_required: boolean; + field_kind: components["schemas"]["FieldKind"]; + /** + * Default + * @default null + */ + default: unknown; + /** + * Orig Default + * @default null + */ + orig_default: unknown; + /** + * Ui Hidden + * @default false + */ + ui_hidden: boolean; + /** @default null */ + ui_type: components["schemas"]["UIType"] | null; + /** @default null */ + ui_component: components["schemas"]["UIComponent"] | null; + /** + * Ui Order + * @default null + */ + ui_order: number | null; + /** + * Ui Choice Labels + * @default null + */ + ui_choice_labels: { + [key: string]: string; + } | null; + }; + /** + * OutputFieldJSONSchemaExtra + * @description Extra attributes to be added to input fields and their OpenAPI schema. Used by the workflow editor + * during schema parsing and UI rendering. + */ + OutputFieldJSONSchemaExtra: { + field_kind: components["schemas"]["FieldKind"]; + /** Ui Hidden */ + ui_hidden: boolean; + ui_type: components["schemas"]["UIType"] | null; + /** Ui Order */ + ui_order: number | null; + }; /** * UIComponent - * @description The type of UI component to use for a field, used to override the default components, which are inferred from the field type. + * @description The type of UI component to use for a field, used to override the default components, which are + * inferred from the field type. * @enum {string} */ UIComponent: "none" | "textarea" | "slider"; @@ -9570,89 +9654,73 @@ export type components = { /** * Version * @description The node's version. Should be a valid semver string e.g. "1.0.0" or "3.8.13". - * @default null */ - version: string | null; + version: string; + /** + * Is Custom + * @description Whether or not this is a custom node + * @default false + */ + is_custom: boolean; }; /** * UIType - * @description Type hints for the UI. - * If a field should be provided a data type that does not exactly match the python type of the field, use this to provide the type that should be used instead. See the node development docs for detail on adding a new field type, which involves client-side changes. + * @description Type hints for the UI for situations in which the field type is not enough to infer the correct UI type. + * + * - Model Fields + * The most common node-author-facing use will be for model fields. Internally, there is no difference + * between SD-1, SD-2 and SDXL model fields - they all use the class `MainModelField`. To ensure the + * base-model-specific UI is rendered, use e.g. `ui_type=UIType.SDXLMainModelField` to indicate that + * the field is an SDXL main model field. + * + * - Any Field + * We cannot infer the usage of `typing.Any` via schema parsing, so you *must* use `ui_type=UIType.Any` to + * indicate that the field accepts any type. Use with caution. This cannot be used on outputs. + * + * - Scheduler Field + * Special handling in the UI is needed for this field, which otherwise would be parsed as a plain enum field. + * + * - Internal Fields + * Similar to the Any Field, the `collect` and `iterate` nodes make use of `typing.Any`. To facilitate + * handling these types in the client, we use `UIType._Collection` and `UIType._CollectionItem`. These + * should not be used by node authors. + * + * - DEPRECATED Fields + * These types are deprecated and should not be used by node authors. A warning will be logged if one is + * used, and the type will be ignored. They are included here for backwards compatibility. * @enum {string} */ - UIType: "boolean" | "ColorField" | "ConditioningField" | "ControlField" | "float" | "ImageField" | "integer" | "LatentsField" | "string" | "BooleanCollection" | "ColorCollection" | "ConditioningCollection" | "ControlCollection" | "FloatCollection" | "ImageCollection" | "IntegerCollection" | "LatentsCollection" | "StringCollection" | "BooleanPolymorphic" | "ColorPolymorphic" | "ConditioningPolymorphic" | "ControlPolymorphic" | "FloatPolymorphic" | "ImagePolymorphic" | "IntegerPolymorphic" | "LatentsPolymorphic" | "StringPolymorphic" | "MainModelField" | "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VaeModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "UNetField" | "VaeField" | "ClipField" | "Collection" | "CollectionItem" | "enum" | "Scheduler" | "WorkflowField" | "IsIntermediate" | "BoardField" | "Any" | "MetadataItem" | "MetadataItemCollection" | "MetadataItemPolymorphic" | "MetadataDict"; - /** - * _InputField - * @description *DO NOT USE* - * This helper class is used to tell the client about our custom field attributes via OpenAPI - * schema generation, and Typescript type generation from that schema. It serves no functional - * purpose in the backend. - */ - _InputField: { - input: components["schemas"]["Input"]; - /** Ui Hidden */ - ui_hidden: boolean; - ui_type: components["schemas"]["UIType"] | null; - ui_component: components["schemas"]["UIComponent"] | null; - /** Ui Order */ - ui_order: number | null; - /** Ui Choice Labels */ - ui_choice_labels: { - [key: string]: string; - } | null; - /** Item Default */ - item_default: unknown; - }; - /** - * _OutputField - * @description *DO NOT USE* - * This helper class is used to tell the client about our custom field attributes via OpenAPI - * schema generation, and Typescript type generation from that schema. It serves no functional - * purpose in the backend. - */ - _OutputField: { - /** Ui Hidden */ - ui_hidden: boolean; - ui_type: components["schemas"]["UIType"] | null; - /** Ui Order */ - ui_order: number | null; - }; - /** - * IPAdapterModelFormat - * @description An enumeration. - * @enum {string} - */ - IPAdapterModelFormat: "invokeai"; - /** - * StableDiffusionXLModelFormat - * @description An enumeration. - * @enum {string} - */ - StableDiffusionXLModelFormat: "checkpoint" | "diffusers"; - /** - * ControlNetModelFormat - * @description An enumeration. - * @enum {string} - */ - ControlNetModelFormat: "checkpoint" | "diffusers"; - /** - * StableDiffusion2ModelFormat - * @description An enumeration. - * @enum {string} - */ - StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; + UIType: "SDXLMainModelField" | "SDXLRefinerModelField" | "ONNXModelField" | "VAEModelField" | "LoRAModelField" | "ControlNetModelField" | "IPAdapterModelField" | "SchedulerField" | "AnyField" | "CollectionField" | "CollectionItemField" | "DEPRECATED_Boolean" | "DEPRECATED_Color" | "DEPRECATED_Conditioning" | "DEPRECATED_Control" | "DEPRECATED_Float" | "DEPRECATED_Image" | "DEPRECATED_Integer" | "DEPRECATED_Latents" | "DEPRECATED_String" | "DEPRECATED_BooleanCollection" | "DEPRECATED_ColorCollection" | "DEPRECATED_ConditioningCollection" | "DEPRECATED_ControlCollection" | "DEPRECATED_FloatCollection" | "DEPRECATED_ImageCollection" | "DEPRECATED_IntegerCollection" | "DEPRECATED_LatentsCollection" | "DEPRECATED_StringCollection" | "DEPRECATED_BooleanPolymorphic" | "DEPRECATED_ColorPolymorphic" | "DEPRECATED_ConditioningPolymorphic" | "DEPRECATED_ControlPolymorphic" | "DEPRECATED_FloatPolymorphic" | "DEPRECATED_ImagePolymorphic" | "DEPRECATED_IntegerPolymorphic" | "DEPRECATED_LatentsPolymorphic" | "DEPRECATED_StringPolymorphic" | "DEPRECATED_MainModel" | "DEPRECATED_UNet" | "DEPRECATED_Vae" | "DEPRECATED_CLIP" | "DEPRECATED_Collection" | "DEPRECATED_CollectionItem" | "DEPRECATED_Enum" | "DEPRECATED_WorkflowField" | "DEPRECATED_IsIntermediate" | "DEPRECATED_BoardField" | "DEPRECATED_MetadataItem" | "DEPRECATED_MetadataItemCollection" | "DEPRECATED_MetadataItemPolymorphic" | "DEPRECATED_MetadataDict"; /** * T2IAdapterModelFormat * @description An enumeration. * @enum {string} */ T2IAdapterModelFormat: "diffusers"; + /** + * StableDiffusionXLModelFormat + * @description An enumeration. + * @enum {string} + */ + StableDiffusionXLModelFormat: "checkpoint" | "diffusers"; + /** + * StableDiffusion2ModelFormat + * @description An enumeration. + * @enum {string} + */ + StableDiffusion2ModelFormat: "checkpoint" | "diffusers"; /** * StableDiffusionOnnxModelFormat * @description An enumeration. * @enum {string} */ StableDiffusionOnnxModelFormat: "olive" | "onnx"; + /** + * ControlNetModelFormat + * @description An enumeration. + * @enum {string} + */ + ControlNetModelFormat: "checkpoint" | "diffusers"; /** * CLIPVisionModelFormat * @description An enumeration. @@ -9665,6 +9733,12 @@ export type components = { * @enum {string} */ StableDiffusion1ModelFormat: "checkpoint" | "diffusers"; + /** + * IPAdapterModelFormat + * @description An enumeration. + * @enum {string} + */ + IPAdapterModelFormat: "invokeai"; }; responses: never; parameters: never; diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index ce3e75a584..3c5e54536e 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -1,6 +1,7 @@ import { UseToastOptions } from '@chakra-ui/react'; import { EntityState } from '@reduxjs/toolkit'; import { components, paths } from './schema'; +import { O } from 'ts-toolbelt'; type s = components['schemas']; @@ -27,8 +28,8 @@ export type BatchConfig = export type EnqueueBatchResult = components['schemas']['EnqueueBatchResult']; -export type _InputField = s['_InputField']; -export type _OutputField = s['_OutputField']; +export type InputFieldJSONSchemaExtra = s['InputFieldJSONSchemaExtra']; +export type OutputFieldJSONSchemaExtra = s['OutputFieldJSONSchemaExtra']; // App Info export type AppVersion = s['AppVersion']; @@ -57,6 +58,7 @@ export type MainModelField = s['MainModelField']; export type OnnxModelField = s['OnnxModelField']; export type VAEModelField = s['VAEModelField']; export type LoRAModelField = s['LoRAModelField']; +export type LoRAModelFormat = s['LoRAModelFormat']; export type ControlNetModelField = s['ControlNetModelField']; export type IPAdapterModelField = s['IPAdapterModelField']; export type T2IAdapterModelField = s['T2IAdapterModelField']; @@ -105,6 +107,7 @@ export type ImportModelConfig = s['Body_import_model']; // Graphs export type Graph = s['Graph']; +export type NonNullableGraph = O.Required; export type Edge = s['Edge']; export type GraphExecutionState = s['GraphExecutionState']; export type Batch = s['Batch']; diff --git a/invokeai/frontend/web/src/services/events/types.ts b/invokeai/frontend/web/src/services/events/types.ts index 543107bb13..b1d7f14731 100644 --- a/invokeai/frontend/web/src/services/events/types.ts +++ b/invokeai/frontend/web/src/services/events/types.ts @@ -1,5 +1,4 @@ import { components } from 'services/api/schema'; -import { O } from 'ts-toolbelt'; import { BaseModelType, Graph, @@ -17,11 +16,6 @@ export type ProgressImage = { height: number; }; -export type AnyInvocationType = O.Required< - NonNullable[string]>, - 'type' ->['type']; - export type AnyInvocation = NonNullable[string]>; export type AnyResult = NonNullable; From ed79980dd4fb85319c8d9f7480cb91383bf9f400 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 25 Nov 2023 21:10:22 +1100 Subject: [PATCH 06/65] feat(ui): improved UI for missing node field templates When a node is updated with new fields and workflow needs to be updated, the fields now display "Unknown input/output: FieldName". --- invokeai/frontend/web/public/locales/en.json | 8 ++-- .../nodes/Invocation/fields/InputField.tsx | 38 +++++++++++++++---- .../nodes/Invocation/fields/OutputField.tsx | 36 +++++++++++++++--- .../flow/panels/TopLeftPanel/TopLeftPanel.tsx | 15 +++++--- .../nodes/hooks/useFieldInputInstance.ts | 28 ++++++++++++++ .../nodes/hooks/useFieldInputTemplate.ts | 29 ++++++++++++++ .../nodes/hooks/useFieldOutputInstance.ts | 28 ++++++++++++++ .../nodes/hooks/useFieldOutputTemplate.ts | 29 ++++++++++++++ 8 files changed, 188 insertions(+), 23 deletions(-) create mode 100644 invokeai/frontend/web/src/features/nodes/hooks/useFieldInputInstance.ts create mode 100644 invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts create mode 100644 invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputInstance.ts create mode 100644 invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputTemplate.ts diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index faa870bd32..6019854862 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -977,16 +977,16 @@ "unhandledInputProperty": "Unhandled input property", "unhandledOutputProperty": "Unhandled output property", "unknownField": "Unknown field", - "unknownFieldType": "$(nodes.unknownField) type", + "unknownFieldType": "$t(nodes.unknownField) type", "unknownNode": "Unknown Node", "unknownNodeType":"$t(nodes.unknownNode) type", "unknownTemplate": "Unknown Template", - "unknownInput": "Unknown input", + "unknownInput": "Unknown input: {{name}}", "unkownInvocation": "Unknown Invocation type", - "unknownOutput": "Unknown output", + "unknownOutput": "Unknown output: {{name}}", "updateNode": "Update Node", "updateApp": "Update App", - "updateAllNodes": "Update All Nodes", + "updateAllNodes": "Update Nodes", "allNodesUpdated": "All Nodes Updated", "unableToUpdateNodes_one": "Unable to update {{count}} node", "unableToUpdateNodes_other": "Unable to update {{count}} nodes", diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx index dac9404c26..4d6269e5f4 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputField.tsx @@ -1,13 +1,14 @@ import { Box, Flex, FormControl, FormLabel } from '@chakra-ui/react'; import { useConnectionState } from 'features/nodes/hooks/useConnectionState'; import { useDoesInputHaveValue } from 'features/nodes/hooks/useDoesInputHaveValue'; -import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate'; +import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance'; +import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate'; import { PropsWithChildren, memo, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; import EditableFieldTitle from './EditableFieldTitle'; import FieldContextMenu from './FieldContextMenu'; import FieldHandle from './FieldHandle'; import InputFieldRenderer from './InputFieldRenderer'; -import { useTranslation } from 'react-i18next'; interface Props { nodeId: string; @@ -16,7 +17,8 @@ interface Props { const InputField = ({ nodeId, fieldName }: Props) => { const { t } = useTranslation(); - const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'input'); + const fieldTemplate = useFieldInputTemplate(nodeId, fieldName); + const fieldInstance = useFieldInputInstance(nodeId, fieldName); const doesFieldHaveValue = useDoesInputHaveValue(nodeId, fieldName); const { @@ -28,7 +30,7 @@ const InputField = ({ nodeId, fieldName }: Props) => { } = useConnectionState({ nodeId, fieldName, kind: 'input' }); const isMissingInput = useMemo(() => { - if (fieldTemplate?.fieldKind !== 'input') { + if (!fieldTemplate) { return false; } @@ -45,13 +47,35 @@ const InputField = ({ nodeId, fieldName }: Props) => { } }, [fieldTemplate, isConnected, doesFieldHaveValue]); - if (fieldTemplate?.fieldKind !== 'input') { + if (!fieldTemplate || !fieldInstance) { return ( - {t('nodes.unknownInput')}: {fieldName} + + {t('nodes.unknownInput', { + name: fieldInstance?.label ?? fieldTemplate?.title ?? fieldName, + })} + ); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx index 4b7ca647f8..994510ef99 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputField.tsx @@ -1,11 +1,12 @@ import { Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react'; import { useConnectionState } from 'features/nodes/hooks/useConnectionState'; -import { useFieldTemplate } from 'features/nodes/hooks/useFieldTemplate'; +import { useFieldOutputInstance } from 'features/nodes/hooks/useFieldOutputInstance'; +import { useFieldOutputTemplate } from 'features/nodes/hooks/useFieldOutputTemplate'; import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants'; import { PropsWithChildren, memo } from 'react'; +import { useTranslation } from 'react-i18next'; import FieldHandle from './FieldHandle'; import FieldTooltipContent from './FieldTooltipContent'; -import { useTranslation } from 'react-i18next'; interface Props { nodeId: string; @@ -14,7 +15,8 @@ interface Props { const OutputField = ({ nodeId, fieldName }: Props) => { const { t } = useTranslation(); - const fieldTemplate = useFieldTemplate(nodeId, fieldName, 'output'); + const fieldTemplate = useFieldOutputTemplate(nodeId, fieldName); + const fieldInstance = useFieldOutputInstance(nodeId, fieldName); const { isConnected, @@ -24,13 +26,35 @@ const OutputField = ({ nodeId, fieldName }: Props) => { shouldDim, } = useConnectionState({ nodeId, fieldName, kind: 'output' }); - if (fieldTemplate?.fieldKind !== 'output') { + if (!fieldTemplate || !fieldInstance) { return ( - {t('nodes.unknownOutput')}: {fieldName} + + {t('nodes.unknownOutput', { + name: fieldTemplate?.title ?? fieldName, + })} + ); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopLeftPanel/TopLeftPanel.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopLeftPanel/TopLeftPanel.tsx index 38aa9bbad7..73d1508c93 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopLeftPanel/TopLeftPanel.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopLeftPanel/TopLeftPanel.tsx @@ -1,13 +1,13 @@ import { Flex } from '@chakra-ui/layout'; import { useAppDispatch } from 'app/store/storeHooks'; -import IAIIconButton from 'common/components/IAIIconButton'; -import { addNodePopoverOpened } from 'features/nodes/store/nodesSlice'; -import { memo, useCallback } from 'react'; -import { FaPlus, FaSync } from 'react-icons/fa'; -import { useTranslation } from 'react-i18next'; import IAIButton from 'common/components/IAIButton'; +import IAIIconButton from 'common/components/IAIIconButton'; import { useGetNodesNeedUpdate } from 'features/nodes/hooks/useGetNodesNeedUpdate'; import { updateAllNodesRequested } from 'features/nodes/store/actions'; +import { addNodePopoverOpened } from 'features/nodes/store/nodesSlice'; +import { memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { FaExclamationTriangle, FaPlus } from 'react-icons/fa'; const TopLeftPanel = () => { const dispatch = useAppDispatch(); @@ -29,7 +29,10 @@ const TopLeftPanel = () => { onClick={handleOpenAddNodePopover} /> {nodesNeedUpdate && ( - } onClick={handleClickUpdateNodes}> + } + onClick={handleClickUpdateNodes} + > {t('nodes.updateAllNodes')} )} diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputInstance.ts new file mode 100644 index 0000000000..8e95e0fd5b --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputInstance.ts @@ -0,0 +1,28 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import { useMemo } from 'react'; +import { isInvocationNode } from '../types/invocation'; + +export const useFieldInputInstance = (nodeId: string, fieldName: string) => { + const selector = useMemo( + () => + createSelector( + stateSelector, + ({ nodes }) => { + const node = nodes.nodes.find((node) => node.id === nodeId); + if (!isInvocationNode(node)) { + return; + } + return node.data.inputs[fieldName]; + }, + defaultSelectorOptions + ), + [fieldName, nodeId] + ); + + const fieldTemplate = useAppSelector(selector); + + return fieldTemplate; +}; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts new file mode 100644 index 0000000000..0f682b53b1 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts @@ -0,0 +1,29 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import { useMemo } from 'react'; +import { isInvocationNode } from '../types/invocation'; + +export const useFieldInputTemplate = (nodeId: string, fieldName: string) => { + const selector = useMemo( + () => + createSelector( + stateSelector, + ({ nodes }) => { + const node = nodes.nodes.find((node) => node.id === nodeId); + if (!isInvocationNode(node)) { + return; + } + const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? '']; + return nodeTemplate?.inputs[fieldName]; + }, + defaultSelectorOptions + ), + [fieldName, nodeId] + ); + + const fieldTemplate = useAppSelector(selector); + + return fieldTemplate; +}; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputInstance.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputInstance.ts new file mode 100644 index 0000000000..0020d334d5 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputInstance.ts @@ -0,0 +1,28 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import { useMemo } from 'react'; +import { isInvocationNode } from '../types/invocation'; + +export const useFieldOutputInstance = (nodeId: string, fieldName: string) => { + const selector = useMemo( + () => + createSelector( + stateSelector, + ({ nodes }) => { + const node = nodes.nodes.find((node) => node.id === nodeId); + if (!isInvocationNode(node)) { + return; + } + return node.data.outputs[fieldName]; + }, + defaultSelectorOptions + ), + [fieldName, nodeId] + ); + + const fieldTemplate = useAppSelector(selector); + + return fieldTemplate; +}; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputTemplate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputTemplate.ts new file mode 100644 index 0000000000..e8d0f0899c --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputTemplate.ts @@ -0,0 +1,29 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; +import { useMemo } from 'react'; +import { isInvocationNode } from '../types/invocation'; + +export const useFieldOutputTemplate = (nodeId: string, fieldName: string) => { + const selector = useMemo( + () => + createSelector( + stateSelector, + ({ nodes }) => { + const node = nodes.nodes.find((node) => node.id === nodeId); + if (!isInvocationNode(node)) { + return; + } + const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? '']; + return nodeTemplate?.outputs[fieldName]; + }, + defaultSelectorOptions + ), + [fieldName, nodeId] + ); + + const fieldTemplate = useAppSelector(selector); + + return fieldTemplate; +}; From 858bcdd3ff780e46d157402d404e373d29f8d3a8 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 25 Nov 2023 21:39:27 +1100 Subject: [PATCH 07/65] feat(nodes): improve docstrings in baseinvocation, disambiguate method names --- invokeai/app/invocations/baseinvocation.py | 70 ++++++++++++++-------- invokeai/app/services/shared/graph.py | 4 +- 2 files changed, 46 insertions(+), 28 deletions(-) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index cddbd071de..59978c13c1 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -93,6 +93,10 @@ class UIType(str, Enum, metaclass=MetaEnum): Similar to the Any Field, the `collect` and `iterate` nodes make use of `typing.Any`. To facilitate handling these types in the client, we use `UIType._Collection` and `UIType._CollectionItem`. These should not be used by node authors. + + - DEPRECATED Fields + These types are deprecated and should not be used by node authors. A warning will be logged if one is + used, and the type will be ignored. They are included here for backwards compatibility. """ # region Model Field Types @@ -173,10 +177,8 @@ class UIComponent(str, Enum, metaclass=MetaEnum): class InputFieldJSONSchemaExtra(BaseModel): """ - *DO NOT USE* - This helper class is used to tell the client about our custom field attributes via OpenAPI - schema generation, and Typescript type generation from that schema. It serves no functional - purpose in the backend. + Extra attributes to be added to input fields and their OpenAPI schema. Used during graph execution, + and by the workflow editor during schema parsing and UI rendering. """ input: Input @@ -198,10 +200,8 @@ class InputFieldJSONSchemaExtra(BaseModel): class OutputFieldJSONSchemaExtra(BaseModel): """ - *DO NOT USE* - This helper class is used to tell the client about our custom field attributes via OpenAPI - schema generation, and Typescript type generation from that schema. It serves no functional - purpose in the backend. + Extra attributes to be added to input fields and their OpenAPI schema. Used by the workflow editor + during schema parsing and UI rendering. """ field_kind: FieldKind @@ -215,11 +215,6 @@ class OutputFieldJSONSchemaExtra(BaseModel): ) -def get_type(klass: BaseModel) -> str: - """Helper function to get an invocation or invocation output's type. This is the default value of the `type` field.""" - return klass.model_fields["type"].default - - def InputField( # copied from pydantic's Field # TODO: Can we support default_factory? @@ -483,29 +478,39 @@ class BaseInvocationOutput(BaseModel): @classmethod def register_output(cls, output: BaseInvocationOutput) -> None: + """Registers an invocation output.""" cls._output_classes.add(output) @classmethod def get_outputs(cls) -> Iterable[BaseInvocationOutput]: + """Gets all invocation outputs.""" return cls._output_classes @classmethod def get_outputs_union(cls) -> UnionType: + """Gets a union of all invocation outputs.""" outputs_union = Union[tuple(cls._output_classes)] # type: ignore [valid-type] return outputs_union # type: ignore [return-value] @classmethod def get_output_types(cls) -> Iterable[str]: - return (get_type(i) for i in BaseInvocationOutput.get_outputs()) + """Gets all invocation output types.""" + return (i.get_type() for i in BaseInvocationOutput.get_outputs()) @staticmethod def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None: + """Adds various UI-facing attributes to the invocation output's OpenAPI schema.""" # Because we use a pydantic Literal field with default value for the invocation type, # it will be typed as optional in the OpenAPI schema. Make it required manually. if "required" not in schema or not isinstance(schema["required"], list): schema["required"] = [] schema["required"].extend(["type"]) + @classmethod + def get_type(cls) -> str: + """Gets the invocation output's type, as provided by the `@invocation_output` decorator.""" + return cls.model_fields["type"].default + model_config = ConfigDict( protected_namespaces=(), validate_assignment=True, @@ -535,21 +540,29 @@ class BaseInvocation(ABC, BaseModel): _invocation_classes: ClassVar[set[BaseInvocation]] = set() + @classmethod + def get_type(cls) -> str: + """Gets the invocation's type, as provided by the `@invocation` decorator.""" + return cls.model_fields["type"].default + @classmethod def register_invocation(cls, invocation: BaseInvocation) -> None: + """Registers an invocation.""" cls._invocation_classes.add(invocation) @classmethod def get_invocations_union(cls) -> UnionType: + """Gets a union of all invocation types.""" invocations_union = Union[tuple(cls._invocation_classes)] # type: ignore [valid-type] return invocations_union # type: ignore [return-value] @classmethod def get_invocations(cls) -> Iterable[BaseInvocation]: + """Gets all invocations, respecting the allowlist and denylist.""" app_config = InvokeAIAppConfig.get_config() allowed_invocations: set[BaseInvocation] = set() for sc in cls._invocation_classes: - invocation_type = get_type(sc) + invocation_type = sc.get_type() is_in_allowlist = ( invocation_type in app_config.allow_nodes if isinstance(app_config.allow_nodes, list) else True ) @@ -562,20 +575,22 @@ class BaseInvocation(ABC, BaseModel): @classmethod def get_invocations_map(cls) -> dict[str, BaseInvocation]: - # Get the type strings out of the literals and into a dictionary - return {get_type(i): i for i in BaseInvocation.get_invocations()} + """Gets a map of all invocation types to their invocation classes.""" + return {i.get_type(): i for i in BaseInvocation.get_invocations()} @classmethod def get_invocation_types(cls) -> Iterable[str]: - return (get_type(i) for i in BaseInvocation.get_invocations()) + """Gets all invocation types.""" + return (i.get_type() for i in BaseInvocation.get_invocations()) @classmethod - def get_output_type(cls) -> BaseInvocationOutput: + def get_output_annotation(cls) -> BaseInvocationOutput: + """Gets the invocation's output annotation (i.e. the return annotation of its `invoke()` method).""" return signature(cls.invoke).return_annotation @staticmethod def json_schema_extra(schema: dict[str, Any], model_class: Type[BaseModel], *args, **kwargs) -> None: - # Add the various UI-facing attributes to the schema. These are used to build the invocation templates. + """Adds various UI-facing attributes to the invocation's OpenAPI schema.""" uiconfig = getattr(model_class, "UIConfig", None) if uiconfig and hasattr(uiconfig, "title"): schema["title"] = uiconfig.title @@ -595,6 +610,10 @@ class BaseInvocation(ABC, BaseModel): pass def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput: + """ + Internal invoke method, calls `invoke()` after some prep. + Handles optional fields that are required to call `invoke()` and invocation cache. + """ for field_name, field in self.model_fields.items(): if not field.json_schema_extra or callable(field.json_schema_extra): # something has gone terribly awry, we should always have this and it should be a dict @@ -634,9 +653,6 @@ class BaseInvocation(ABC, BaseModel): context.services.logger.debug(f'Skipping invocation cache for "{self.get_type()}": {self.id}') return self.invoke(context) - def get_type(self) -> str: - return self.model_fields["type"].default - id: str = Field( default_factory=uuid_string, description="The id of this instance of an invocation. Must be unique among all instances of invocations.", @@ -693,9 +709,11 @@ RESERVED_PYDANTIC_FIELD_NAMES = {m[0] for m in inspect.getmembers(_Model())} def validate_fields(model_fields: dict[str, FieldInfo], model_type: str) -> None: """ Validates the fields of an invocation or invocation output: - - must not override any pydantic reserved fields - - must not end with "Collection" or "Polymorphic" as these are reserved for internal use - - must be created via `InputField`, `OutputField`, or be an internal field defined in this file + - Must not override any pydantic reserved fields + - Must have a type annotation + - Must have a json_schema_extra dict + - Must have field_kind in json_schema_extra + - Field name must not be reserved, according to its field_kind """ for name, field in model_fields.items(): if name in RESERVED_PYDANTIC_FIELD_NAMES: diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index ee86ef17c6..0d97c0b9a1 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -49,7 +49,7 @@ class Edge(BaseModel): def get_output_field(node: BaseInvocation, field: str) -> Any: node_type = type(node) - node_outputs = get_type_hints(node_type.get_output_type()) + node_outputs = get_type_hints(node_type.get_output_annotation()) node_output_field = node_outputs.get(field) or None return node_output_field @@ -379,7 +379,7 @@ class Graph(BaseModel): raise NodeNotFoundError(f"Edge destination node {edge.destination.node_id} does not exist in the graph") # output fields are not on the node object directly, they are on the output type - if edge.source.field not in source_node.get_output_type().model_fields: + if edge.source.field not in source_node.get_output_annotation().model_fields: raise NodeFieldNotFoundError( f"Edge source field {edge.source.field} does not exist in node {edge.source.node_id}" ) From 514c49d946686d91ee3c3ce8538e59f0e88c7690 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sat, 25 Nov 2023 21:40:12 +1100 Subject: [PATCH 08/65] feat(nodes): warn if node has no version specified; fall back on 1.0.0 --- invokeai/app/invocations/baseinvocation.py | 3 +++ invokeai/app/services/shared/graph.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 59978c13c1..22e68ea3cb 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -811,6 +811,9 @@ def invocation( except ValueError as e: raise InvalidVersionError(f'Invalid version string for node "{invocation_type}": "{version}"') from e cls.UIConfig.version = version + else: + logger.warn(f'No version specified for node "{invocation_type}", using "1.0.0"') + cls.UIConfig.version = "1.0.0" if use_cache is not None: cls.model_fields["use_cache"].default = use_cache diff --git a/invokeai/app/services/shared/graph.py b/invokeai/app/services/shared/graph.py index 0d97c0b9a1..c825a84011 100644 --- a/invokeai/app/services/shared/graph.py +++ b/invokeai/app/services/shared/graph.py @@ -188,7 +188,7 @@ class GraphInvocationOutput(BaseInvocationOutput): # TODO: Fill this out and move to invocations -@invocation("graph") +@invocation("graph", version="1.0.0") class GraphInvocation(BaseInvocation): """Execute a graph""" From ab944bd13ae34aac98547a5809219204f5c00383 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 26 Nov 2023 00:34:51 +1100 Subject: [PATCH 09/65] feat(ui): remove `docs/` from prettierignore --- invokeai/frontend/web/.prettierignore | 1 - 1 file changed, 1 deletion(-) diff --git a/invokeai/frontend/web/.prettierignore b/invokeai/frontend/web/.prettierignore index bdf02d5c9e..05782f1f53 100644 --- a/invokeai/frontend/web/.prettierignore +++ b/invokeai/frontend/web/.prettierignore @@ -9,6 +9,5 @@ index.html .yalc/ *.scss src/services/api/schema.d.ts -docs/ static/ src/theme/css/overlayscrollbars.css From 803fb393bb0d137128da3468e7ac4a0a06828a23 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 26 Nov 2023 00:35:46 +1100 Subject: [PATCH 10/65] fix(ui): fix mis-named typeguard --- invokeai/frontend/web/src/features/nodes/types/invocation.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/frontend/web/src/features/nodes/types/invocation.ts b/invokeai/frontend/web/src/features/nodes/types/invocation.ts index 216db437b9..70403169de 100644 --- a/invokeai/frontend/web/src/features/nodes/types/invocation.ts +++ b/invokeai/frontend/web/src/features/nodes/types/invocation.ts @@ -71,7 +71,7 @@ export const isInvocationNode = ( export const isNotesNode = ( node?: Node ): node is Node => Boolean(node && node.type === 'notes'); -export const isProgressImageNode = ( +export const isCurrentImageNode = ( node?: Node ): node is Node => Boolean(node && node.type === 'current_image'); From 5386a286fd218ec6442e3c4bfae673392a0e0c0b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 26 Nov 2023 02:27:58 +1100 Subject: [PATCH 11/65] feat(ui): constrain w/h in imageoutput schema --- invokeai/frontend/web/src/features/nodes/types/common.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/types/common.ts b/invokeai/frontend/web/src/features/nodes/types/common.ts index 0cab248c80..460b301685 100644 --- a/invokeai/frontend/web/src/features/nodes/types/common.ts +++ b/invokeai/frontend/web/src/features/nodes/types/common.ts @@ -206,8 +206,8 @@ export type ProgressImage = z.infer; // #region ImageOutput export const zImageOutput = z.object({ image: zImageField, - width: z.number().int(), - height: z.number().int(), + width: z.number().int().gt(0), + height: z.number().int().gt(0), type: z.literal('image_output'), }); export type ImageOutput = z.infer; From 296741306c7cc9fa2cb4538633e84b4793bccdea Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 26 Nov 2023 02:35:44 +1100 Subject: [PATCH 12/65] feat(ui): update frontend README --- invokeai/frontend/web/docs/README.md | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/invokeai/frontend/web/docs/README.md b/invokeai/frontend/web/docs/README.md index 4783999419..5f9e3c2c55 100644 --- a/invokeai/frontend/web/docs/README.md +++ b/invokeai/frontend/web/docs/README.md @@ -13,6 +13,7 @@ - [Vite](#vite) - [i18next & Weblate](#i18next--weblate) - [openapi-typescript](#openapi-typescript) + - [reactflow](#reactflow) - [Client Types Generation](#client-types-generation) - [Package Scripts](#package-scripts) - [Contributing](#contributing) @@ -26,7 +27,7 @@ The UI is a fairly straightforward Typescript React app. ## Core Libraries -The app makes heavy use of a handful of libraries. +InvokeAI's UI is made possible by a number of excellent open-source libraries. The most heavily-used are listed below, but there are many others. ### Redux Toolkit @@ -57,12 +58,20 @@ We use [redux-remember](https://github.com/zewish/redux-remember) for persistenc ### i18next & Weblate -We use [i18next](https://github.com/i18next/react-i18next) for localisation, but translation to languages other than English happens on our [Weblate](https://hosted.weblate.org/engage/invokeai/) project. **Only the English source strings should be changed on this repo.** +We use [i18next](https://github.com/i18next/react-i18next) for localization, but translation to languages other than English happens on our [Weblate](https://hosted.weblate.org/engage/invokeai/) project. **Only the English source strings should be changed on this repo.** ### openapi-typescript [openapi-typescript](https://github.com/drwpow/openapi-typescript) is used to generate types from the server's OpenAPI schema. See TYPES_CODEGEN.md. +### reactflow + +[reactflow](https://github.com/xyflow/xyflow) powers the Workflow Editor. + +### zod + +[zod](https://github.com/colinhacks/zod) schemas are used to model data structures and provide runtime validation. + ## Client Types Generation We use [`openapi-typescript`](https://github.com/drwpow/openapi-typescript) to generate types from the app's OpenAPI schema. From 8f2cf3019139f654f6afa1d8890404ae3da716f7 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 26 Nov 2023 02:36:14 +1100 Subject: [PATCH 13/65] feat(ui): add workflows design & implementation doc (WIP) --- .../docs/WORKFLOWS_DESIGN_IMPLEMENTATION.md | 160 ++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 invokeai/frontend/web/docs/WORKFLOWS_DESIGN_IMPLEMENTATION.md diff --git a/invokeai/frontend/web/docs/WORKFLOWS_DESIGN_IMPLEMENTATION.md b/invokeai/frontend/web/docs/WORKFLOWS_DESIGN_IMPLEMENTATION.md new file mode 100644 index 0000000000..150f06b45d --- /dev/null +++ b/invokeai/frontend/web/docs/WORKFLOWS_DESIGN_IMPLEMENTATION.md @@ -0,0 +1,160 @@ +# Workflows - Design and Implementation + + + + + +- [Workflows - Design and Implementation](#workflows---design-and-implementation) + - [Linear UI](#linear-ui) + - [Workflow Editor](#workflow-editor) + - [Workflows](#workflows) + - [Workflow -> reactflow state -> InvokeAI graph](#workflow---reactflow-state---invokeai-graph) + - [Nodes vs Invocations](#nodes-vs-invocations) + - [Workflow Linear View](#workflow-linear-view) + - [OpenAPI Schema Parsing](#openapi-schema-parsing) + - [Field Instances and Templates](#field-instances-and-templates) + - [Stateful vs Stateless Fields](#stateful-vs-stateless-fields) + - [Collection and Polymorphic Fields](#collection-and-polymorphic-fields) + - [Implementation](#implementation) + + + +InvokeAI's backend uses graphs, composed of **nodes** and **edges**, to process data and generate images. + +Nodes have any number of **input fields** and one **output field**. Edges connect nodes together via their inputs and outputs. + +During execution, a nodes' output may be passed along to any number of other nodes' inputs. + +We provide two ways to build graphs in the frontend: the [Linear UI](#linear-ui) and [Workflow Editor](#workflow-editor). + +## Linear UI + +This includes the **Text to Image**, **Image to Image** and **Unified Canvas** tabs. + +The user-managed parameters on these tabs are stored as simple objects in the application state. When the user invokes, adding a generation to the queue, we internally build a graph from these parameters. + +This logic can be fairly complex due to the range of features available and their interactions. Depending on the parameters selected, the graph may be very different. Building graphs in code can be challenging - you are trying to construct a non-linear structure in a linear context. + +The simplest graph building logic is for **Text to Image** with a SD1.5 model: +`invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts` + +There are many other graph builders in the same folder for different tabs or base models (e.g. SDXL). Some are pretty hairy. + +In the Linear UI, we go straight from **simple application state** to **graph** via these builders. + +## Workflow Editor + +The Workflow Editor is a visual graph editor, allowing users to draw edges from node to node to construct a graph. This _far_ more approachable way to create complex graphs. + +InvokeAI uses the [reactflow](https://github.com/xyflow/xyflow) library to power the Workflow Editor. It provides both a graph editor UI and manages its own internal graph state. + +### Workflows + +So far, we've described two different graph representations used by InvokeAI - the InvokeAI execution graph and the reactflow state. + +Neither of these is sufficient to represent a _workflow_, though. A workflow must have a representation of a its graph's nodes and edges, but it also has other data: + +- Name +- Description +- Version +- Notes +- [Exposed fields](#workflow-linear-view) +- Author, tags, category, etc. + +Workflows should have other qualities: + +- Portable: you should be able to load a workflow created by another person. +- Resilient: you should be able to "upgrade" a workflow as the application changes. +- Abstract: as much as is possible, workflows should not be married to the specific implementation details of the application. + +To support these qualities, workflows are serializable, have a versioned schemas, and represent graphs as minimally as possible. Fortunately, the reactflow state for nodes and edges works perfectly for this.. + +#### Workflow -> reactflow state -> InvokeAI graph + +Given a workflow, we need to be able to derive reactflow state and/or an InvokeAI graph from it. + +The first step - workflow to reactflow state - is very simple. The logic is in `invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts`, in the `workflowLoaded` reducer. + +The reactflow state is, however, structurally incompatible with our backend's graph structure. When a user invokes on a Workflow, we need to convert the reactflow state into an InvokeAI graph. This is far simpler than the graph building logic from the Linear UI: +`invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts` + +#### Nodes vs Invocations + +We often use the terms "node" and "invocation" interchangeably, but they may refer to different things in the frontend. + +reactflow [has its own definitions](https://reactflow.dev/learn/concepts/terms-and-definitions) of "node", "edge" and "handle" which are closely related to InvokeAI graph concepts. + +- A reactflow node is related to an InvokeAI invocation. It has a "data" property, which holds the InvokeAI-specific invocation data. +- A reactflow edge is roughly equivalent to an InvokeAI edge. +- A reactflow handle is roughly equivalent to an InvokeAI input or output field. + +#### Workflow Linear View + +Graphs are very capable data structures, but not everyone wants to work with them all the time. + +To allow less technical users - or anyone who wants a less visually noisy workspace - to benefit from the power of nodes, InvokeAI has a workflow feature called the Linear View. + +A workflow input field can be added to this Linear View, and its input component can be presented similarly to the Linear UI tabs. Internally, we add the field to the workflow's list of exposed fields. + +### OpenAPI Schema Parsing + +OpenAPI is a schema specification that can represent complex data structures and relationships. The backend is capable of generating an OpenAPI schema for all invocations. + +When the UI connects, it requests this schema and parses each invocation into an **invocation template**. Invocation templates have a number of properties, like title, description and type, but the most important ones are their input and output **field templates**. + +Invocation and field templates are the "source of truth" for graphs, because they indicate what the backend is able to process. + +When a user adds a new node to their workflow, these templates are used to instantiate a node with fields instantiated from the input and output field templates. + +#### Field Instances and Templates + +Field templates consist of: + +- Name: the identifier of the field, its variable name in python +- Type: derived from the field's type annotation in python (e.g. IntegerField, ImageField, MainModelField) +- Constraints: derived from the field's creation args in python (e.g. minimum value for an integer) +- Default value: optionally provided in the field's creation args (e.g. 42 for an integer) + +Field instances are created from the templates and have name, type and optionally a value. + +The type of the field determines the UI components that are rendered for it. + +A field instance's name associates it with its template. + +#### Stateful vs Stateless Fields + +**Stateful** fields store their value in the frontend graph. Think primitives, model identifiers, images, etc. Fields are only stateful if the frontend allows the user to directly input a value for them. + +Many field types, however, are **stateless**. An example is a `UNetField`, which contains some data describing a UNet. Users cannot directly provide this data - it is created and consumed in the backend. + +Stateless fields do not store their value in the node, so their field instances do not have values. + +"Custom" fields will always be treated as stateless fields. + +#### Collection and Polymorphic Fields + +Field types have a name and two flags which may identify it as a **collection** or **polymorphic** field. + +If a field is annotated in python as a list, its field type is parsed and flagged as a collection type (e.g. `list[int]`). + +If it is annotated as a union of a type and list, the type will be flagged as a polymorphic type (e.g. `Union[int, list[int]]`). Fields may not be unions of different types (e.g. `Union[int, list[str]]` and `Union[int, str]` are not allowed). + +## Implementation + +The majority of data structures in the backend are [pydantic](https://github.com/pydantic/pydantic) models. Pydantic provides OpenAPI schemas for all models and we then generate TypeScript types from those. + +Workflows and all related data are modeled in the frontend using [zod](https://github.com/colinhacks/zod). Related types are inferred from the zod schemas. + +### Schemas and Types + +The schemas, inferred types, type guards and related constants are in `invokeai/frontend/web/src/features/nodes/types/`. + +Roughly in order from lowest-level to highest: + +- `common.ts`: stateful field data, and couple other misc types +- `field.ts`: fields - types, values, instances, templates +- `metadata.ts`: core metadata +- `invocation.ts`: invocations and other node types +- `workflow.ts`: workflows and constituents + +### Workflow Migrations From e85f2254f02d6ef17b6556268eb3f16e407e0ddd Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 26 Nov 2023 20:43:10 +1100 Subject: [PATCH 14/65] feat(ui): update fields docstring --- invokeai/frontend/web/src/features/nodes/types/field.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts index dd1c50f6e3..26f2a72ee2 100644 --- a/invokeai/frontend/web/src/features/nodes/types/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -13,10 +13,10 @@ import { } from './common'; /** - * zod schemas & inferred types for input field values. + * zod schemas & inferred types for fields. * - * These schemas and types are only required for field types that have UI components and allow the - * user to directly provide values. + * These schemas and types are only required for stateful field - fields that have UI components + * and allow the user to directly provide values. * * This includes primitive values (numbers, strings, booleans), models, scheduler, etc. * From a703e1b3d30cabbee1facd6f24b2c84aa156568c Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 26 Nov 2023 20:44:39 +1100 Subject: [PATCH 15/65] feat(ui): add errors for invalid polymorphic types --- invokeai/frontend/web/public/locales/en.json | 44 ++++++++----------- .../src/features/nodes/util/parseFieldType.ts | 16 +++++++ 2 files changed, 34 insertions(+), 26 deletions(-) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 6019854862..a1b2c2b6e7 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -1,7 +1,7 @@ { "accessibility": { "copyMetadataJson": "Copy metadata JSON", - "createIssue":"Create Issue", + "createIssue": "Create Issue", "exitViewer": "Exit Viewer", "flipHorizontally": "Flip Horizontally", "flipVertically": "Flip Vertically", @@ -13,7 +13,7 @@ "nextImage": "Next Image", "previousImage": "Previous Image", "reset": "Reset", - "resetUI":"$t(accessibility.reset) UI", + "resetUI": "$t(accessibility.reset) UI", "rotateClockwise": "Rotate Clockwise", "rotateCounterClockwise": "Rotate Counter-Clockwise", "showGalleryPanel": "Show Gallery Panel", @@ -59,7 +59,7 @@ "back": "Back", "batch": "Batch Manager", "cancel": "Cancel", - "copyError":"$t(gallery.copy) Error", + "copyError": "$t(gallery.copy) Error", "close": "Close", "on": "On", "checkpoint": "Checkpoint", @@ -76,7 +76,7 @@ "error": "Error", "file": "File", "folder": "Folder", - "format":"format", + "format": "format", "generate": "Generate", "githubLabel": "Github", "hotkeysLabel": "Hotkeys", @@ -355,9 +355,9 @@ "autoSwitchNewImages": "Auto-Switch to New Images", "copy": "Copy", "currentlyInUse": "This image is currently in use in the following features:", - "drop":"Drop", - "dropOrUpload":"$t(gallery.drop) or Upload", - "dropToUpload":"$t(gallery.drop) to Upload", + "drop": "Drop", + "dropOrUpload": "$t(gallery.drop) or Upload", + "dropToUpload": "$t(gallery.drop) to Upload", "deleteImage": "Delete Image", "deleteImageBin": "Deleted images will be sent to your operating system's Bin.", "deleteImagePermanent": "Deleted images cannot be restored.", @@ -775,7 +775,7 @@ "esrganModel": "ESRGAN Model", "loading": "loading", "noLoRAsAvailable": "No LoRAs available", - "noLoRAsLoaded":"No LoRAs Loaded", + "noLoRAsLoaded": "No LoRAs Loaded", "noMatchingLoRAs": "No matching LoRAs", "noMatchingModels": "No matching Models", "noModelsAvailable": "No models available", @@ -787,7 +787,7 @@ "nodes": { "addNode": "Add Node", "addNodeToolTip": "Add Node (Shift+A, Space)", - "addLinearView":"Add to Linear View", + "addLinearView": "Add to Linear View", "animatedEdges": "Animated Edges", "animatedEdgesHelp": "Animate selected edges and edges connected to selected nodes", "boardField": "Board", @@ -971,6 +971,8 @@ "outputFieldTypeParseError": "Unable to parse type of output field {{node}}.{{field}} ({{message}})", "unableToExtractSchemaNameFromRef": "unable to extract schema name from ref", "unsupportedArrayItemType": "unsupported array item type \"{{type}}\"", + "unsupportedAnyOfLength": "too many union members ({{count}})", + "unsupportedMismatchedUnion": "mismatched polymorphic type with members {{firstType}} and {{secondType}}", "unableToParseFieldType": "unable to parse field type", "uNetField": "UNet", "uNetFieldDescription": "UNet submodel.", @@ -979,7 +981,7 @@ "unknownField": "Unknown field", "unknownFieldType": "$t(nodes.unknownField) type", "unknownNode": "Unknown Node", - "unknownNodeType":"$t(nodes.unknownNode) type", + "unknownNodeType": "$t(nodes.unknownNode) type", "unknownTemplate": "Unknown Template", "unknownInput": "Unknown input: {{name}}", "unkownInvocation": "Unknown Invocation type", @@ -1353,15 +1355,11 @@ }, "compositingBlur": { "heading": "Blur", - "paragraphs": [ - "The blur radius of the mask." - ] + "paragraphs": ["The blur radius of the mask."] }, "compositingBlurMethod": { "heading": "Blur Method", - "paragraphs": [ - "The method of blur applied to the masked area." - ] + "paragraphs": ["The method of blur applied to the masked area."] }, "compositingCoherencePass": { "heading": "Coherence Pass", @@ -1371,9 +1369,7 @@ }, "compositingCoherenceMode": { "heading": "Mode", - "paragraphs": [ - "The mode of the Coherence Pass." - ] + "paragraphs": ["The mode of the Coherence Pass."] }, "compositingCoherenceSteps": { "heading": "Steps", @@ -1391,9 +1387,7 @@ }, "compositingMaskAdjustments": { "heading": "Mask Adjustments", - "paragraphs": [ - "Adjust the mask." - ] + "paragraphs": ["Adjust the mask."] }, "controlNetBeginEnd": { "heading": "Begin / End Step Percentage", @@ -1451,9 +1445,7 @@ }, "infillMethod": { "heading": "Infill Method", - "paragraphs": [ - "Method to infill the selected area." - ] + "paragraphs": ["Method to infill the selected area."] }, "lora": { "heading": "LoRA Weight", @@ -1593,7 +1585,7 @@ "redo": "Redo", "resetView": "Reset View", "saveBoxRegionOnly": "Save Box Region Only", - "saveMask":"Save $t(unifiedCanvas.mask)", + "saveMask": "Save $t(unifiedCanvas.mask)", "saveToGallery": "Save To Gallery", "scaledBoundingBox": "Scaled Bounding Box", "showCanvasDebugInfo": "Show Additional Canvas Info", diff --git a/invokeai/frontend/web/src/features/nodes/util/parseFieldType.ts b/invokeai/frontend/web/src/features/nodes/util/parseFieldType.ts index 133a3d11c9..2d25ab9faa 100644 --- a/invokeai/frontend/web/src/features/nodes/util/parseFieldType.ts +++ b/invokeai/frontend/web/src/features/nodes/util/parseFieldType.ts @@ -115,6 +115,15 @@ export const parseFieldType = ( * Any other cases we ignore. */ + if (filteredAnyOf.length !== 2) { + // This is a union of more than 2 types, which we don't support + throw new UnsupportedFieldTypeError( + t('nodes.unsupportedAnyOfLength', { + count: filteredAnyOf.length, + }) + ); + } + let firstType: string | undefined; let secondType: string | undefined; @@ -154,6 +163,13 @@ export const parseFieldType = ( isPolymorphic: true, // <-- don't forget, polymorphic! }; } + + throw new UnsupportedFieldTypeError( + t('nodes.unsupportedMismatchedUnion', { + firstType, + secondType, + }) + ); } } else if (schemaObject.enum) { return { name: 'EnumField', isCollection: false, isPolymorphic: false }; From ad9c954a58e5392c4dc7e0fa816afa4894d4adf5 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 26 Nov 2023 21:06:38 +1100 Subject: [PATCH 16/65] feat(ui): move field output template builder to own file --- .../nodes/util/buildFieldOutputTemplate.ts | 24 +++++++++++++++++++ .../src/features/nodes/util/parseSchema.ts | 23 +++++++----------- 2 files changed, 33 insertions(+), 14 deletions(-) create mode 100644 invokeai/frontend/web/src/features/nodes/util/buildFieldOutputTemplate.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/buildFieldOutputTemplate.ts b/invokeai/frontend/web/src/features/nodes/util/buildFieldOutputTemplate.ts new file mode 100644 index 0000000000..05e3c66386 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/buildFieldOutputTemplate.ts @@ -0,0 +1,24 @@ +import { startCase } from 'lodash-es'; +import { FieldOutputTemplate, FieldType } from '../types/field'; +import { InvocationFieldSchema } from '../types/openapi'; + +export const buildFieldOutputTemplate = ( + fieldSchema: InvocationFieldSchema, + fieldName: string, + fieldType: FieldType +): FieldOutputTemplate => { + const { title, description, ui_hidden, ui_type, ui_order } = fieldSchema; + + const fieldOutputTemplate: FieldOutputTemplate = { + fieldKind: 'output', + name: fieldName, + title: title ?? (fieldName ? startCase(fieldName) : ''), + description: description ?? '', + type: fieldType, + ui_hidden, + ui_type, + ui_order, + }; + + return fieldOutputTemplate; +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts index 2c59b6cb14..81d79d2976 100644 --- a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts +++ b/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts @@ -1,7 +1,9 @@ import { logger } from 'app/logging/logger'; import { parseify } from 'common/util/serialize'; -import { reduce, startCase } from 'lodash-es'; +import { t } from 'i18next'; +import { reduce } from 'lodash-es'; import { OpenAPIV3_1 } from 'openapi-types'; +import { FieldTypeParseError, UnsupportedFieldTypeError } from '../types/error'; import { FieldInputTemplate, FieldOutputTemplate } from '../types/field'; import { InvocationTemplate } from '../types/invocation'; import { @@ -11,9 +13,8 @@ import { isInvocationSchemaObject, } from '../types/openapi'; import { buildFieldInputTemplate } from './buildFieldInputTemplate'; +import { buildFieldOutputTemplate } from './buildFieldOutputTemplate'; import { parseFieldType } from './parseFieldType'; -import { FieldTypeParseError, UnsupportedFieldTypeError } from '../types/error'; -import { t } from 'i18next'; const RESERVED_INPUT_FIELD_NAMES = ['id', 'type', 'use_cache']; const RESERVED_OUTPUT_FIELD_NAMES = ['type']; @@ -209,17 +210,11 @@ export const parseSchema = ( return outputsAccumulator; } - const fieldOutputTemplate: FieldOutputTemplate = { - fieldKind: 'output', - name: propertyName, - title: - property.title ?? (propertyName ? startCase(propertyName) : ''), - description: property.description ?? '', - type: fieldType, - ui_hidden: property.ui_hidden ?? false, - ui_type: property.ui_type, - ui_order: property.ui_order, - }; + const fieldOutputTemplate = buildFieldOutputTemplate( + property, + propertyName, + fieldType + ); outputsAccumulator[propertyName] = fieldOutputTemplate; } catch (e) { From 654591cbf30068ef8c1eaadbe2b8b88fce2a3748 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 26 Nov 2023 21:07:49 +1100 Subject: [PATCH 17/65] feat(ui): make buildFieldInputTemplate arg name consistent --- .../web/src/features/nodes/util/buildFieldInputTemplate.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/util/buildFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/util/buildFieldInputTemplate.ts index 8d11ac25b9..0deddf0dea 100644 --- a/invokeai/frontend/web/src/features/nodes/util/buildFieldInputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/util/buildFieldInputTemplate.ts @@ -326,7 +326,7 @@ export const TEMPLATE_BUILDER_MAP: Record< export const buildFieldInputTemplate = ( fieldSchema: InvocationFieldSchema, - name: string, + fieldName: string, fieldType: FieldType ): FieldInputTemplate => { const { @@ -342,8 +342,8 @@ export const buildFieldInputTemplate = ( // This is the base field template that is common to all fields. The builder function will add all other // properties to this template. const baseField: Omit = { - name, - title: fieldSchema.title ?? (name ? startCase(name) : ''), + name: fieldName, + title: fieldSchema.title ?? (fieldName ? startCase(fieldName) : ''), required, description: fieldSchema.description ?? '', fieldKind: 'input' as const, From 42370939a803a8948d6bb34266b08c2233ed215a Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 26 Nov 2023 21:21:45 +1100 Subject: [PATCH 18/65] feat(ui): update workflows design & implementation docs (wip) --- .../docs/WORKFLOWS_DESIGN_IMPLEMENTATION.md | 161 ++++++++++++++---- 1 file changed, 131 insertions(+), 30 deletions(-) diff --git a/invokeai/frontend/web/docs/WORKFLOWS_DESIGN_IMPLEMENTATION.md b/invokeai/frontend/web/docs/WORKFLOWS_DESIGN_IMPLEMENTATION.md index 150f06b45d..70013499d0 100644 --- a/invokeai/frontend/web/docs/WORKFLOWS_DESIGN_IMPLEMENTATION.md +++ b/invokeai/frontend/web/docs/WORKFLOWS_DESIGN_IMPLEMENTATION.md @@ -5,29 +5,47 @@ - [Workflows - Design and Implementation](#workflows---design-and-implementation) - - [Linear UI](#linear-ui) - - [Workflow Editor](#workflow-editor) - - [Workflows](#workflows) - - [Workflow -> reactflow state -> InvokeAI graph](#workflow---reactflow-state---invokeai-graph) - - [Nodes vs Invocations](#nodes-vs-invocations) - - [Workflow Linear View](#workflow-linear-view) - - [OpenAPI Schema Parsing](#openapi-schema-parsing) - - [Field Instances and Templates](#field-instances-and-templates) - - [Stateful vs Stateless Fields](#stateful-vs-stateless-fields) - - [Collection and Polymorphic Fields](#collection-and-polymorphic-fields) + - [Design](#design) + - [Linear UI](#linear-ui) + - [Workflow Editor](#workflow-editor) + - [Workflows](#workflows) + - [Workflow -\> reactflow state -\> InvokeAI graph](#workflow---reactflow-state---invokeai-graph) + - [Nodes vs Invocations](#nodes-vs-invocations) + - [Workflow Linear View](#workflow-linear-view) + - [OpenAPI Schema](#openapi-schema) + - [Field Instances and Templates](#field-instances-and-templates) + - [Stateful vs Stateless Fields](#stateful-vs-stateless-fields) + - [Collection and Polymorphic Fields](#collection-and-polymorphic-fields) - [Implementation](#implementation) + - [zod Schemas and Types](#zod-schemas-and-types) + - [OpenAPI Schema Parsing](#openapi-schema-parsing) + - [Parsing Field Types](#parsing-field-types) + - [Primitive Types](#primitive-types) + - [Complex Types](#complex-types) + - [Collection Types](#collection-types) + - [Polymorphic Types](#polymorphic-types) + - [Optional Fields](#optional-fields) + - [Building Field Input Templates](#building-field-input-templates) + - [Building Field Output Templates](#building-field-output-templates) + - [Workflow Migrations](#workflow-migrations) InvokeAI's backend uses graphs, composed of **nodes** and **edges**, to process data and generate images. -Nodes have any number of **input fields** and one **output field**. Edges connect nodes together via their inputs and outputs. +Nodes have any number of **input fields** and **output fields**. Edges connect nodes together via their inputs and outputs. Fields have data types which dictate how they may be connected. -During execution, a nodes' output may be passed along to any number of other nodes' inputs. +During execution, a nodes' outputs may be passed along to any number of other nodes' inputs. -We provide two ways to build graphs in the frontend: the [Linear UI](#linear-ui) and [Workflow Editor](#workflow-editor). +Workflows are an enriched abstraction over a graph. -## Linear UI +## Design + +InvokeAI provide two ways to build graphs in the frontend: the [Linear UI](#linear-ui) and [Workflow Editor](#workflow-editor). + +To better understand the use case and challenges related to workflows, we will review both of these modes. + +### Linear UI This includes the **Text to Image**, **Image to Image** and **Unified Canvas** tabs. @@ -42,17 +60,15 @@ There are many other graph builders in the same folder for different tabs or bas In the Linear UI, we go straight from **simple application state** to **graph** via these builders. -## Workflow Editor +### Workflow Editor The Workflow Editor is a visual graph editor, allowing users to draw edges from node to node to construct a graph. This _far_ more approachable way to create complex graphs. InvokeAI uses the [reactflow](https://github.com/xyflow/xyflow) library to power the Workflow Editor. It provides both a graph editor UI and manages its own internal graph state. -### Workflows +#### Workflows -So far, we've described two different graph representations used by InvokeAI - the InvokeAI execution graph and the reactflow state. - -Neither of these is sufficient to represent a _workflow_, though. A workflow must have a representation of a its graph's nodes and edges, but it also has other data: +A workflow is a representation of a graph plus additional metadata: - Name - Description @@ -69,7 +85,7 @@ Workflows should have other qualities: To support these qualities, workflows are serializable, have a versioned schemas, and represent graphs as minimally as possible. Fortunately, the reactflow state for nodes and edges works perfectly for this.. -#### Workflow -> reactflow state -> InvokeAI graph +##### Workflow -> reactflow state -> InvokeAI graph Given a workflow, we need to be able to derive reactflow state and/or an InvokeAI graph from it. @@ -78,7 +94,7 @@ The first step - workflow to reactflow state - is very simple. The logic is in ` The reactflow state is, however, structurally incompatible with our backend's graph structure. When a user invokes on a Workflow, we need to convert the reactflow state into an InvokeAI graph. This is far simpler than the graph building logic from the Linear UI: `invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts` -#### Nodes vs Invocations +##### Nodes vs Invocations We often use the terms "node" and "invocation" interchangeably, but they may refer to different things in the frontend. @@ -88,7 +104,7 @@ reactflow [has its own definitions](https://reactflow.dev/learn/concepts/terms-a - A reactflow edge is roughly equivalent to an InvokeAI edge. - A reactflow handle is roughly equivalent to an InvokeAI input or output field. -#### Workflow Linear View +##### Workflow Linear View Graphs are very capable data structures, but not everyone wants to work with them all the time. @@ -96,7 +112,7 @@ To allow less technical users - or anyone who wants a less visually noisy worksp A workflow input field can be added to this Linear View, and its input component can be presented similarly to the Linear UI tabs. Internally, we add the field to the workflow's list of exposed fields. -### OpenAPI Schema Parsing +#### OpenAPI Schema OpenAPI is a schema specification that can represent complex data structures and relationships. The backend is capable of generating an OpenAPI schema for all invocations. @@ -106,7 +122,7 @@ Invocation and field templates are the "source of truth" for graphs, because the When a user adds a new node to their workflow, these templates are used to instantiate a node with fields instantiated from the input and output field templates. -#### Field Instances and Templates +##### Field Instances and Templates Field templates consist of: @@ -121,7 +137,7 @@ The type of the field determines the UI components that are rendered for it. A field instance's name associates it with its template. -#### Stateful vs Stateless Fields +##### Stateful vs Stateless Fields **Stateful** fields store their value in the frontend graph. Think primitives, model identifiers, images, etc. Fields are only stateful if the frontend allows the user to directly input a value for them. @@ -131,7 +147,7 @@ Stateless fields do not store their value in the node, so their field instances "Custom" fields will always be treated as stateless fields. -#### Collection and Polymorphic Fields +##### Collection and Polymorphic Fields Field types have a name and two flags which may identify it as a **collection** or **polymorphic** field. @@ -143,18 +159,103 @@ If it is annotated as a union of a type and list, the type will be flagged as a The majority of data structures in the backend are [pydantic](https://github.com/pydantic/pydantic) models. Pydantic provides OpenAPI schemas for all models and we then generate TypeScript types from those. +The OpenAPI schema is parsed at runtime into our invocation templates. + Workflows and all related data are modeled in the frontend using [zod](https://github.com/colinhacks/zod). Related types are inferred from the zod schemas. -### Schemas and Types +> In python, invocations are pydantic models with fields. These fields become inputs. The invocation's `invoke()` function returns a pydantic model - its output. Like the invocation itself, the output model has any number of fields, which become outputs. -The schemas, inferred types, type guards and related constants are in `invokeai/frontend/web/src/features/nodes/types/`. +### zod Schemas and Types -Roughly in order from lowest-level to highest: +The zod schemas, inferred types, and type guards are in `invokeai/frontend/web/src/features/nodes/types/`. + +Roughly order from lowest-level to highest: - `common.ts`: stateful field data, and couple other misc types - `field.ts`: fields - types, values, instances, templates -- `metadata.ts`: core metadata - `invocation.ts`: invocations and other node types - `workflow.ts`: workflows and constituents +We customize the OpenAPI schema to include additional properties on invocation and field schemas. To facilitate parsing this schema into templates, we modify/wrap the types from [openapi-types](https://github.com/kogosoftwarellc/open-api/tree/main/packages/openapi-types) in `openapi.ts`. + +### OpenAPI Schema Parsing + +The entrypoint for the OpenAPI schema parsing is `invokeai/frontend/web/src/features/nodes/util/parseSchema.ts`. + +General logic flow: + +- Iterate over all invocation schema objects + - Extract relevant invocation-level attributes (e.g. title, type, version, etc) + - Iterate over the invocation's input fields + - [Parse each field's type](#parsing-field-types) + - [Build a field input template](#building-field-input-templates) from the type - either a stateful template or "generic" stateless template + - Iterate over the invocation's output fields + - Parse the field's type (same as inputs) + - [Build a field output template](#building-field-output-templates) + - Assemble the attributes and fields into an invocation template + +Most of these involve very straightforward `reduce`s, but the less intuitive steps are detailed below. + +#### Parsing Field Types + +Field types are represented as structured objects: + +```ts +type FieldType = { + name: string; + isCollection: boolean; + isPolymorphic: boolean; +}; +``` + +The parsing logic is in `invokeai/frontend/web/src/features/nodes/util/parseFieldType.ts`. + +There are 4 general cases for field type parsing. + +##### Primitive Types + +When a field is annotated as a primitive values (e.g. `int`, `str`, `float`), the field type parsing is fairly straightforward. The field is represented by a simple OpenAPI **schema object**, which has a `type` property. + +We create a field type name from this `type` string (e.g. `string` -> `StringField`). + +##### Complex Types + +When a field is annotated as a pydantic model (e.g. `ImageField`, `MainModelField`, `ControlField`), it is represented as a **reference object**. Reference objects are pointers to another schema or reference object within the schema. + +We need to **dereference**[^dereference] the schema to pull these out. Dereferencing may require recursion. We use the reference object's name directly for the field type name. + +##### Collection Types + +When a field is annotated as a list of a single type, the schema object has an `items` property. They may be a schema object or reference object and must be parsed to determine the item type. + +We use the item type for field type name, adding `isCollection: true` to the field type. + +##### Polymorphic Types + +When a field is annotated as a union of a type and list of that type, the schema object has an `anyOf` property, which holds a list of valid types for the union. + +After verifying that the union has two members (a type and list of the same type), we use the type for field type name, adding `isPolymorphic: true` to the field type. + +##### Optional Fields + +In OpenAPI v3.1, when an object is optional, it is put into an `anyOf` along with a primitive schema object with `type: 'null'`. + +Handling this adds a fair bit of complexity, as we now must filter out the `'null'` types and work with the remaining types as described above. + +If there is a single remaining schema object, we must recursively call to `parseFieldType()` to get parse it. + +[^dereference]: Unfortunately, at this time, we've had limited success using external libraries to deference at runtime, so we do this ourselves. + +#### Building Field Input Templates + +Now that we have a field type, we can build an input template for the field. This logic is in `invokeai/frontend/web/src/features/nodes/util/buildFieldInputTemplate.ts`. + +Stateful fields all get a function to build their template, while stateless fields are constructed directly. This is possible because stateless fields have no default value or constraints. + +#### Building Field Output Templates + +Field outputs are similar to stateless fields - they do not have any value in the frontend. When building their templates, we don't need a special function for each field type. + +The logic is in `invokeai/frontend/web/src/features/nodes/util/buildFieldOutputTemplate.ts`. + ### Workflow Migrations From 4309f3bd58f2ab440981d6b9123b9aade15c3a9b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 26 Nov 2023 23:33:18 +1100 Subject: [PATCH 19/65] feat(ui): tidy node-related types --- .../features/nodes/hooks/useBuildNodeData.ts | 4 +-- .../src/features/nodes/store/nodesSlice.ts | 10 ++++---- .../web/src/features/nodes/store/types.ts | 14 +++++------ .../nodes/store/util/buildNodeData.ts | 6 +---- .../web/src/features/nodes/types/constants.ts | 9 +++++++ .../src/features/nodes/types/invocation.ts | 25 ++++++++++--------- 6 files changed, 36 insertions(+), 32 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNodeData.ts b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNodeData.ts index 694261d943..46f11b3823 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNodeData.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNodeData.ts @@ -9,7 +9,7 @@ import { buildNotesNode, } from '../store/util/buildNodeData'; import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from '../types/constants'; -import { AnyNodeData, InvocationTemplate } from '../types/invocation'; +import { AnyNode, InvocationTemplate } from '../types/invocation'; const templatesSelector = createSelector( [(state: RootState) => state.nodes], (nodes) => nodes.nodeTemplates @@ -26,7 +26,7 @@ export const useBuildNodeData = () => { return useCallback( // string here is "any invocation type" - (type: string | 'current_image' | 'notes'): Node => { + (type: string | 'current_image' | 'notes'): AnyNode => { let _x = window.innerWidth / 2; let _y = window.innerHeight / 2; diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 0c21d02fed..f91fe04fde 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -28,7 +28,7 @@ import { appSocketQueueItemStatusChanged, } from 'services/events/actions'; import { v4 as uuidv4 } from 'uuid'; -import { DRAG_HANDLE_CLASSNAME } from '../types/constants'; +import { SHARED_NODE_PROPERTIES } from '../types/constants'; import { BoardFieldValue, BooleanFieldValue, @@ -50,7 +50,7 @@ import { VAEModelFieldValue, } from '../types/field'; import { - AnyNodeData, + AnyNode, InvocationTemplate, isInvocationNode, isNotesNode, @@ -157,7 +157,7 @@ const nodesSlice = createSlice({ } state.nodes[nodeIndex] = action.payload.node; }, - nodeAdded: (state, action: PayloadAction>) => { + nodeAdded: (state, action: PayloadAction) => { const node = action.payload; const position = findUnoccupiedPosition( state.nodes, @@ -520,7 +520,7 @@ const nodesSlice = createSlice({ state.edges = applyEdgeChanges(edgeChanges, state.edges); } }, - nodesDeleted: (state, action: PayloadAction[]>) => { + nodesDeleted: (state, action: PayloadAction) => { action.payload.forEach((node) => { state.workflow.exposedFields = state.workflow.exposedFields.filter( (f) => f.nodeId !== node.id @@ -731,7 +731,7 @@ const nodesSlice = createSlice({ state.nodes = applyNodeChanges( nodes.map((node) => ({ - item: { ...node, dragHandle: `.${DRAG_HANDLE_CLASSNAME}` }, + item: { ...node, ...SHARED_NODE_PROPERTIES }, type: 'add', })), [] diff --git a/invokeai/frontend/web/src/features/nodes/store/types.ts b/invokeai/frontend/web/src/features/nodes/store/types.ts index b865b9d3a1..278be3c498 100644 --- a/invokeai/frontend/web/src/features/nodes/store/types.ts +++ b/invokeai/frontend/web/src/features/nodes/store/types.ts @@ -1,6 +1,4 @@ import { - Edge, - Node, OnConnectStartParams, SelectionMode, Viewport, @@ -8,16 +6,16 @@ import { } from 'reactflow'; import { FieldIdentifier, FieldType } from '../types/field'; import { - AnyNodeData, - InvocationEdgeExtra, + AnyNode, + InvocationNodeEdge, InvocationTemplate, NodeExecutionState, } from '../types/invocation'; import { WorkflowV2 } from '../types/workflow'; export type NodesState = { - nodes: Node[]; - edges: Edge[]; + nodes: AnyNode[]; + edges: InvocationNodeEdge[]; nodeTemplates: Record; connectionStartParams: OnConnectStartParams | null; connectionStartFieldType: FieldType | null; @@ -37,8 +35,8 @@ export type NodesState = { isReady: boolean; mouseOverField: FieldIdentifier | null; mouseOverNode: string | null; - nodesToCopy: Node[]; - edgesToCopy: Edge[]; + nodesToCopy: AnyNode[]; + edgesToCopy: InvocationNodeEdge[]; isAddNodePopoverOpen: boolean; addNewNodePosition: XYPosition | null; selectionMode: SelectionMode; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/buildNodeData.ts b/invokeai/frontend/web/src/features/nodes/store/util/buildNodeData.ts index 5328f789ad..c2582600c3 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/buildNodeData.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/buildNodeData.ts @@ -1,4 +1,4 @@ -import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; +import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants'; import { FieldInputInstance, FieldOutputInstance, @@ -14,10 +14,6 @@ import { reduce } from 'lodash-es'; import { Node, XYPosition } from 'reactflow'; import { v4 as uuidv4 } from 'uuid'; -export const SHARED_NODE_PROPERTIES: Partial = { - dragHandle: `.${DRAG_HANDLE_CLASSNAME}`, -}; - export const buildNotesNode = (position: XYPosition): Node => { const nodeId = uuidv4(); const node: Node = { diff --git a/invokeai/frontend/web/src/features/nodes/types/constants.ts b/invokeai/frontend/web/src/features/nodes/types/constants.ts index a97899de91..27cb4fa778 100644 --- a/invokeai/frontend/web/src/features/nodes/types/constants.ts +++ b/invokeai/frontend/web/src/features/nodes/types/constants.ts @@ -1,3 +1,5 @@ +import { Node } from 'reactflow'; + /** * How long to wait before showing a tooltip when hovering a field handle. */ @@ -14,6 +16,13 @@ export const NODE_WIDTH = 320; */ export const DRAG_HANDLE_CLASSNAME = 'node-drag-handle'; +/** + * reactflow-specifc properties shared between all node types. + */ +export const SHARED_NODE_PROPERTIES: Partial = { + dragHandle: `.${DRAG_HANDLE_CLASSNAME}`, +}; + /** * Helper for getting the kind of a field. */ diff --git a/invokeai/frontend/web/src/features/nodes/types/invocation.ts b/invokeai/frontend/web/src/features/nodes/types/invocation.ts index 70403169de..5d22e64545 100644 --- a/invokeai/frontend/web/src/features/nodes/types/invocation.ts +++ b/invokeai/frontend/web/src/features/nodes/types/invocation.ts @@ -1,4 +1,4 @@ -import { Node } from 'reactflow'; +import { Edge, Node } from 'reactflow'; import { z } from 'zod'; import { zProgressImage } from './common'; import { @@ -64,16 +64,16 @@ export type InvocationNodeData = z.infer; export type CurrentImageNodeData = z.infer; export type AnyNodeData = z.infer; -export const isInvocationNode = ( - node?: Node -): node is Node => +export type InvocationNode = Node; +export type NotesNode = Node; +export type CurrentImageNode = Node; +export type AnyNode = Node; + +export const isInvocationNode = (node?: AnyNode): node is InvocationNode => Boolean(node && node.type === 'invocation'); -export const isNotesNode = ( - node?: Node -): node is Node => Boolean(node && node.type === 'notes'); -export const isCurrentImageNode = ( - node?: Node -): node is Node => +export const isNotesNode = (node?: AnyNode): node is NotesNode => + Boolean(node && node.type === 'notes'); +export const isCurrentImageNode = (node?: AnyNode): node is CurrentImageNode => Boolean(node && node.type === 'current_image'); export const isInvocationNodeData = ( node?: AnyNodeData @@ -101,8 +101,9 @@ export type NodeStatus = z.infer; // #endregion // #region Edges -export const zInvocationEdgeExtra = z.object({ +export const zInvocationNodeEdgeExtra = z.object({ type: z.union([z.literal('default'), z.literal('collapsed')]), }); -export type InvocationEdgeExtra = z.infer; +export type InvocationNodeEdgeExtra = z.infer; +export type InvocationNodeEdge = Edge; // #endregion From 8d99113bef753fe7cbfdc9de044672ac24fd1323 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 27 Nov 2023 00:00:44 +1100 Subject: [PATCH 20/65] feat(ui): organize node utils --- .../listeners/enqueueRequestedCanvas.ts | 4 +- .../listeners/enqueueRequestedLinear.ts | 10 ++-- .../listeners/enqueueRequestedNodes.ts | 2 +- .../listeners/receivedOpenAPISchema.ts | 2 +- .../socketio/socketInvocationComplete.ts | 2 +- .../listeners/updateAllNodesRequested.ts | 2 +- .../listeners/upscaleRequested.ts | 2 +- .../listeners/workflowLoadRequested.ts | 2 +- .../flow/AddNodePopover/AddNodePopover.tsx | 4 +- .../inspector/InspectorDetailsTab.tsx | 2 +- .../hooks/useAnyOrDirectInputFieldNames.ts | 4 +- .../{useBuildNodeData.ts => useBuildNode.ts} | 11 ++--- .../hooks/useConnectionInputFieldNames.ts | 4 +- .../nodes/hooks/useGetNodesNeedUpdate.ts | 2 +- .../nodes/hooks/useNodeNeedsUpdate.ts | 2 +- .../nodes/hooks/useOutputFieldNames.ts | 2 +- .../src/features/nodes/hooks/useWorkflow.ts | 2 +- .../addControlNetToLinearGraph.ts | 0 .../{graphBuilders => graph}/addHrfToGraph.ts | 0 .../addIPAdapterToLinearGraph.ts | 0 .../addLinearUIOutputNode.ts | 0 .../addLoRAsToGraph.ts | 0 .../addNSFWCheckerToGraph.ts | 0 .../addSDXLLoRAstoGraph.ts | 0 .../addSDXLRefinerToGraph.ts | 0 .../addSeamlessToLinearGraph.ts | 0 .../addT2IAdapterToLinearGraph.ts | 0 .../{graphBuilders => graph}/addVAEToGraph.ts | 0 .../addWatermarkerToGraph.ts | 0 .../buildAdHocUpscaleGraph.ts | 0 .../buildCanvasGraph.ts | 0 .../buildCanvasImageToImageGraph.ts | 0 .../buildCanvasInpaintGraph.ts | 0 .../buildCanvasOutpaintGraph.ts | 0 .../buildCanvasSDXLImageToImageGraph.ts | 0 .../buildCanvasSDXLInpaintGraph.ts | 0 .../buildCanvasSDXLOutpaintGraph.ts | 0 .../buildCanvasSDXLTextToImageGraph.ts | 0 .../buildCanvasTextToImageGraph.ts | 0 .../buildLinearBatchConfig.ts | 0 .../buildLinearImageToImageGraph.ts | 0 .../buildLinearSDXLImageToImageGraph.ts | 0 .../buildLinearSDXLTextToImageGraph.ts | 0 .../buildLinearTextToImageGraph.ts | 0 .../buildNodesGraph.ts | 2 +- .../{graphBuilders => graph}/constants.ts | 0 .../helpers/craftSDXLStylePrompt.ts | 0 .../util/{graphBuilders => graph}/metadata.ts | 0 .../nodes/util/node/buildCurrentImageNode.ts | 23 +++++++++ .../node/buildInvocationNode.ts} | 49 ++----------------- .../nodes/util/node/buildNotesNode.ts | 22 +++++++++ .../{ => node}/getSortedFilteredFieldNames.ts | 2 +- .../{store/util => util/node}/nodeUpdate.ts | 2 +- .../{ => schema}/buildFieldInputInstance.ts | 2 +- .../{ => schema}/buildFieldInputTemplate.ts | 4 +- .../{ => schema}/buildFieldOutputTemplate.ts | 4 +- .../nodes/util/{ => schema}/parseFieldType.ts | 9 ++-- .../nodes/util/{ => schema}/parseSchema.ts | 11 +++-- .../util/{ => workflow}/buildWorkflow.ts | 4 +- .../util/{ => workflow}/validateWorkflow.ts | 8 +-- 60 files changed, 106 insertions(+), 95 deletions(-) rename invokeai/frontend/web/src/features/nodes/hooks/{useBuildNodeData.ts => useBuildNode.ts} (87%) rename invokeai/frontend/web/src/features/nodes/util/{graphBuilders => graph}/addControlNetToLinearGraph.ts (100%) rename invokeai/frontend/web/src/features/nodes/util/{graphBuilders => graph}/addHrfToGraph.ts (100%) rename invokeai/frontend/web/src/features/nodes/util/{graphBuilders => graph}/addIPAdapterToLinearGraph.ts (100%) rename invokeai/frontend/web/src/features/nodes/util/{graphBuilders => graph}/addLinearUIOutputNode.ts (100%) rename invokeai/frontend/web/src/features/nodes/util/{graphBuilders => graph}/addLoRAsToGraph.ts (100%) rename invokeai/frontend/web/src/features/nodes/util/{graphBuilders => graph}/addNSFWCheckerToGraph.ts (100%) rename invokeai/frontend/web/src/features/nodes/util/{graphBuilders => graph}/addSDXLLoRAstoGraph.ts (100%) rename invokeai/frontend/web/src/features/nodes/util/{graphBuilders => graph}/addSDXLRefinerToGraph.ts (100%) rename invokeai/frontend/web/src/features/nodes/util/{graphBuilders => graph}/addSeamlessToLinearGraph.ts (100%) rename invokeai/frontend/web/src/features/nodes/util/{graphBuilders => graph}/addT2IAdapterToLinearGraph.ts (100%) rename invokeai/frontend/web/src/features/nodes/util/{graphBuilders => graph}/addVAEToGraph.ts (100%) rename invokeai/frontend/web/src/features/nodes/util/{graphBuilders => graph}/addWatermarkerToGraph.ts (100%) rename invokeai/frontend/web/src/features/nodes/util/{graphBuilders => graph}/buildAdHocUpscaleGraph.ts (100%) rename invokeai/frontend/web/src/features/nodes/util/{graphBuilders => graph}/buildCanvasGraph.ts (100%) rename invokeai/frontend/web/src/features/nodes/util/{graphBuilders => graph}/buildCanvasImageToImageGraph.ts (100%) rename invokeai/frontend/web/src/features/nodes/util/{graphBuilders => graph}/buildCanvasInpaintGraph.ts (100%) rename invokeai/frontend/web/src/features/nodes/util/{graphBuilders => graph}/buildCanvasOutpaintGraph.ts (100%) rename invokeai/frontend/web/src/features/nodes/util/{graphBuilders => graph}/buildCanvasSDXLImageToImageGraph.ts (100%) rename invokeai/frontend/web/src/features/nodes/util/{graphBuilders => graph}/buildCanvasSDXLInpaintGraph.ts (100%) rename invokeai/frontend/web/src/features/nodes/util/{graphBuilders => graph}/buildCanvasSDXLOutpaintGraph.ts (100%) rename invokeai/frontend/web/src/features/nodes/util/{graphBuilders => graph}/buildCanvasSDXLTextToImageGraph.ts (100%) rename invokeai/frontend/web/src/features/nodes/util/{graphBuilders => graph}/buildCanvasTextToImageGraph.ts (100%) rename invokeai/frontend/web/src/features/nodes/util/{graphBuilders => graph}/buildLinearBatchConfig.ts (100%) rename invokeai/frontend/web/src/features/nodes/util/{graphBuilders => graph}/buildLinearImageToImageGraph.ts (100%) rename invokeai/frontend/web/src/features/nodes/util/{graphBuilders => graph}/buildLinearSDXLImageToImageGraph.ts (100%) rename invokeai/frontend/web/src/features/nodes/util/{graphBuilders => graph}/buildLinearSDXLTextToImageGraph.ts (100%) rename invokeai/frontend/web/src/features/nodes/util/{graphBuilders => graph}/buildLinearTextToImageGraph.ts (100%) rename invokeai/frontend/web/src/features/nodes/util/{graphBuilders => graph}/buildNodesGraph.ts (98%) rename invokeai/frontend/web/src/features/nodes/util/{graphBuilders => graph}/constants.ts (100%) rename invokeai/frontend/web/src/features/nodes/util/{graphBuilders => graph}/helpers/craftSDXLStylePrompt.ts (100%) rename invokeai/frontend/web/src/features/nodes/util/{graphBuilders => graph}/metadata.ts (100%) create mode 100644 invokeai/frontend/web/src/features/nodes/util/node/buildCurrentImageNode.ts rename invokeai/frontend/web/src/features/nodes/{store/util/buildNodeData.ts => util/node/buildInvocationNode.ts} (62%) create mode 100644 invokeai/frontend/web/src/features/nodes/util/node/buildNotesNode.ts rename invokeai/frontend/web/src/features/nodes/util/{ => node}/getSortedFilteredFieldNames.ts (90%) rename invokeai/frontend/web/src/features/nodes/{store/util => util/node}/nodeUpdate.ts (97%) rename invokeai/frontend/web/src/features/nodes/util/{ => schema}/buildFieldInputInstance.ts (93%) rename invokeai/frontend/web/src/features/nodes/util/{ => schema}/buildFieldInputTemplate.ts (99%) rename invokeai/frontend/web/src/features/nodes/util/{ => schema}/buildFieldOutputTemplate.ts (81%) rename invokeai/frontend/web/src/features/nodes/util/{ => schema}/parseFieldType.ts (97%) rename invokeai/frontend/web/src/features/nodes/util/{ => schema}/parseSchema.ts (96%) rename invokeai/frontend/web/src/features/nodes/util/{ => workflow}/buildWorkflow.ts (90%) rename invokeai/frontend/web/src/features/nodes/util/{ => workflow}/validateWorkflow.ts (94%) diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedCanvas.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedCanvas.ts index 8c283ce64e..bcaf778b6e 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedCanvas.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedCanvas.ts @@ -10,8 +10,8 @@ import { blobToDataURL } from 'features/canvas/util/blobToDataURL'; import { getCanvasData } from 'features/canvas/util/getCanvasData'; import { getCanvasGenerationMode } from 'features/canvas/util/getCanvasGenerationMode'; import { canvasGraphBuilt } from 'features/nodes/store/actions'; -import { buildCanvasGraph } from 'features/nodes/util/graphBuilders/buildCanvasGraph'; -import { prepareLinearUIBatch } from 'features/nodes/util/graphBuilders/buildLinearBatchConfig'; +import { buildCanvasGraph } from 'features/nodes/util/graph/buildCanvasGraph'; +import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig'; import { imagesApi } from 'services/api/endpoints/images'; import { queueApi } from 'services/api/endpoints/queue'; import { ImageDTO } from 'services/api/types'; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts index bb89d18b91..faeecfb44c 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts @@ -1,9 +1,9 @@ import { enqueueRequested } from 'app/store/actions'; -import { prepareLinearUIBatch } from 'features/nodes/util/graphBuilders/buildLinearBatchConfig'; -import { buildLinearImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearImageToImageGraph'; -import { buildLinearSDXLImageToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph'; -import { buildLinearSDXLTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph'; -import { buildLinearTextToImageGraph } from 'features/nodes/util/graphBuilders/buildLinearTextToImageGraph'; +import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig'; +import { buildLinearImageToImageGraph } from 'features/nodes/util/graph/buildLinearImageToImageGraph'; +import { buildLinearSDXLImageToImageGraph } from 'features/nodes/util/graph/buildLinearSDXLImageToImageGraph'; +import { buildLinearSDXLTextToImageGraph } from 'features/nodes/util/graph/buildLinearSDXLTextToImageGraph'; +import { buildLinearTextToImageGraph } from 'features/nodes/util/graph/buildLinearTextToImageGraph'; import { queueApi } from 'services/api/endpoints/queue'; import { startAppListening } from '..'; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes.ts index b87e443a4e..b9b1060f18 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes.ts @@ -1,5 +1,5 @@ import { enqueueRequested } from 'app/store/actions'; -import { buildNodesGraph } from 'features/nodes/util/graphBuilders/buildNodesGraph'; +import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph'; import { queueApi } from 'services/api/endpoints/queue'; import { BatchConfig } from 'services/api/types'; import { startAppListening } from '..'; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts index f5b630a39d..ff44317fcf 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts @@ -1,7 +1,7 @@ import { logger } from 'app/logging/logger'; import { parseify } from 'common/util/serialize'; import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice'; -import { parseSchema } from 'features/nodes/util/parseSchema'; +import { parseSchema } from 'features/nodes/util/schema/parseSchema'; import { size } from 'lodash-es'; import { receivedOpenAPISchema } from 'services/api/thunks/schema'; import { startAppListening } from '..'; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts index bc9959b8fc..364a2658bf 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts @@ -10,7 +10,7 @@ import { IMAGE_CATEGORIES } from 'features/gallery/store/types'; import { LINEAR_UI_OUTPUT, nodeIDDenyList, -} from 'features/nodes/util/graphBuilders/constants'; +} from 'features/nodes/util/graph/constants'; import { boardsApi } from 'services/api/endpoints/boards'; import { imagesApi } from 'services/api/endpoints/images'; import { imagesAdapter } from 'services/api/util'; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts index b2383410bd..1df083c795 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts @@ -4,7 +4,7 @@ import { nodeReplaced } from 'features/nodes/store/nodesSlice'; import { getNeedsUpdate, updateNode, -} from 'features/nodes/store/util/nodeUpdate'; +} from 'features/nodes/util/node/nodeUpdate'; import { NodeUpdateError } from 'features/nodes/types/error'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { addToast } from 'features/system/store/systemSlice'; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/upscaleRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/upscaleRequested.ts index 9ddcdc9701..7dbb7d9fb1 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/upscaleRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/upscaleRequested.ts @@ -1,7 +1,7 @@ import { createAction } from '@reduxjs/toolkit'; import { logger } from 'app/logging/logger'; import { parseify } from 'common/util/serialize'; -import { buildAdHocUpscaleGraph } from 'features/nodes/util/graphBuilders/buildAdHocUpscaleGraph'; +import { buildAdHocUpscaleGraph } from 'features/nodes/util/graph/buildAdHocUpscaleGraph'; import { addToast } from 'features/system/store/systemSlice'; import { t } from 'i18next'; import { queueApi } from 'services/api/endpoints/queue'; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts index 5336c63942..20d0ac0773 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts @@ -4,7 +4,7 @@ import { workflowLoadRequested } from 'features/nodes/store/actions'; import { workflowLoaded } from 'features/nodes/store/nodesSlice'; import { $flow } from 'features/nodes/store/reactFlowInstance'; import { WorkflowVersionError } from 'features/nodes/types/error'; -import { validateWorkflow } from 'features/nodes/util/validateWorkflow'; +import { validateWorkflow } from 'features/nodes/util/workflow/validateWorkflow'; import { addToast } from 'features/system/store/systemSlice'; import { makeToast } from 'features/system/util/makeToast'; import { setActiveTab } from 'features/ui/store/uiSlice'; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx index 419899abdb..7d854484e0 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx @@ -11,7 +11,7 @@ import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import IAIMantineSearchableSelect from 'common/components/IAIMantineSearchableSelect'; -import { useBuildNodeData } from 'features/nodes/hooks/useBuildNodeData'; +import { useBuildNode } from 'features/nodes/hooks/useBuildNode'; import { addNodePopoverClosed, addNodePopoverOpened, @@ -51,7 +51,7 @@ const selectFilter = (value: string, item: NodeTemplate) => { const AddNodePopover = () => { const dispatch = useAppDispatch(); - const buildInvocation = useBuildNodeData(); + const buildInvocation = useBuildNode(); const toaster = useAppToaster(); const { t } = useTranslation(); diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx index ecbe538fcc..fad2443c4f 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx @@ -11,7 +11,7 @@ import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { IAINoContentFallback } from 'common/components/IAIImageFallback'; -import { getNeedsUpdate } from 'features/nodes/store/util/nodeUpdate'; +import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate'; import { InvocationNodeData, InvocationTemplate, diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts index ccfa0f57fd..b934c6b959 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts @@ -5,8 +5,8 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { keys, map } from 'lodash-es'; import { useMemo } from 'react'; import { isInvocationNode } from '../types/invocation'; -import { getSortedFilteredFieldNames } from '../util/getSortedFilteredFieldNames'; -import { TEMPLATE_BUILDER_MAP } from '../util/buildFieldInputTemplate'; +import { getSortedFilteredFieldNames } from '../util/node/getSortedFilteredFieldNames'; +import { TEMPLATE_BUILDER_MAP } from '../util/schema/buildFieldInputTemplate'; export const useAnyOrDirectInputFieldNames = (nodeId: string) => { const selector = useMemo( diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNodeData.ts b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts similarity index 87% rename from invokeai/frontend/web/src/features/nodes/hooks/useBuildNodeData.ts rename to invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts index 46f11b3823..69c0757689 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNodeData.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts @@ -3,13 +3,12 @@ import { RootState } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { useCallback } from 'react'; import { Node, useReactFlow } from 'reactflow'; -import { - buildCurrentImageNode, - buildInvocationNode, - buildNotesNode, -} from '../store/util/buildNodeData'; import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from '../types/constants'; import { AnyNode, InvocationTemplate } from '../types/invocation'; +import { buildCurrentImageNode } from '../util/node/buildCurrentImageNode'; +import { buildInvocationNode } from '../util/node/buildInvocationNode'; +import { buildNotesNode } from '../util/node/buildNotesNode'; + const templatesSelector = createSelector( [(state: RootState) => state.nodes], (nodes) => nodes.nodeTemplates @@ -19,7 +18,7 @@ export const SHARED_NODE_PROPERTIES: Partial = { dragHandle: `.${DRAG_HANDLE_CLASSNAME}`, }; -export const useBuildNodeData = () => { +export const useBuildNode = () => { const nodeTemplates = useAppSelector(templatesSelector); const flow = useReactFlow(); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts index 2951167944..a2694bd1ba 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts @@ -5,8 +5,8 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { keys, map } from 'lodash-es'; import { useMemo } from 'react'; import { isInvocationNode } from '../types/invocation'; -import { getSortedFilteredFieldNames } from '../util/getSortedFilteredFieldNames'; -import { TEMPLATE_BUILDER_MAP } from '../util/buildFieldInputTemplate'; +import { getSortedFilteredFieldNames } from '../util/node/getSortedFilteredFieldNames'; +import { TEMPLATE_BUILDER_MAP } from '../util/schema/buildFieldInputTemplate'; export const useConnectionInputFieldNames = (nodeId: string) => { const selector = useMemo( diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts index c22c0d9505..9673c6417f 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts @@ -2,7 +2,7 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; -import { getNeedsUpdate } from '../store/util/nodeUpdate'; +import { getNeedsUpdate } from '../util/node/nodeUpdate'; import { isInvocationNode } from '../types/invocation'; const selector = createSelector( diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts index 99a7c47170..cf3ecfbc12 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts @@ -4,7 +4,7 @@ import { useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { useMemo } from 'react'; import { isInvocationNode } from '../types/invocation'; -import { getNeedsUpdate } from '../store/util/nodeUpdate'; +import { getNeedsUpdate } from '../util/node/nodeUpdate'; export const useNodeNeedsUpdate = (nodeId: string) => { const selector = useMemo( diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts index 93e4ccb833..ec0315b227 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts @@ -5,7 +5,7 @@ import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { map } from 'lodash-es'; import { useMemo } from 'react'; import { isInvocationNode } from '../types/invocation'; -import { getSortedFilteredFieldNames } from '../util/getSortedFilteredFieldNames'; +import { getSortedFilteredFieldNames } from '../util/node/getSortedFilteredFieldNames'; export const useOutputFieldNames = (nodeId: string) => { const selector = useMemo( diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useWorkflow.ts b/invokeai/frontend/web/src/features/nodes/hooks/useWorkflow.ts index f729aa1004..b0799630c9 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useWorkflow.ts @@ -1,6 +1,6 @@ import { RootState } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; -import { buildWorkflow } from 'features/nodes/util/buildWorkflow'; +import { buildWorkflow } from 'features/nodes/util/workflow/buildWorkflow'; import { useMemo } from 'react'; import { useDebounce } from 'use-debounce'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addControlNetToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addControlNetToLinearGraph.ts similarity index 100% rename from invokeai/frontend/web/src/features/nodes/util/graphBuilders/addControlNetToLinearGraph.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/addControlNetToLinearGraph.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addHrfToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addHrfToGraph.ts similarity index 100% rename from invokeai/frontend/web/src/features/nodes/util/graphBuilders/addHrfToGraph.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/addHrfToGraph.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addIPAdapterToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addIPAdapterToLinearGraph.ts similarity index 100% rename from invokeai/frontend/web/src/features/nodes/util/graphBuilders/addIPAdapterToLinearGraph.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/addIPAdapterToLinearGraph.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLinearUIOutputNode.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addLinearUIOutputNode.ts similarity index 100% rename from invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLinearUIOutputNode.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/addLinearUIOutputNode.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addLoRAsToGraph.ts similarity index 100% rename from invokeai/frontend/web/src/features/nodes/util/graphBuilders/addLoRAsToGraph.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/addLoRAsToGraph.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addNSFWCheckerToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addNSFWCheckerToGraph.ts similarity index 100% rename from invokeai/frontend/web/src/features/nodes/util/graphBuilders/addNSFWCheckerToGraph.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/addNSFWCheckerToGraph.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLLoRAstoGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLLoRAstoGraph.ts similarity index 100% rename from invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLLoRAstoGraph.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/addSDXLLoRAstoGraph.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLRefinerToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addSDXLRefinerToGraph.ts similarity index 100% rename from invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSDXLRefinerToGraph.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/addSDXLRefinerToGraph.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSeamlessToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addSeamlessToLinearGraph.ts similarity index 100% rename from invokeai/frontend/web/src/features/nodes/util/graphBuilders/addSeamlessToLinearGraph.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/addSeamlessToLinearGraph.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addT2IAdapterToLinearGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addT2IAdapterToLinearGraph.ts similarity index 100% rename from invokeai/frontend/web/src/features/nodes/util/graphBuilders/addT2IAdapterToLinearGraph.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/addT2IAdapterToLinearGraph.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addVAEToGraph.ts similarity index 100% rename from invokeai/frontend/web/src/features/nodes/util/graphBuilders/addVAEToGraph.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/addVAEToGraph.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/addWatermarkerToGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/addWatermarkerToGraph.ts similarity index 100% rename from invokeai/frontend/web/src/features/nodes/util/graphBuilders/addWatermarkerToGraph.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/addWatermarkerToGraph.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildAdHocUpscaleGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildAdHocUpscaleGraph.ts similarity index 100% rename from invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildAdHocUpscaleGraph.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/buildAdHocUpscaleGraph.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasGraph.ts similarity index 100% rename from invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasGraph.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasGraph.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasImageToImageGraph.ts similarity index 100% rename from invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasImageToImageGraph.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasImageToImageGraph.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasInpaintGraph.ts similarity index 100% rename from invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasInpaintGraph.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasInpaintGraph.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasOutpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasOutpaintGraph.ts similarity index 100% rename from invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasOutpaintGraph.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasOutpaintGraph.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLImageToImageGraph.ts similarity index 100% rename from invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLImageToImageGraph.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLImageToImageGraph.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLInpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLInpaintGraph.ts similarity index 100% rename from invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLInpaintGraph.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLInpaintGraph.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLOutpaintGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLOutpaintGraph.ts similarity index 100% rename from invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLOutpaintGraph.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLOutpaintGraph.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLTextToImageGraph.ts similarity index 100% rename from invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasSDXLTextToImageGraph.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasSDXLTextToImageGraph.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasTextToImageGraph.ts similarity index 100% rename from invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildCanvasTextToImageGraph.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/buildCanvasTextToImageGraph.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearBatchConfig.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearBatchConfig.ts similarity index 100% rename from invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearBatchConfig.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/buildLinearBatchConfig.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearImageToImageGraph.ts similarity index 100% rename from invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearImageToImageGraph.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/buildLinearImageToImageGraph.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLImageToImageGraph.ts similarity index 100% rename from invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLImageToImageGraph.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLImageToImageGraph.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLTextToImageGraph.ts similarity index 100% rename from invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearSDXLTextToImageGraph.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/buildLinearSDXLTextToImageGraph.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts similarity index 100% rename from invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.ts similarity index 98% rename from invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.ts index 9ed0eb1d32..cfa306b2c6 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.ts @@ -4,7 +4,7 @@ import { cloneDeep, omit, reduce } from 'lodash-es'; import { Graph } from 'services/api/types'; import { AnyInvocation } from 'services/events/types'; import { v4 as uuidv4 } from 'uuid'; -import { buildWorkflow } from '../buildWorkflow'; +import { buildWorkflow } from '../workflow/buildWorkflow'; import { FieldInputInstance, isColorFieldInputInstance, diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts b/invokeai/frontend/web/src/features/nodes/util/graph/constants.ts similarity index 100% rename from invokeai/frontend/web/src/features/nodes/util/graphBuilders/constants.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/constants.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/helpers/craftSDXLStylePrompt.ts b/invokeai/frontend/web/src/features/nodes/util/graph/helpers/craftSDXLStylePrompt.ts similarity index 100% rename from invokeai/frontend/web/src/features/nodes/util/graphBuilders/helpers/craftSDXLStylePrompt.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/helpers/craftSDXLStylePrompt.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/graphBuilders/metadata.ts b/invokeai/frontend/web/src/features/nodes/util/graph/metadata.ts similarity index 100% rename from invokeai/frontend/web/src/features/nodes/util/graphBuilders/metadata.ts rename to invokeai/frontend/web/src/features/nodes/util/graph/metadata.ts diff --git a/invokeai/frontend/web/src/features/nodes/util/node/buildCurrentImageNode.ts b/invokeai/frontend/web/src/features/nodes/util/node/buildCurrentImageNode.ts new file mode 100644 index 0000000000..8440cbb3d1 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/node/buildCurrentImageNode.ts @@ -0,0 +1,23 @@ +import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants'; +import { CurrentImageNode } from 'features/nodes/types/invocation'; +import { XYPosition } from 'reactflow'; +import { v4 as uuidv4 } from 'uuid'; + +export const buildCurrentImageNode = ( + position: XYPosition +): CurrentImageNode => { + const nodeId = uuidv4(); + const node: CurrentImageNode = { + ...SHARED_NODE_PROPERTIES, + id: nodeId, + type: 'current_image', + position, + data: { + id: nodeId, + type: 'current_image', + isOpen: true, + label: 'Current Image', + }, + }; + return node; +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/buildNodeData.ts b/invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts similarity index 62% rename from invokeai/frontend/web/src/features/nodes/store/util/buildNodeData.ts rename to invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts index c2582600c3..1da3b23a9a 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/buildNodeData.ts +++ b/invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts @@ -4,57 +4,18 @@ import { FieldOutputInstance, } from 'features/nodes/types/field'; import { - CurrentImageNodeData, - InvocationNodeData, + InvocationNode, InvocationTemplate, - NotesNodeData, } from 'features/nodes/types/invocation'; -import { buildFieldInputInstance } from 'features/nodes/util/buildFieldInputInstance'; +import { buildFieldInputInstance } from 'features/nodes/util/schema/buildFieldInputInstance'; import { reduce } from 'lodash-es'; -import { Node, XYPosition } from 'reactflow'; +import { XYPosition } from 'reactflow'; import { v4 as uuidv4 } from 'uuid'; -export const buildNotesNode = (position: XYPosition): Node => { - const nodeId = uuidv4(); - const node: Node = { - ...SHARED_NODE_PROPERTIES, - id: nodeId, - type: 'notes', - position, - data: { - id: nodeId, - isOpen: true, - label: 'Notes', - notes: '', - type: 'notes', - }, - }; - return node; -}; - -export const buildCurrentImageNode = ( - position: XYPosition -): Node => { - const nodeId = uuidv4(); - const node: Node = { - ...SHARED_NODE_PROPERTIES, - id: nodeId, - type: 'current_image', - position, - data: { - id: nodeId, - type: 'current_image', - isOpen: true, - label: 'Current Image', - }, - }; - return node; -}; - export const buildInvocationNode = ( position: XYPosition, template: InvocationTemplate -): Node => { +): InvocationNode => { const nodeId = uuidv4(); const { type } = template; @@ -94,7 +55,7 @@ export const buildInvocationNode = ( {} as Record ); - const node: Node = { + const node: InvocationNode = { ...SHARED_NODE_PROPERTIES, id: nodeId, type: 'invocation', diff --git a/invokeai/frontend/web/src/features/nodes/util/node/buildNotesNode.ts b/invokeai/frontend/web/src/features/nodes/util/node/buildNotesNode.ts new file mode 100644 index 0000000000..268d23d2db --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/node/buildNotesNode.ts @@ -0,0 +1,22 @@ +import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants'; +import { NotesNode } from 'features/nodes/types/invocation'; +import { XYPosition } from 'reactflow'; +import { v4 as uuidv4 } from 'uuid'; + +export const buildNotesNode = (position: XYPosition): NotesNode => { + const nodeId = uuidv4(); + const node: NotesNode = { + ...SHARED_NODE_PROPERTIES, + id: nodeId, + type: 'notes', + position, + data: { + id: nodeId, + isOpen: true, + label: 'Notes', + notes: '', + type: 'notes', + }, + }; + return node; +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/getSortedFilteredFieldNames.ts b/invokeai/frontend/web/src/features/nodes/util/node/getSortedFilteredFieldNames.ts similarity index 90% rename from invokeai/frontend/web/src/features/nodes/util/getSortedFilteredFieldNames.ts rename to invokeai/frontend/web/src/features/nodes/util/node/getSortedFilteredFieldNames.ts index 2ed5faca29..2aa1ccc172 100644 --- a/invokeai/frontend/web/src/features/nodes/util/getSortedFilteredFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/util/node/getSortedFilteredFieldNames.ts @@ -1,5 +1,5 @@ import { isNil } from 'lodash-es'; -import { FieldInputTemplate, FieldOutputTemplate } from '../types/field'; +import { FieldInputTemplate, FieldOutputTemplate } from '../../types/field'; export const getSortedFilteredFieldNames = ( fields: FieldInputTemplate[] | FieldOutputTemplate[] diff --git a/invokeai/frontend/web/src/features/nodes/store/util/nodeUpdate.ts b/invokeai/frontend/web/src/features/nodes/util/node/nodeUpdate.ts similarity index 97% rename from invokeai/frontend/web/src/features/nodes/store/util/nodeUpdate.ts rename to invokeai/frontend/web/src/features/nodes/util/node/nodeUpdate.ts index e9e24823f9..d1913f6e05 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/nodeUpdate.ts +++ b/invokeai/frontend/web/src/features/nodes/util/node/nodeUpdate.ts @@ -7,7 +7,7 @@ import { import { zParsedSemver } from 'features/nodes/types/semver'; import { cloneDeep, defaultsDeep } from 'lodash-es'; import { Node } from 'reactflow'; -import { buildInvocationNode } from './buildNodeData'; +import { buildInvocationNode } from './buildInvocationNode'; export const getNeedsUpdate = ( node: Node, diff --git a/invokeai/frontend/web/src/features/nodes/util/buildFieldInputInstance.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts similarity index 93% rename from invokeai/frontend/web/src/features/nodes/util/buildFieldInputInstance.ts rename to invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts index 200bd98e86..2e06652d4b 100644 --- a/invokeai/frontend/web/src/features/nodes/util/buildFieldInputInstance.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts @@ -1,5 +1,5 @@ import { get } from 'lodash-es'; -import { FieldInputInstance, FieldInputTemplate } from '../types/field'; +import { FieldInputInstance, FieldInputTemplate } from '../../types/field'; const FIELD_VALUE_FALLBACK_MAP = { EnumField: '', diff --git a/invokeai/frontend/web/src/features/nodes/util/buildFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts similarity index 99% rename from invokeai/frontend/web/src/features/nodes/util/buildFieldInputTemplate.ts rename to invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts index 0deddf0dea..9f33c1328f 100644 --- a/invokeai/frontend/web/src/features/nodes/util/buildFieldInputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts @@ -22,8 +22,8 @@ import { T2IAdapterModelFieldInputTemplate, VAEModelFieldInputTemplate, isStatefulFieldType, -} from '../types/field'; -import { InvocationFieldSchema } from '../types/openapi'; +} from '../../types/field'; +import { InvocationFieldSchema } from '../../types/openapi'; // eslint-disable-next-line @typescript-eslint/no-explicit-any type FieldInputTemplateBuilder = // valid `any`! diff --git a/invokeai/frontend/web/src/features/nodes/util/buildFieldOutputTemplate.ts b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldOutputTemplate.ts similarity index 81% rename from invokeai/frontend/web/src/features/nodes/util/buildFieldOutputTemplate.ts rename to invokeai/frontend/web/src/features/nodes/util/schema/buildFieldOutputTemplate.ts index 05e3c66386..0d363da429 100644 --- a/invokeai/frontend/web/src/features/nodes/util/buildFieldOutputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/buildFieldOutputTemplate.ts @@ -1,6 +1,6 @@ import { startCase } from 'lodash-es'; -import { FieldOutputTemplate, FieldType } from '../types/field'; -import { InvocationFieldSchema } from '../types/openapi'; +import { FieldOutputTemplate, FieldType } from '../../types/field'; +import { InvocationFieldSchema } from '../../types/openapi'; export const buildFieldOutputTemplate = ( fieldSchema: InvocationFieldSchema, diff --git a/invokeai/frontend/web/src/features/nodes/util/parseFieldType.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts similarity index 97% rename from invokeai/frontend/web/src/features/nodes/util/parseFieldType.ts rename to invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts index 2d25ab9faa..2314faaa39 100644 --- a/invokeai/frontend/web/src/features/nodes/util/parseFieldType.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseFieldType.ts @@ -1,8 +1,11 @@ import { t } from 'i18next'; import { isArray } from 'lodash-es'; import { OpenAPIV3_1 } from 'openapi-types'; -import { FieldTypeParseError, UnsupportedFieldTypeError } from '../types/error'; -import { FieldType } from '../types/field'; +import { + FieldTypeParseError, + UnsupportedFieldTypeError, +} from '../../types/error'; +import { FieldType } from '../../types/field'; import { OpenAPIV3_1SchemaOrRef, isArraySchemaObject, @@ -10,7 +13,7 @@ import { isNonArraySchemaObject, isRefObject, isSchemaObject, -} from '../types/openapi'; +} from '../../types/openapi'; /** * Transforms an invocation output ref object to field type. diff --git a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts similarity index 96% rename from invokeai/frontend/web/src/features/nodes/util/parseSchema.ts rename to invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts index 81d79d2976..04b3f66e7d 100644 --- a/invokeai/frontend/web/src/features/nodes/util/parseSchema.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts @@ -3,15 +3,18 @@ import { parseify } from 'common/util/serialize'; import { t } from 'i18next'; import { reduce } from 'lodash-es'; import { OpenAPIV3_1 } from 'openapi-types'; -import { FieldTypeParseError, UnsupportedFieldTypeError } from '../types/error'; -import { FieldInputTemplate, FieldOutputTemplate } from '../types/field'; -import { InvocationTemplate } from '../types/invocation'; +import { + FieldTypeParseError, + UnsupportedFieldTypeError, +} from '../../types/error'; +import { FieldInputTemplate, FieldOutputTemplate } from '../../types/field'; +import { InvocationTemplate } from '../../types/invocation'; import { InvocationSchemaObject, isInvocationFieldSchema, isInvocationOutputSchemaObject, isInvocationSchemaObject, -} from '../types/openapi'; +} from '../../types/openapi'; import { buildFieldInputTemplate } from './buildFieldInputTemplate'; import { buildFieldOutputTemplate } from './buildFieldOutputTemplate'; import { parseFieldType } from './parseFieldType'; diff --git a/invokeai/frontend/web/src/features/nodes/util/buildWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts similarity index 90% rename from invokeai/frontend/web/src/features/nodes/util/buildWorkflow.ts rename to invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts index 7e49be4068..ee28376347 100644 --- a/invokeai/frontend/web/src/features/nodes/util/buildWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts @@ -1,6 +1,6 @@ import { logger } from 'app/logging/logger'; -import { NodesState } from '../store/types'; -import { WorkflowV2, zWorkflowEdge, zWorkflowNode } from '../types/workflow'; +import { NodesState } from '../../store/types'; +import { WorkflowV2, zWorkflowEdge, zWorkflowNode } from '../../types/workflow'; import { fromZodError } from 'zod-validation-error'; import { parseify } from 'common/util/serialize'; import i18n from 'i18next'; diff --git a/invokeai/frontend/web/src/features/nodes/util/validateWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts similarity index 94% rename from invokeai/frontend/web/src/features/nodes/util/validateWorkflow.ts rename to invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts index 6d2ee13cf2..08ff0c4daf 100644 --- a/invokeai/frontend/web/src/features/nodes/util/validateWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts @@ -2,10 +2,10 @@ import { parseify } from 'common/util/serialize'; import { t } from 'i18next'; import { keyBy } from 'lodash-es'; import { JsonObject } from 'type-fest'; -import { getNeedsUpdate } from '../store/util/nodeUpdate'; -import { InvocationTemplate } from '../types/invocation'; -import { parseAndMigrateWorkflow } from '../types/migration/migrations'; -import { WorkflowV2, isWorkflowInvocationNode } from '../types/workflow'; +import { getNeedsUpdate } from '../node/nodeUpdate'; +import { InvocationTemplate } from '../../types/invocation'; +import { parseAndMigrateWorkflow } from '../../types/migration/migrations'; +import { WorkflowV2, isWorkflowInvocationNode } from '../../types/workflow'; type WorkflowWarning = { message: string; From 0d9a546d7479efcb6787062fe4e126c44b99e937 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 27 Nov 2023 01:07:21 +1100 Subject: [PATCH 21/65] feat(ui): organize migrations files --- .../nodes/types/{migration => }/v1/fieldTypeMap.ts | 2 +- .../nodes/types/{migration => }/v1/workflowV1.ts | 0 .../{types/migration => util/workflow}/migrations.ts | 12 ++++++------ .../features/nodes/util/workflow/validateWorkflow.ts | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) rename invokeai/frontend/web/src/features/nodes/types/{migration => }/v1/fieldTypeMap.ts (99%) rename invokeai/frontend/web/src/features/nodes/types/{migration => }/v1/workflowV1.ts (100%) rename invokeai/frontend/web/src/features/nodes/{types/migration => util/workflow}/migrations.ts (84%) diff --git a/invokeai/frontend/web/src/features/nodes/types/migration/v1/fieldTypeMap.ts b/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts similarity index 99% rename from invokeai/frontend/web/src/features/nodes/types/migration/v1/fieldTypeMap.ts rename to invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts index facf015b02..aace29d523 100644 --- a/invokeai/frontend/web/src/features/nodes/types/migration/v1/fieldTypeMap.ts +++ b/invokeai/frontend/web/src/features/nodes/types/v1/fieldTypeMap.ts @@ -1,4 +1,4 @@ -import { FieldType, StatefulFieldType } from '../../field'; +import { FieldType, StatefulFieldType } from '../field'; import { FieldTypeV1 } from './workflowV1'; /** diff --git a/invokeai/frontend/web/src/features/nodes/types/migration/v1/workflowV1.ts b/invokeai/frontend/web/src/features/nodes/types/v1/workflowV1.ts similarity index 100% rename from invokeai/frontend/web/src/features/nodes/types/migration/v1/workflowV1.ts rename to invokeai/frontend/web/src/features/nodes/types/v1/workflowV1.ts diff --git a/invokeai/frontend/web/src/features/nodes/types/migration/migrations.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts similarity index 84% rename from invokeai/frontend/web/src/features/nodes/types/migration/migrations.ts rename to invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts index 45c3852493..5428ac9861 100644 --- a/invokeai/frontend/web/src/features/nodes/types/migration/migrations.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts @@ -1,11 +1,11 @@ +import { t } from 'i18next'; import { forEach, isString } from 'lodash-es'; import { z } from 'zod'; -import { WorkflowVersionError } from '../error'; -import { zSemVer } from '../semver'; -import { WorkflowV2, zWorkflowV2 } from '../workflow'; -import { FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING } from './v1/fieldTypeMap'; -import { WorkflowV1, zWorkflowV1 } from './v1/workflowV1'; -import { t } from 'i18next'; +import { WorkflowVersionError } from '../../types/error'; +import { zSemVer } from '../../types/semver'; +import { FIELD_TYPE_V1_TO_FIELD_TYPE_V2_MAPPING } from '../../types/v1/fieldTypeMap'; +import { WorkflowV1, zWorkflowV1 } from '../../types/v1/workflowV1'; +import { WorkflowV2, zWorkflowV2 } from '../../types/workflow'; /** * Helper schema to extract the version from a workflow. diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts index 08ff0c4daf..07b3cc5bfc 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts @@ -4,7 +4,7 @@ import { keyBy } from 'lodash-es'; import { JsonObject } from 'type-fest'; import { getNeedsUpdate } from '../node/nodeUpdate'; import { InvocationTemplate } from '../../types/invocation'; -import { parseAndMigrateWorkflow } from '../../types/migration/migrations'; +import { parseAndMigrateWorkflow } from './migrations'; import { WorkflowV2, isWorkflowInvocationNode } from '../../types/workflow'; type WorkflowWarning = { From a02090b06ba120a50a4ddfc1066ab139666419c8 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 27 Nov 2023 01:30:09 +1100 Subject: [PATCH 22/65] feat(ui): update workflows design & implementation docs --- .../docs/WORKFLOWS_DESIGN_IMPLEMENTATION.md | 112 +++++++++++++++--- 1 file changed, 95 insertions(+), 17 deletions(-) diff --git a/invokeai/frontend/web/docs/WORKFLOWS_DESIGN_IMPLEMENTATION.md b/invokeai/frontend/web/docs/WORKFLOWS_DESIGN_IMPLEMENTATION.md index 70013499d0..757204069c 100644 --- a/invokeai/frontend/web/docs/WORKFLOWS_DESIGN_IMPLEMENTATION.md +++ b/invokeai/frontend/web/docs/WORKFLOWS_DESIGN_IMPLEMENTATION.md @@ -9,7 +9,7 @@ - [Linear UI](#linear-ui) - [Workflow Editor](#workflow-editor) - [Workflows](#workflows) - - [Workflow -\> reactflow state -\> InvokeAI graph](#workflow---reactflow-state---invokeai-graph) + - [Workflow -> reactflow state -> InvokeAI graph](#workflow---reactflow-state---invokeai-graph) - [Nodes vs Invocations](#nodes-vs-invocations) - [Workflow Linear View](#workflow-linear-view) - [OpenAPI Schema](#openapi-schema) @@ -27,10 +27,16 @@ - [Optional Fields](#optional-fields) - [Building Field Input Templates](#building-field-input-templates) - [Building Field Output Templates](#building-field-output-templates) + - [Managing reactflow State](#managing-reactflow-state) + - [Building Nodes and Edges](#building-nodes-and-edges) + - [Building a Workflow](#building-a-workflow) + - [Loading a Workflow](#loading-a-workflow) - [Workflow Migrations](#workflow-migrations) +> This document describes, at a high level, the design and implementation of workflows in the InvokeAI frontend. There are a substantial number of implementation details not included, but which are hopefully clear from the code. + InvokeAI's backend uses graphs, composed of **nodes** and **edges**, to process data and generate images. Nodes have any number of **input fields** and **output fields**. Edges connect nodes together via their inputs and outputs. Fields have data types which dictate how they may be connected. @@ -54,7 +60,7 @@ The user-managed parameters on these tabs are stored as simple objects in the ap This logic can be fairly complex due to the range of features available and their interactions. Depending on the parameters selected, the graph may be very different. Building graphs in code can be challenging - you are trying to construct a non-linear structure in a linear context. The simplest graph building logic is for **Text to Image** with a SD1.5 model: -`invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildLinearTextToImageGraph.ts` +`invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts` There are many other graph builders in the same folder for different tabs or base models (e.g. SDXL). Some are pretty hairy. @@ -64,7 +70,7 @@ In the Linear UI, we go straight from **simple application state** to **graph** The Workflow Editor is a visual graph editor, allowing users to draw edges from node to node to construct a graph. This _far_ more approachable way to create complex graphs. -InvokeAI uses the [reactflow](https://github.com/xyflow/xyflow) library to power the Workflow Editor. It provides both a graph editor UI and manages its own internal graph state. +InvokeAI uses the [reactflow] library to power the Workflow Editor. It provides both a graph editor UI and manages its own internal graph state. #### Workflows @@ -83,7 +89,7 @@ Workflows should have other qualities: - Resilient: you should be able to "upgrade" a workflow as the application changes. - Abstract: as much as is possible, workflows should not be married to the specific implementation details of the application. -To support these qualities, workflows are serializable, have a versioned schemas, and represent graphs as minimally as possible. Fortunately, the reactflow state for nodes and edges works perfectly for this.. +To support these qualities, workflows are serializable, have a versioned schemas, and represent graphs as minimally as possible. Fortunately, the reactflow state for nodes and edges works perfectly for this. ##### Workflow -> reactflow state -> InvokeAI graph @@ -92,13 +98,13 @@ Given a workflow, we need to be able to derive reactflow state and/or an InvokeA The first step - workflow to reactflow state - is very simple. The logic is in `invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts`, in the `workflowLoaded` reducer. The reactflow state is, however, structurally incompatible with our backend's graph structure. When a user invokes on a Workflow, we need to convert the reactflow state into an InvokeAI graph. This is far simpler than the graph building logic from the Linear UI: -`invokeai/frontend/web/src/features/nodes/util/graphBuilders/buildNodesGraph.ts` +`invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.ts` ##### Nodes vs Invocations We often use the terms "node" and "invocation" interchangeably, but they may refer to different things in the frontend. -reactflow [has its own definitions](https://reactflow.dev/learn/concepts/terms-and-definitions) of "node", "edge" and "handle" which are closely related to InvokeAI graph concepts. +reactflow [has its own definitions][reactflow-concepts] of "node", "edge" and "handle" which are closely related to InvokeAI graph concepts. - A reactflow node is related to an InvokeAI invocation. It has a "data" property, which holds the InvokeAI-specific invocation data. - A reactflow edge is roughly equivalent to an InvokeAI edge. @@ -157,13 +163,13 @@ If it is annotated as a union of a type and list, the type will be flagged as a ## Implementation -The majority of data structures in the backend are [pydantic](https://github.com/pydantic/pydantic) models. Pydantic provides OpenAPI schemas for all models and we then generate TypeScript types from those. +The majority of data structures in the backend are [pydantic] models. Pydantic provides OpenAPI schemas for all models and we then generate TypeScript types from those. The OpenAPI schema is parsed at runtime into our invocation templates. -Workflows and all related data are modeled in the frontend using [zod](https://github.com/colinhacks/zod). Related types are inferred from the zod schemas. +Workflows and all related data are modeled in the frontend using [zod]. Related types are inferred from the zod schemas. -> In python, invocations are pydantic models with fields. These fields become inputs. The invocation's `invoke()` function returns a pydantic model - its output. Like the invocation itself, the output model has any number of fields, which become outputs. +> In python, invocations are pydantic models with fields. These fields become node inputs. The invocation's `invoke()` function returns a pydantic model - its output. Like the invocation itself, the output model has any number of fields, which become node outputs. ### zod Schemas and Types @@ -176,11 +182,11 @@ Roughly order from lowest-level to highest: - `invocation.ts`: invocations and other node types - `workflow.ts`: workflows and constituents -We customize the OpenAPI schema to include additional properties on invocation and field schemas. To facilitate parsing this schema into templates, we modify/wrap the types from [openapi-types](https://github.com/kogosoftwarellc/open-api/tree/main/packages/openapi-types) in `openapi.ts`. +We customize the OpenAPI schema to include additional properties on invocation and field schemas. To facilitate parsing this schema into templates, we modify/wrap the types from [openapi-types] in `openapi.ts`. ### OpenAPI Schema Parsing -The entrypoint for the OpenAPI schema parsing is `invokeai/frontend/web/src/features/nodes/util/parseSchema.ts`. +Schema parsing logic lives in `invokeai/frontend/web/src/features/nodes/util/schema/`. The entrypoint is `parseSchema.ts`. General logic flow: @@ -208,7 +214,7 @@ type FieldType = { }; ``` -The parsing logic is in `invokeai/frontend/web/src/features/nodes/util/parseFieldType.ts`. +The parsing logic is in `parseFieldType.ts`. There are 4 general cases for field type parsing. @@ -222,7 +228,9 @@ We create a field type name from this `type` string (e.g. `string` -> `StringFie When a field is annotated as a pydantic model (e.g. `ImageField`, `MainModelField`, `ControlField`), it is represented as a **reference object**. Reference objects are pointers to another schema or reference object within the schema. -We need to **dereference**[^dereference] the schema to pull these out. Dereferencing may require recursion. We use the reference object's name directly for the field type name. +We need to **dereference** the schema to pull these out. Dereferencing may require recursion. We use the reference object's name directly for the field type name. + +> Unfortunately, at this time, we've had limited success using external libraries to deference at runtime, so we do this ourselves. ##### Collection Types @@ -244,11 +252,9 @@ Handling this adds a fair bit of complexity, as we now must filter out the `'nul If there is a single remaining schema object, we must recursively call to `parseFieldType()` to get parse it. -[^dereference]: Unfortunately, at this time, we've had limited success using external libraries to deference at runtime, so we do this ourselves. - #### Building Field Input Templates -Now that we have a field type, we can build an input template for the field. This logic is in `invokeai/frontend/web/src/features/nodes/util/buildFieldInputTemplate.ts`. +Now that we have a field type, we can build an input template for the field. This logic is in `buildFieldInputTemplate.ts`. Stateful fields all get a function to build their template, while stateless fields are constructed directly. This is possible because stateless fields have no default value or constraints. @@ -256,6 +262,78 @@ Stateful fields all get a function to build their template, while stateless fiel Field outputs are similar to stateless fields - they do not have any value in the frontend. When building their templates, we don't need a special function for each field type. -The logic is in `invokeai/frontend/web/src/features/nodes/util/buildFieldOutputTemplate.ts`. +The logic is in `buildFieldOutputTemplate.ts`. + +### Managing reactflow State + +As described above, the workflow editor state is the essentially the reactflow state, plus some extra metadata. + +We provide reactflow with an array of nodes and edges via redux, and a number of [event handlers][reactflow-events]. These handlers dispatch redux actions, managing nodes and edges. + +The pieces of redux state relevant to workflows are: + +- `state.nodes.nodes`: the reactflow nodes state +- `state.nodes.edges`: the reactflow edges state +- `state.nodes.workflow`: the workflow metadata + +#### Building Nodes and Edges + +A reactflow node has a few important top-level properties: + +- `id`: unique identifier +- `type`: a string that maps to a react component to render the node +- `position`: XY coordinates +- `data`: arbitrary data + +When the user adds a node, we build **invocation node data**, storing it in `data`. Invocation properties (e.g. type, version, label, etc.) are copied from the invocation template. Inputs and outputs are built from the invocation template's field templates. + +See `invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts`. + +Edges are managed by reactflow, but briefly, they consist of: + +- `source`: id of the source node +- `sourceHandle`: id of the source node handle (output field) +- `target`: id of the target node +- `targetHandle`: id of the target node handle (input field) + +> Edge creation is gated behind validation logic. This validation compares the input and output field types and overall graph state. + +#### Building a Workflow + +Building a workflow entity is as simple as dropping the nodes, edges and metadata into an object. + +Each node and edge is parsed with a zod schema, which serves to strip out any unneeded data. + +See `invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts``. + +#### Loading a Workflow + +Workflows may be loaded from external sources or the user's local instance. In all cases, the workflow needs to be handled with care, as an untrusted object. + +Loading has a few stages which may throw or warn if there are problems: + +- Parsing the workflow data structure itself, [migrating](#workflow-migrations) it if necessary (throws) +- Check for a template for each node (warns) +- Check each node's version against its template (warns) +- Validate the source and target of each edge (warns) + +This validation occurs in `invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts`. + +If there are no fatal errors, the workflow is then stored in redux state. ### Workflow Migrations + +When the workflow schema changes, we may need to perform some data migrations. This occurs as workflows are loaded. zod schemas for each workflow schema version is retained to facilitate migrations. + +Previous schemas are in folders in `invokeai/frontend/web/src/features/nodes/types/`, eg `v1/`. + +Migration logic is in `invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts`. + + + +[pydantic]: https://github.com/pydantic/pydantic 'pydantic' +[zod]: https://github.com/colinhacks/zod 'zod' +[openapi-types]: https://github.com/kogosoftwarellc/open-api/tree/main/packages/openapi-types 'openapi-types' +[reactflow]: https://github.com/xyflow/xyflow 'reactflow' +[reactflow-concepts]: https://reactflow.dev/learn/concepts/terms-and-definitions +[reactflow-events]: https://reactflow.dev/api-reference/react-flow#event-handlers From e41d0b9a7681856bc4d8a69532dce55e9dd20901 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 27 Nov 2023 08:48:01 +1100 Subject: [PATCH 23/65] feat(ui): add links to relevant files in workflows doc --- .../docs/WORKFLOWS_DESIGN_IMPLEMENTATION.md | 37 ++++++++++++------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/invokeai/frontend/web/docs/WORKFLOWS_DESIGN_IMPLEMENTATION.md b/invokeai/frontend/web/docs/WORKFLOWS_DESIGN_IMPLEMENTATION.md index 757204069c..d5cfc9154e 100644 --- a/invokeai/frontend/web/docs/WORKFLOWS_DESIGN_IMPLEMENTATION.md +++ b/invokeai/frontend/web/docs/WORKFLOWS_DESIGN_IMPLEMENTATION.md @@ -59,10 +59,9 @@ The user-managed parameters on these tabs are stored as simple objects in the ap This logic can be fairly complex due to the range of features available and their interactions. Depending on the parameters selected, the graph may be very different. Building graphs in code can be challenging - you are trying to construct a non-linear structure in a linear context. -The simplest graph building logic is for **Text to Image** with a SD1.5 model: -`invokeai/frontend/web/src/features/nodes/util/graph/buildLinearTextToImageGraph.ts` +The simplest graph building logic is for **Text to Image** with a SD1.5 model: [buildLinearTextToImageGraph.ts] -There are many other graph builders in the same folder for different tabs or base models (e.g. SDXL). Some are pretty hairy. +There are many other graph builders in the same directory for different tabs or base models (e.g. SDXL). Some are pretty hairy. In the Linear UI, we go straight from **simple application state** to **graph** via these builders. @@ -95,10 +94,10 @@ To support these qualities, workflows are serializable, have a versioned schemas Given a workflow, we need to be able to derive reactflow state and/or an InvokeAI graph from it. -The first step - workflow to reactflow state - is very simple. The logic is in `invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts`, in the `workflowLoaded` reducer. +The first step - workflow to reactflow state - is very simple. The logic is in [nodesSlice.ts], in the `workflowLoaded` reducer. The reactflow state is, however, structurally incompatible with our backend's graph structure. When a user invokes on a Workflow, we need to convert the reactflow state into an InvokeAI graph. This is far simpler than the graph building logic from the Linear UI: -`invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.ts` +[buildNodesGraph.ts] ##### Nodes vs Invocations @@ -173,7 +172,7 @@ Workflows and all related data are modeled in the frontend using [zod]. Related ### zod Schemas and Types -The zod schemas, inferred types, and type guards are in `invokeai/frontend/web/src/features/nodes/types/`. +The zod schemas, inferred types, and type guards are in [types/]. Roughly order from lowest-level to highest: @@ -186,7 +185,7 @@ We customize the OpenAPI schema to include additional properties on invocation a ### OpenAPI Schema Parsing -Schema parsing logic lives in `invokeai/frontend/web/src/features/nodes/util/schema/`. The entrypoint is `parseSchema.ts`. +The entrypoint for OpenAPI schema parsing is [parseSchema.ts]. General logic flow: @@ -254,15 +253,17 @@ If there is a single remaining schema object, we must recursively call to `parse #### Building Field Input Templates -Now that we have a field type, we can build an input template for the field. This logic is in `buildFieldInputTemplate.ts`. +Now that we have a field type, we can build an input template for the field. Stateful fields all get a function to build their template, while stateless fields are constructed directly. This is possible because stateless fields have no default value or constraints. +See [buildFieldInputTemplate.ts]. + #### Building Field Output Templates Field outputs are similar to stateless fields - they do not have any value in the frontend. When building their templates, we don't need a special function for each field type. -The logic is in `buildFieldOutputTemplate.ts`. +See [buildFieldOutputTemplate.ts]. ### Managing reactflow State @@ -287,7 +288,7 @@ A reactflow node has a few important top-level properties: When the user adds a node, we build **invocation node data**, storing it in `data`. Invocation properties (e.g. type, version, label, etc.) are copied from the invocation template. Inputs and outputs are built from the invocation template's field templates. -See `invokeai/frontend/web/src/features/nodes/util/node/buildInvocationNode.ts`. +See [buildInvocationNode.ts]. Edges are managed by reactflow, but briefly, they consist of: @@ -304,7 +305,7 @@ Building a workflow entity is as simple as dropping the nodes, edges and metadat Each node and edge is parsed with a zod schema, which serves to strip out any unneeded data. -See `invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts``. +See [buildWorkflow.ts]. #### Loading a Workflow @@ -317,7 +318,7 @@ Loading has a few stages which may throw or warn if there are problems: - Check each node's version against its template (warns) - Validate the source and target of each edge (warns) -This validation occurs in `invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts`. +This validation occurs in [validateWorkflow.ts]. If there are no fatal errors, the workflow is then stored in redux state. @@ -327,7 +328,7 @@ When the workflow schema changes, we may need to perform some data migrations. T Previous schemas are in folders in `invokeai/frontend/web/src/features/nodes/types/`, eg `v1/`. -Migration logic is in `invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts`. +Migration logic is in [migrations.ts]. @@ -337,3 +338,13 @@ Migration logic is in `invokeai/frontend/web/src/features/nodes/util/workflow/mi [reactflow]: https://github.com/xyflow/xyflow 'reactflow' [reactflow-concepts]: https://reactflow.dev/learn/concepts/terms-and-definitions [reactflow-events]: https://reactflow.dev/api-reference/react-flow#event-handlers +[buildWorkflow.ts]: ../src/features/nodes/util/workflow/buildWorkflow.ts +[nodesSlice.ts]: ../src/features/nodes/store/nodesSlice.ts +[buildLinearTextToImageGraph.ts]: ../src/features/nodes/util/graph/buildLinearTextToImageGraph.ts +[buildNodesGraph.ts]: ../src/features/nodes/util/graph/buildNodesGraph.ts +[buildInvocationNode.ts]: ../src/features/nodes/util/node/buildInvocationNode.ts +[validateWorkflow.ts]: ../src/features/nodes/util/workflow/validateWorkflow.ts +[migrations.ts]: ../src/features/nodes/util/workflow/migrations.ts +[parseSchema.ts]: ../src/features/nodes/util/schema/parseSchema.ts +[buildFieldInputTemplate.ts]: ../src/features/nodes/util/schema/buildFieldInputTemplate.ts +[buildFieldOutputTemplate.ts]: ../src/features/nodes/util/schema/buildFieldOutputTemplate.ts From 4c6a88a64205a827204a93fca2c73cb3452270f1 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 27 Nov 2023 08:53:52 +1100 Subject: [PATCH 24/65] feat(ui): update readme --- invokeai/frontend/web/docs/README.md | 44 ++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/invokeai/frontend/web/docs/README.md b/invokeai/frontend/web/docs/README.md index 5f9e3c2c55..2545206c6a 100644 --- a/invokeai/frontend/web/docs/README.md +++ b/invokeai/frontend/web/docs/README.md @@ -14,6 +14,7 @@ - [i18next & Weblate](#i18next--weblate) - [openapi-typescript](#openapi-typescript) - [reactflow](#reactflow) + - [zod](#zod) - [Client Types Generation](#client-types-generation) - [Package Scripts](#package-scripts) - [Contributing](#contributing) @@ -31,50 +32,50 @@ InvokeAI's UI is made possible by a number of excellent open-source libraries. T ### Redux Toolkit -[Redux Toolkit](https://github.com/reduxjs/redux-toolkit) is used for state management and fetching/caching: +[Redux Toolkit] is used for state management and fetching/caching: - `RTK-Query` for data fetching and caching - `createAsyncThunk` for a couple other HTTP requests - `createEntityAdapter` to normalize things like images and models - `createListenerMiddleware` for async workflows -We use [redux-remember](https://github.com/zewish/redux-remember) for persistence. +We use [redux-remember] for persistence. ### Socket\.IO -[Socket\.IO](https://github.com/socketio/socket.io) is used for server-to-client events, like generation process and queue state changes. +[Socket.IO] is used for server-to-client events, like generation process and queue state changes. ### Chakra UI -[Chakra UI](https://github.com/chakra-ui/chakra-ui) is our primary UI library, but we also use a few components from [Mantine v6](https://v6.mantine.dev/). +[Chakra UI] is our primary UI library, but we also use a few components from [Mantine v6]. ### KonvaJS -[KonvaJS](https://github.com/konvajs/react-konva) powers the canvas. In the future, we'd like to explore [PixiJS](https://github.com/pixijs/pixijs) or WebGPU. +[KonvaJS] powers the canvas. In the future, we'd like to explore [PixiJS] or WebGPU. ### Vite -[Vite](https://github.com/vitejs/vite) is our bundler. +[Vite] is our bundler. ### i18next & Weblate -We use [i18next](https://github.com/i18next/react-i18next) for localization, but translation to languages other than English happens on our [Weblate](https://hosted.weblate.org/engage/invokeai/) project. **Only the English source strings should be changed on this repo.** +We use [i18next] for localization, but translation to languages other than English happens on our [Weblate] project. **Only the English source strings should be changed on this repo.** ### openapi-typescript -[openapi-typescript](https://github.com/drwpow/openapi-typescript) is used to generate types from the server's OpenAPI schema. See TYPES_CODEGEN.md. +[openapi-typescript] is used to generate types from the server's OpenAPI schema. See TYPES_CODEGEN.md. ### reactflow -[reactflow](https://github.com/xyflow/xyflow) powers the Workflow Editor. +[reactflow] powers the Workflow Editor. ### zod -[zod](https://github.com/colinhacks/zod) schemas are used to model data structures and provide runtime validation. +[zod] schemas are used to model data structures and provide runtime validation. ## Client Types Generation -We use [`openapi-typescript`](https://github.com/drwpow/openapi-typescript) to generate types from the app's OpenAPI schema. +We use [openapi-typescript] to generate types from the app's OpenAPI schema. The generated types are written to `invokeai/frontend/web/src/services/api/schema.d.ts`. This file is committed to the repo. @@ -107,11 +108,11 @@ Run with `yarn