From d6c9bf5b38785df1e7a29bea5132ea0f82424249 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 15 Aug 2023 11:46:37 -0400 Subject: [PATCH 01/51] added sdxl controlnet detection --- invokeai/backend/model_management/model_probe.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/invokeai/backend/model_management/model_probe.py b/invokeai/backend/model_management/model_probe.py index 145c56c273..3045849065 100644 --- a/invokeai/backend/model_management/model_probe.py +++ b/invokeai/backend/model_management/model_probe.py @@ -481,9 +481,19 @@ class ControlNetFolderProbe(FolderProbeBase): with open(config_file, "r") as file: config = json.load(file) # no obvious way to distinguish between sd2-base and sd2-768 - return ( - BaseModelType.StableDiffusion1 if config["cross_attention_dim"] == 768 else BaseModelType.StableDiffusion2 + dimension = config["cross_attention_dim"] + base_model = ( + BaseModelType.StableDiffusion1 + if dimension == 768 + else BaseModelType.StableDiffusion2 + if dimension == 1024 + else BaseModelType.StableDiffusionXL + if dimension == 2048 + else None ) + if not base_model: + raise InvalidModelException(f"Unable to determine model base for {self.folder_path}") + return base_model class LoRAFolderProbe(FolderProbeBase): From a4b029d03c952067287a9e5b0ceb2c2a58d18cc0 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 15 Aug 2023 18:21:31 -0400 Subject: [PATCH 02/51] write RAM usage and change after each generation --- invokeai/app/services/invocation_stats.py | 50 ++++++++++++++++------- 1 file changed, 36 insertions(+), 14 deletions(-) diff --git a/invokeai/app/services/invocation_stats.py b/invokeai/app/services/invocation_stats.py index 50320a6611..9d50375c09 100644 --- a/invokeai/app/services/invocation_stats.py +++ b/invokeai/app/services/invocation_stats.py @@ -29,6 +29,7 @@ The abstract base class for this class is InvocationStatsServiceBase. An impleme writes to the system log is stored in InvocationServices.performance_statistics. """ +import psutil import time from abc import ABC, abstractmethod from contextlib import AbstractContextManager @@ -83,13 +84,14 @@ class InvocationStatsServiceBase(ABC): pass @abstractmethod - def update_invocation_stats( - self, - graph_id: str, - invocation_type: str, - time_used: float, - vram_used: float, - ): + def update_invocation_stats(self, + graph_id: str, + invocation_type: str, + time_used: float, + vram_used: float, + ram_used: float, + ram_changed: float, + ): """ Add timing information on execution of a node. Usually used internally. @@ -97,6 +99,8 @@ class InvocationStatsServiceBase(ABC): :param invocation_type: String literal type of the node :param time_used: Time used by node's exection (sec) :param vram_used: Maximum VRAM used during exection (GB) + :param ram_used: Current RAM available (GB) + :param ram_changed: Change in RAM usage over course of the run (GB) """ pass @@ -140,18 +144,23 @@ class InvocationStatsService(InvocationStatsServiceBase): self.collector = collector self.graph_id = graph_id self.start_time = 0 + self.ram_info = None def __enter__(self): self.start_time = time.time() if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() + self.ram_info = psutil.virtual_memory() + def __exit__(self, *args): self.collector.update_invocation_stats( - self.graph_id, - self.invocation.type, - time.time() - self.start_time, - torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0, + graph_id = self.graph_id, + invocation_type = self.invocation.type, + time_used = time.time() - self.start_time, + vram_used = torch.cuda.max_memory_allocated() / 1e9 if torch.cuda.is_available() else 0.0, + ram_used = psutil.virtual_memory().used / 1e9, + ram_changed = (psutil.virtual_memory().used - self.ram_info.used) / 1e9, ) def collect_stats( @@ -179,13 +188,23 @@ class InvocationStatsService(InvocationStatsServiceBase): except KeyError: logger.warning(f"Attempted to clear statistics for unknown graph {graph_execution_id}") - def update_invocation_stats(self, graph_id: str, invocation_type: str, time_used: float, vram_used: float): + def update_invocation_stats(self, + graph_id: str, + invocation_type: str, + time_used: float, + vram_used: float, + ram_used: float, + ram_changed: float, + ): """ Add timing information on execution of a node. Usually used internally. :param graph_id: ID of the graph that is currently executing :param invocation_type: String literal type of the node - :param time_used: Floating point seconds used by node's exection + :param time_used: Time used by node's exection (sec) + :param vram_used: Maximum VRAM used during exection (GB) + :param ram_used: Current RAM available (GB) + :param ram_changed: Change in RAM usage over course of the run (GB) """ if not self._stats[graph_id].nodes.get(invocation_type): self._stats[graph_id].nodes[invocation_type] = NodeStats() @@ -193,6 +212,8 @@ class InvocationStatsService(InvocationStatsServiceBase): stats.calls += 1 stats.time_used += time_used stats.max_vram = max(stats.max_vram, vram_used) + stats.ram_used = ram_used + stats.ram_changed = ram_changed def log_stats(self): """ @@ -214,8 +235,9 @@ class InvocationStatsService(InvocationStatsServiceBase): total_time += stats.time_used logger.info(f"TOTAL GRAPH EXECUTION TIME: {total_time:7.3f}s") + logger.info("Current RAM used: " + "%4.2fG" % stats.ram_used + f" (delta={stats.ram_changed:4.2f}G)") if torch.cuda.is_available(): - logger.info("Current VRAM utilization " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9)) + logger.info("Current VRAM used: " + "%4.2fG" % (torch.cuda.memory_allocated() / 1e9)) completed.add(graph_id) From f49fc7fb5584c07c705a7792cfc2e7d75be79db5 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 14 Aug 2023 13:23:09 +1000 Subject: [PATCH 03/51] feat: node editor squashed rebase on main after backendd refactor --- invokeai/app/api_app.py | 7 +- invokeai/app/invocations/baseinvocation.py | 475 ++++++++- invokeai/app/invocations/collections.py | 104 +- invokeai/app/invocations/compel.py | 133 +-- .../controlnet_image_processors.py | 450 +++------ invokeai/app/invocations/cv.py | 30 +- invokeai/app/invocations/image.py | 352 +++---- invokeai/app/invocations/infill.py | 44 +- invokeai/app/invocations/latent.py | 186 ++-- invokeai/app/invocations/math.py | 102 +- invokeai/app/invocations/metadata.py | 122 ++- invokeai/app/invocations/model.py | 139 ++- invokeai/app/invocations/noise.py | 48 +- invokeai/app/invocations/onnx.py | 247 ++--- invokeai/app/invocations/param_easing.py | 117 +-- invokeai/app/invocations/params.py | 56 +- invokeai/app/invocations/prompt.py | 75 +- invokeai/app/invocations/sdxl.py | 73 +- invokeai/app/invocations/upscale.py | 14 +- invokeai/app/models/image.py | 14 +- invokeai/app/services/graph.py | 56 +- invokeai/app/services/processor.py | 5 +- invokeai/app/services/sqlite.py | 3 +- invokeai/frontend/web/package.json | 1 + invokeai/frontend/web/scripts/colors.js | 34 + invokeai/frontend/web/scripts/typegen.js | 66 +- .../web/src/app/components/GlobalHotkeys.ts | 33 +- .../web/src/app/components/InvokeAIUI.tsx | 6 +- .../frontend/web/src/app/logging/logger.ts | 3 +- .../middleware/devtools/actionsDenylist.ts | 2 +- .../listeners/imageDropped.ts | 36 +- .../listeners/imageUploaded.ts | 6 +- .../listeners/modelsLoaded.ts | 85 +- .../listeners/receivedOpenAPISchema.ts | 9 +- .../socketio/socketInvocationComplete.ts | 2 +- .../listeners/userInvokedNodes.ts | 2 +- .../frontend/web/src/app/types/invokeai.ts | 81 +- .../web/src/common/components/IAIDndImage.tsx | 25 +- .../src/common/components/IAIDraggable.tsx | 19 +- .../src/common/components/IAIDroppable.tsx | 10 +- .../common/components/IAIImageFallback.tsx | 4 +- .../web/src/common/components/IAISwitch.tsx | 45 +- .../src/common/hooks/useChakraThemeTokens.ts | 114 +++ .../frontend/web/src/common/util/serialize.ts | 8 +- .../components/ControlNetImagePreview.tsx | 2 +- .../src/features/controlNet/store/types.ts | 6 +- .../deleteImageModal/store/selectors.ts | 5 +- .../dnd/components/AppDndContext.tsx} | 49 +- .../dnd/components/DndContextTypesafe.tsx | 6 + .../dnd/components}/DragPreview.tsx | 33 +- .../src/features/dnd/hooks/typesafeHooks.ts | 15 + .../dnd/hooks/useScaledCenteredModifer.ts | 50 + .../dnd/types/index.ts} | 133 +-- .../web/src/features/dnd/util/isValidDrop.ts | 87 ++ .../Boards/BoardsList/GalleryBoard.tsx | 2 +- .../Boards/BoardsList/GenericBoard.tsx | 2 +- .../Boards/BoardsList/NoBoardBoard.tsx | 4 +- .../CurrentImage/CurrentImagePreview.tsx | 8 +- .../components/ImageGalleryContent.tsx | 2 + .../components/ImageGrid/GalleryImage.tsx | 10 +- .../components/ImageGrid/GalleryImageGrid.tsx | 2 +- .../ImageMetadataViewer/ImageMetadataJSON.tsx | 45 +- .../ImageMetadataViewer.tsx | 87 +- .../features/nodes/components/AddNodeMenu.tsx | 25 +- .../nodes/components/CustomConnectionLine.tsx | 61 ++ .../features/nodes/components/CustomEdges.tsx | 183 ++++ .../features/nodes/components/CustomNodes.tsx | 9 + .../features/nodes/components/FieldHandle.tsx | 64 -- .../nodes/components/FieldTypeLegend.tsx | 16 +- .../src/features/nodes/components/Flow.tsx | 94 +- .../components/IAINode/IAINodeHeader.tsx | 55 - .../components/IAINode/IAINodeInputs.tsx | 149 --- .../components/IAINode/IAINodeOutputs.tsx | 97 -- .../nodes/components/InputFieldComponent.tsx | 252 ----- .../Invocation/NodeCollapseButton.tsx | 57 ++ .../Invocation/NodeCollapsedHandles.tsx | 74 ++ .../components/Invocation/NodeFooter.tsx | 77 ++ .../components/Invocation/NodeNotesEdit.tsx | 113 +++ .../NodeResizer.tsx} | 7 +- .../components/Invocation/NodeSettings.tsx | 69 ++ .../Invocation/NodeStatusIndicator.tsx | 185 ++++ .../nodes/components/Invocation/NodeTitle.tsx | 123 +++ .../components/Invocation/NodeWrapper.tsx | 96 ++ .../nodes/components/InvocationComponent.tsx | 74 -- .../features/nodes/components/NodeEditor.tsx | 50 +- .../nodes/components/NodeEditorSettings.tsx | 139 +++ .../nodes/components/NodeGraphOverlay.tsx | 46 +- .../nodes/components/NodeOpacitySlider.tsx | 42 + .../features/nodes/components/NodeWrapper.tsx | 36 - .../nodes/components/ProgressImageNode.tsx | 73 -- .../nodes/components/ViewportControls.tsx | 39 +- .../BottomLeftPanel.tsx} | 7 +- .../{panels => editorPanels}/MinimapPanel.tsx | 5 +- .../TopCenterPanel.tsx | 6 +- .../{panels => editorPanels}/TopLeftPanel.tsx | 0 .../TopRightPanel.tsx | 8 +- .../fields/ArrayInputFieldComponent.tsx | 15 - .../fields/EnumInputFieldComponent.tsx | 37 - .../components/fields/FieldContextMenu.tsx | 47 + .../nodes/components/fields/FieldHandle.tsx | 122 +++ .../nodes/components/fields/FieldTitle.tsx | 161 +++ .../components/fields/FieldTooltipContent.tsx | 41 + .../nodes/components/fields/InputField.tsx | 153 +++ .../components/fields/InputFieldRenderer.tsx | 293 ++++++ .../fields/ItemInputFieldComponent.tsx | 15 - .../components/fields/LinearViewField.tsx | 88 ++ .../nodes/components/fields/OutputField.tsx | 114 +++ .../fields/StringInputFieldComponent.tsx | 36 - .../BooleanInputField.tsx} | 28 +- .../ClipInputField.tsx} | 0 .../fieldTypes/CollectionInputField.tsx | 17 + .../fieldTypes/CollectionItemInputField.tsx | 17 + .../ColorInputField.tsx} | 26 +- .../ConditioningInputField.tsx} | 0 .../ControlInputField.tsx} | 0 .../ControlNetModelInputField.tsx} | 11 +- .../fields/fieldTypes/EnumInputField.tsx | 45 + .../ImageCollectionInputField.tsx} | 15 +- .../ImageInputField.tsx} | 17 +- .../LatentsInputField.tsx} | 0 .../LoRAModelInputField.tsx} | 12 +- .../fields/fieldTypes/MainModelInputField.tsx | 144 +++ .../NumberInputField.tsx} | 28 +- .../RefinerModelInputField.tsx} | 22 +- .../SDXLMainModelInputField.tsx} | 51 +- .../fields/fieldTypes/StringInputField.tsx | 46 + .../UnetInputField.tsx} | 0 .../VaeInputField.tsx} | 0 .../VaeModelInputField.tsx} | 7 +- .../fields/{ => fieldTypes}/types.ts | 7 +- .../components/nodes/CurrentImageNode.tsx | 93 ++ .../nodes/components/nodes/InvocationNode.tsx | 127 +++ .../nodes/components/nodes/NotesNode.tsx | 73 ++ .../nodes/components/panel/InspectorPanel.tsx | 101 ++ .../components/panel/NodeDataInspector.tsx | 36 + .../components/panel/NodeEditorPanelGroup.tsx | 49 + .../components/panel/ScrollableContent.tsx | 45 + .../nodes/components/panel/WorkflowPanel.tsx | 52 + .../components/panel/workflow/GeneralTab.tsx | 142 +++ .../components/panel/workflow/LinearTab.tsx | 114 +++ .../components/panel/workflow/NotesTab.tsx | 51 + .../components/panel/workflow/WorkflowTab.tsx | 43 + .../nodes/components/search/NodeSearch.tsx | 21 +- .../nodes/components/ui/LoadGraphButton.tsx | 161 --- .../nodes/components/ui/NodeInvokeButton.tsx | 16 +- .../nodes/components/ui/SaveGraphButton.tsx | 48 - ...BuildInvocation.ts => useBuildNodeData.ts} | 68 +- .../nodes/hooks/useConnectionState.ts | 92 ++ .../nodes/hooks/useIsValidConnection.ts | 166 +-- .../nodes/store/nodesPersistDenylist.ts | 6 +- .../src/features/nodes/store/nodesSlice.ts | 614 ++++++++++- .../web/src/features/nodes/store/selectors.ts | 92 ++ .../web/src/features/nodes/store/types.ts | 31 +- .../util/makeIsConnectionValidSelector.ts | 92 ++ .../nodes/store/util/makeTemplateSelector.ts | 23 +- .../web/src/features/nodes/types/constants.ts | 210 ++-- .../web/src/features/nodes/types/types.ts | 434 +++++--- .../src/features/nodes/util/buildWorkflow.ts | 42 + .../nodes/util/fieldTemplateBuilders.ts | 415 +++++--- .../features/nodes/util/fieldValueBuilders.ts | 130 +-- .../addControlNetToLinearGraph.ts | 4 +- .../util/graphBuilders/addLoRAsToGraph.ts | 2 +- .../graphBuilders/addNSFWCheckerToGraph.ts | 2 +- .../util/graphBuilders/addSDXLLoRAstoGraph.ts | 3 + .../graphBuilders/addSDXLRefinerToGraph.ts | 4 +- .../graphBuilders/buildCanvasOutpaintGraph.ts | 4 +- .../buildCanvasSDXLOutpaintGraph.ts | 4 +- .../util/graphBuilders/buildNodesGraph.ts | 17 +- .../src/features/nodes/util/parseSchema.ts | 102 +- .../Parameters/ImageToImage/InitialImage.tsx | 8 +- .../SettingsModal/SettingsModal.tsx | 31 +- .../src/features/system/store/systemSlice.ts | 8 +- .../src/features/ui/components/InvokeTabs.tsx | 25 +- .../ui/components/tabs/Nodes/NodesTab.tsx | 9 +- .../ui/components/tabs/ResizeHandle.tsx | 55 +- .../UnifiedCanvas/UnifiedCanvasContent.tsx | 13 +- .../features/ui/hooks/useMinimumPanelSize.ts | 9 +- .../web/src/features/ui/store/hotkeysSlice.ts | 13 +- .../web/src/services/api/constants.ts | 3 + .../frontend/web/src/services/api/schema.d.ts | 954 ++++++++++-------- .../web/src/services/api/thunks/schema.ts | 6 +- .../frontend/web/src/services/api/types.ts | 298 ++---- .../web/src/theme/components/checkbox.ts | 9 +- .../web/src/theme/components/formLabel.ts | 3 + .../frontend/web/src/theme/components/tabs.ts | 47 +- .../web/src/theme/components/textarea.ts | 1 + .../web/src/theme/custom/reactflow.ts | 21 + invokeai/frontend/web/src/theme/theme.ts | 22 +- 188 files changed, 8541 insertions(+), 4660 deletions(-) create mode 100644 invokeai/frontend/web/scripts/colors.js rename invokeai/frontend/web/src/{app/components/ImageDnd/ImageDndContext.tsx => features/dnd/components/AppDndContext.tsx} (70%) create mode 100644 invokeai/frontend/web/src/features/dnd/components/DndContextTypesafe.tsx rename invokeai/frontend/web/src/{app/components/ImageDnd => features/dnd/components}/DragPreview.tsx (69%) create mode 100644 invokeai/frontend/web/src/features/dnd/hooks/typesafeHooks.ts create mode 100644 invokeai/frontend/web/src/features/dnd/hooks/useScaledCenteredModifer.ts rename invokeai/frontend/web/src/{app/components/ImageDnd/typesafeDnd.tsx => features/dnd/types/index.ts} (51%) create mode 100644 invokeai/frontend/web/src/features/dnd/util/isValidDrop.ts create mode 100644 invokeai/frontend/web/src/features/nodes/components/CustomConnectionLine.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/components/CustomEdges.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/components/CustomNodes.tsx delete mode 100644 invokeai/frontend/web/src/features/nodes/components/FieldHandle.tsx delete mode 100644 invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeHeader.tsx delete mode 100644 invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeInputs.tsx delete mode 100644 invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeOutputs.tsx delete mode 100644 invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/components/Invocation/NodeCollapseButton.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/components/Invocation/NodeCollapsedHandles.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/components/Invocation/NodeFooter.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/components/Invocation/NodeNotesEdit.tsx rename invokeai/frontend/web/src/features/nodes/components/{IAINode/IAINodeResizer.tsx => Invocation/NodeResizer.tsx} (73%) create mode 100644 invokeai/frontend/web/src/features/nodes/components/Invocation/NodeSettings.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/components/Invocation/NodeStatusIndicator.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/components/Invocation/NodeTitle.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/components/Invocation/NodeWrapper.tsx delete mode 100644 invokeai/frontend/web/src/features/nodes/components/InvocationComponent.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/components/NodeEditorSettings.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/components/NodeOpacitySlider.tsx delete mode 100644 invokeai/frontend/web/src/features/nodes/components/NodeWrapper.tsx delete mode 100644 invokeai/frontend/web/src/features/nodes/components/ProgressImageNode.tsx rename invokeai/frontend/web/src/features/nodes/components/{panels/BottomLeftPanel.tsx.tsx => editorPanels/BottomLeftPanel.tsx} (55%) rename invokeai/frontend/web/src/features/nodes/components/{panels => editorPanels}/MinimapPanel.tsx (91%) rename invokeai/frontend/web/src/features/nodes/components/{panels => editorPanels}/TopCenterPanel.tsx (79%) rename invokeai/frontend/web/src/features/nodes/components/{panels => editorPanels}/TopLeftPanel.tsx (100%) rename invokeai/frontend/web/src/features/nodes/components/{panels => editorPanels}/TopRightPanel.tsx (55%) delete mode 100644 invokeai/frontend/web/src/features/nodes/components/fields/ArrayInputFieldComponent.tsx delete mode 100644 invokeai/frontend/web/src/features/nodes/components/fields/EnumInputFieldComponent.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/components/fields/FieldContextMenu.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/components/fields/FieldHandle.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/components/fields/FieldTitle.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/components/fields/FieldTooltipContent.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/components/fields/InputField.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/components/fields/InputFieldRenderer.tsx delete mode 100644 invokeai/frontend/web/src/features/nodes/components/fields/ItemInputFieldComponent.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/components/fields/LinearViewField.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/components/fields/OutputField.tsx delete mode 100644 invokeai/frontend/web/src/features/nodes/components/fields/StringInputFieldComponent.tsx rename invokeai/frontend/web/src/features/nodes/components/fields/{BooleanInputFieldComponent.tsx => fieldTypes/BooleanInputField.tsx} (53%) rename invokeai/frontend/web/src/features/nodes/components/fields/{ClipInputFieldComponent.tsx => fieldTypes/ClipInputField.tsx} (100%) create mode 100644 invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/CollectionInputField.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/CollectionItemInputField.tsx rename invokeai/frontend/web/src/features/nodes/components/fields/{ColorInputFieldComponent.tsx => fieldTypes/ColorInputField.tsx} (57%) rename invokeai/frontend/web/src/features/nodes/components/fields/{ConditioningInputFieldComponent.tsx => fieldTypes/ConditioningInputField.tsx} (100%) rename invokeai/frontend/web/src/features/nodes/components/fields/{ControlInputFieldComponent.tsx => fieldTypes/ControlInputField.tsx} (100%) rename invokeai/frontend/web/src/features/nodes/components/fields/{ControlNetModelInputFieldComponent.tsx => fieldTypes/ControlNetModelInputField.tsx} (91%) create mode 100644 invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/EnumInputField.tsx rename invokeai/frontend/web/src/features/nodes/components/fields/{ImageCollectionInputFieldComponent.tsx => fieldTypes/ImageCollectionInputField.tsx} (86%) rename invokeai/frontend/web/src/features/nodes/components/fields/{ImageInputFieldComponent.tsx => fieldTypes/ImageInputField.tsx} (88%) rename invokeai/frontend/web/src/features/nodes/components/fields/{LatentsInputFieldComponent.tsx => fieldTypes/LatentsInputField.tsx} (100%) rename invokeai/frontend/web/src/features/nodes/components/fields/{LoRAModelInputFieldComponent.tsx => fieldTypes/LoRAModelInputField.tsx} (92%) create mode 100644 invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/MainModelInputField.tsx rename invokeai/frontend/web/src/features/nodes/components/fields/{NumberInputFieldComponent.tsx => fieldTypes/NumberInputField.tsx} (67%) rename invokeai/frontend/web/src/features/nodes/components/fields/{RefinerModelInputFieldComponent.tsx => fieldTypes/RefinerModelInputField.tsx} (89%) rename invokeai/frontend/web/src/features/nodes/components/fields/{ModelInputFieldComponent.tsx => fieldTypes/SDXLMainModelInputField.tsx} (78%) create mode 100644 invokeai/frontend/web/src/features/nodes/components/fields/fieldTypes/StringInputField.tsx rename invokeai/frontend/web/src/features/nodes/components/fields/{UnetInputFieldComponent.tsx => fieldTypes/UnetInputField.tsx} (100%) rename invokeai/frontend/web/src/features/nodes/components/fields/{VaeInputFieldComponent.tsx => fieldTypes/VaeInputField.tsx} (100%) rename invokeai/frontend/web/src/features/nodes/components/fields/{VaeModelInputFieldComponent.tsx => fieldTypes/VaeModelInputField.tsx} (93%) rename invokeai/frontend/web/src/features/nodes/components/fields/{ => fieldTypes}/types.ts (60%) create mode 100644 invokeai/frontend/web/src/features/nodes/components/nodes/CurrentImageNode.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/components/nodes/InvocationNode.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/components/nodes/NotesNode.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/components/panel/InspectorPanel.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/components/panel/NodeDataInspector.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/components/panel/NodeEditorPanelGroup.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/components/panel/ScrollableContent.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/components/panel/WorkflowPanel.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/components/panel/workflow/GeneralTab.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/components/panel/workflow/LinearTab.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/components/panel/workflow/NotesTab.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/components/panel/workflow/WorkflowTab.tsx delete mode 100644 invokeai/frontend/web/src/features/nodes/components/ui/LoadGraphButton.tsx delete mode 100644 invokeai/frontend/web/src/features/nodes/components/ui/SaveGraphButton.tsx rename invokeai/frontend/web/src/features/nodes/hooks/{useBuildInvocation.ts => useBuildNodeData.ts} (69%) create mode 100644 invokeai/frontend/web/src/features/nodes/hooks/useConnectionState.ts create mode 100644 invokeai/frontend/web/src/features/nodes/store/selectors.ts create mode 100644 invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts create mode 100644 invokeai/frontend/web/src/features/nodes/util/buildWorkflow.ts create mode 100644 invokeai/frontend/web/src/theme/custom/reactflow.ts diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index 6b875d37ce..20b2781ef0 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -38,7 +38,7 @@ import mimetypes from .api.dependencies import ApiDependencies from .api.routers import sessions, models, images, boards, board_images, app_info from .api.sockets import SocketIO -from .invocations.baseinvocation import BaseInvocation +from .invocations.baseinvocation import BaseInvocation, _InputField, _OutputField, UIConfigBase import torch @@ -134,6 +134,11 @@ def custom_openapi(): # This could break in some cases, figure out a better way to do it output_type_titles[schema_key] = output_schema["title"] + # Add Node Editor UI helper schemas + ui_config_schemas = schema([UIConfigBase, _InputField, _OutputField], ref_prefix="#/components/schemas/") + for schema_key, output_schema in ui_config_schemas["definitions"].items(): + openapi_schema["components"]["schemas"][schema_key] = output_schema + # Add a reference to the output type to additionalProperties of the invoker schema for invoker in all_invocations: invoker_name = invoker.__name__ diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 758ab2e787..65aeef75d8 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -3,15 +3,353 @@ from __future__ import annotations from abc import ABC, abstractmethod +from enum import Enum from inspect import signature -from typing import TYPE_CHECKING, Dict, List, Literal, TypedDict, get_args, get_type_hints +from typing import ( + TYPE_CHECKING, + AbstractSet, + Any, + Callable, + ClassVar, + Mapping, + Optional, + Type, + TypeVar, + Union, + get_args, + get_type_hints, +) -from pydantic import BaseConfig, BaseModel, Field +from pydantic import BaseModel, Field +from pydantic.fields import Undefined +from pydantic.typing import NoArgAnyCallable if TYPE_CHECKING: from ..services.invocation_services import InvocationServices +class FieldDescriptions: + denoising_start = "When to start denoising, expressed a percentage of total steps" + denoising_end = "When to stop denoising, expressed a percentage of total steps" + cfg_scale = "Classifier-Free Guidance scale" + scheduler = "Scheduler to use during inference" + positive_cond = "Positive conditioning tensor" + negative_cond = "Negative conditioning tensor" + noise = "Noise tensor" + clip = "CLIP (tokenizer, text encoder, LoRAs) and skipped layer count" + unet = "UNet (scheduler, LoRAs)" + vae = "VAE" + cond = "Conditioning tensor" + controlnet_model = "ControlNet model to load" + vae_model = "VAE model to load" + lora_model = "LoRA model to load" + main_model = "Main model (UNet, VAE, CLIP) to load" + sdxl_main_model = "SDXL Main model (UNet, VAE, CLIP1, CLIP2) to load" + sdxl_refiner_model = "SDXL Refiner Main Modde (UNet, VAE, CLIP2) to load" + onnx_main_model = "ONNX Main model (UNet, VAE, CLIP) to load" + lora_weight = "The weight at which the LoRA is applied to each model" + compel_prompt = "Prompt to be parsed by Compel to create a conditioning tensor" + raw_prompt = "Raw prompt text (no parsing)" + sdxl_aesthetic = "The aesthetic score to apply to the conditioning tensor" + skipped_layers = "Number of layers to skip in text encoder" + seed = "Seed for random number generation" + steps = "Number of steps to run" + width = "Width of output (px)" + height = "Height of output (px)" + control = "ControlNet(s) to apply" + denoised_latents = "Denoised latents tensor" + latents = "Latents tensor" + strength = "Strength of denoising (proportional to steps)" + core_metadata = "Optional core metadata to be written to image" + interp_mode = "Interpolation mode" + torch_antialias = "Whether or not to apply antialiasing (bilinear or bicubic only)" + fp32 = "Whether or not to use full float32 precision" + precision = "Precision to use" + tiled = "Processing using overlapping tiles (reduce memory consumption)" + detect_res = "Pixel resolution for detection" + image_res = "Pixel resolution for output image" + safe_mode = "Whether or not to use safe mode" + scribble_mode = "Whether or not to use scribble mode" + scale_factor = "The factor by which to scale" + num_1 = "The first number" + num_2 = "The second number" + mask = "The mask to use for the operation" + + +class Input(str, Enum): + """ + The type of input a field accepts. + - `Input.Direct`: The field must have its value provided directly, when the invocation and field \ + are instantiated. + - `Input.Connection`: The field must have its value provided by a connection. + - `Input.Any`: The field may have its value provided either directly or by a connection. + """ + + Connection = "connection" + Direct = "direct" + Any = "any" + + +class UITypeHint(str, Enum): + """ + Type hints for the UI. + If a field should be provided a data type that does not exactly match the python type of the field, \ + use this to provide the type that should be used instead. See the node development docs for detail \ + on adding a new field type, which involves client-side changes. + """ + + Integer = "integer" + Float = "float" + Boolean = "boolean" + String = "string" + Enum = "enum" + Array = "array" + ImageField = "ImageField" + LatentsField = "LatentsField" + ConditioningField = "ConditioningField" + ControlField = "ControlField" + MainModelField = "MainModelField" + SDXLMainModelField = "SDXLMainModelField" + SDXLRefinerModelField = "SDXLRefinerModelField" + ONNXModelField = "ONNXModelField" + VaeModelField = "VaeModelField" + LoRAModelField = "LoRAModelField" + ControlNetModelField = "ControlNetModelField" + UNetField = "UNetField" + VaeField = "VaeField" + ClipField = "ClipField" + ColorField = "ColorField" + ImageCollection = "ImageCollection" + IntegerCollection = "IntegerCollection" + FloatCollection = "FloatCollection" + StringCollection = "StringCollection" + BooleanCollection = "BooleanCollection" + Collection = "Collection" + CollectionItem = "CollectionItem" + Seed = "Seed" + FilePath = "FilePath" + + +class UIComponent(str, Enum): + """ + The type of UI component to use for a field, used to override the default components, which are \ + inferred from the field type. + """ + + None_ = "none" + Textarea = "textarea" + Slider = "slider" + + +class _InputField(BaseModel): + """ + *DO NOT USE* + This helper class is used to tell the client about our custom field attributes via OpenAPI + schema generation, and Typescript type generation from that schema. It serves no functional + purpose in the backend. + """ + + input: Input + ui_hidden: bool + ui_type_hint: Optional[UITypeHint] + ui_component: Optional[UIComponent] + + +class _OutputField(BaseModel): + """ + *DO NOT USE* + This helper class is used to tell the client about our custom field attributes via OpenAPI + schema generation, and Typescript type generation from that schema. It serves no functional + purpose in the backend. + """ + + ui_hidden: bool + ui_type_hint: Optional[UITypeHint] + + +def InputField( + *args: Any, + default: Any = Undefined, + default_factory: Optional[NoArgAnyCallable] = None, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + exclude: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None, + include: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None, + const: Optional[bool] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + multiple_of: Optional[float] = None, + allow_inf_nan: Optional[bool] = None, + max_digits: Optional[int] = None, + decimal_places: Optional[int] = None, + min_items: Optional[int] = None, + max_items: Optional[int] = None, + unique_items: Optional[bool] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + allow_mutation: bool = True, + regex: Optional[str] = None, + discriminator: Optional[str] = None, + repr: bool = True, + input: Input = Input.Any, + ui_type_hint: Optional[UITypeHint] = None, + ui_component: Optional[UIComponent] = None, + ui_hidden: bool = False, + **kwargs: Any, +) -> Any: + """ + Creates an input field for an invocation. + + This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \ + that adds a few extra parameters to support graph execution and the node editor UI. + + :param Input input: [Input.Any] The kind of input this field requires. \ + `Input.Direct` means a value must be provided on instantiation. \ + `Input.Connection` means the value must be provided by a connection. \ + `Input.Any` means either will do. + + :param UITypeHint ui_type_hint: [None] Optionally provides an extra type hint for the UI. \ + In some situations, the field's type is not enough to infer the correct UI type. \ + For example, model selection fields should render a dropdown UI component to select a model. \ + Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \ + `MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \ + `UITypeHint.SDXLMainModelField` to indicate that the field is an SDXL main model field. + + :param UIComponent ui_component: [None] Optionally specifies a specific component to use in the UI. \ + The UI will always render a suitable component, but sometimes you want something different than the default. \ + For example, a `string` field will default to a single-line input, but you may want a multi-line textarea instead. \ + For this case, you could provide `UIComponent.Textarea`. + + : param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. + """ + return Field( + *args, + default=default, + default_factory=default_factory, + alias=alias, + title=title, + description=description, + exclude=exclude, + include=include, + const=const, + gt=gt, + ge=ge, + lt=lt, + le=le, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + min_items=min_items, + max_items=max_items, + unique_items=unique_items, + min_length=min_length, + max_length=max_length, + allow_mutation=allow_mutation, + regex=regex, + discriminator=discriminator, + repr=repr, + input=input, + ui_type_hint=ui_type_hint, + ui_component=ui_component, + ui_hidden=ui_hidden, + **kwargs, + ) + + +def OutputField( + *args: Any, + default: Any = Undefined, + default_factory: Optional[NoArgAnyCallable] = None, + alias: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + exclude: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None, + include: Optional[Union[AbstractSet[Union[int, str]], Mapping[Union[int, str], Any], Any]] = None, + const: Optional[bool] = None, + gt: Optional[float] = None, + ge: Optional[float] = None, + lt: Optional[float] = None, + le: Optional[float] = None, + multiple_of: Optional[float] = None, + allow_inf_nan: Optional[bool] = None, + max_digits: Optional[int] = None, + decimal_places: Optional[int] = None, + min_items: Optional[int] = None, + max_items: Optional[int] = None, + unique_items: Optional[bool] = None, + min_length: Optional[int] = None, + max_length: Optional[int] = None, + allow_mutation: bool = True, + regex: Optional[str] = None, + discriminator: Optional[str] = None, + repr: bool = True, + ui_type_hint: Optional[UITypeHint] = None, + ui_hidden: bool = False, + **kwargs: Any, +) -> Any: + """ + Creates an output field for an invocation output. + + This is a wrapper for Pydantic's [Field](https://docs.pydantic.dev/1.10/usage/schema/#field-customization) \ + that adds a few extra parameters to support graph execution and the node editor UI. + + :param UITypeHint ui_type_hint: [None] Optionally provides an extra type hint for the UI. \ + In some situations, the field's type is not enough to infer the correct UI type. \ + For example, model selection fields should render a dropdown UI component to select a model. \ + Internally, there is no difference between SD-1, SD-2 and SDXL model fields, they all use \ + `MainModelField`. So to ensure the base-model-specific UI is rendered, you can use \ + `UITypeHint.SDXLMainModelField` to indicate that the field is an SDXL main model field. + + : param bool ui_hidden: [False] Specifies whether or not this field should be hidden in the UI. \ + """ + return Field( + *args, + default=default, + default_factory=default_factory, + alias=alias, + title=title, + description=description, + exclude=exclude, + include=include, + const=const, + gt=gt, + ge=ge, + lt=lt, + le=le, + multiple_of=multiple_of, + allow_inf_nan=allow_inf_nan, + max_digits=max_digits, + decimal_places=decimal_places, + min_items=min_items, + max_items=max_items, + unique_items=unique_items, + min_length=min_length, + max_length=max_length, + allow_mutation=allow_mutation, + regex=regex, + discriminator=discriminator, + repr=repr, + ui_type_hint=ui_type_hint, + ui_hidden=ui_hidden, + **kwargs, + ) + + +class UIConfigBase(BaseModel): + """ + Provides additional node configuration to the UI. + This is used internally by the @tags and @title decorator logic. You probably want to use those + decorators, though you may add this class to a node definition to specify the title and tags. + """ + + tags: Optional[list[str]] = Field(default_factory=None, description="The tags to display in the UI") + title: Optional[str] = Field(default=None, description="The display name of the node") + + class InvocationContext: services: InvocationServices graph_execution_state_id: str @@ -39,6 +377,20 @@ class BaseInvocationOutput(BaseModel): return tuple(subclasses) +class RequiredConnectionException(Exception): + """Raised when an field which requires a connection did not receive a value.""" + + def __init__(self, node_id: str, field_name: str): + super().__init__(f"Node {node_id} missing connections for field {field_name}") + + +class MissingInputException(Exception): + """Raised when an field which requires some input, but did not receive a value.""" + + def __init__(self, node_id: str, field_name: str): + super().__init__(f"Node {node_id} missing value or connection for field {field_name}") + + class BaseInvocation(ABC, BaseModel): """A node to process inputs and produce outputs. May use dependency injection in __init__ to receive providers. @@ -76,70 +428,81 @@ class BaseInvocation(ABC, BaseModel): def get_output_type(cls): return signature(cls.invoke).return_annotation + class Config: + @staticmethod + def schema_extra(schema: dict[str, Any], model_class: Type[BaseModel]) -> None: + uiconfig = getattr(model_class, "UIConfig", None) + if uiconfig and hasattr(uiconfig, "title"): + schema["title"] = uiconfig.title + if uiconfig and hasattr(uiconfig, "tags"): + schema["tags"] = uiconfig.tags + @abstractmethod def invoke(self, context: InvocationContext) -> BaseInvocationOutput: """Invoke with provided context and return outputs.""" pass - # fmt: off - id: str = Field(description="The id of this node. Must be unique among all nodes.") - is_intermediate: bool = Field(default=False, description="Whether or not this node is an intermediate node.") - # fmt: on + def __init__(self, **data): + # nodes may have required fields, that can accept input from connections + # on instantiation of the model, we need to exclude these from validation + restore = dict() + try: + field_names = list(self.__fields__.keys()) + for field_name in field_names: + # if the field is required and may get its value from a connection, exclude it from validation + field = self.__fields__[field_name] + _input = field.field_info.extra.get("input", None) + if _input in [Input.Connection, Input.Any] and field.required: + if field_name not in data: + restore[field_name] = self.__fields__.pop(field_name) + # instantiate the node, which will validate the data + super().__init__(**data) + finally: + # restore the removed fields + for field_name, field in restore.items(): + self.__fields__[field_name] = field + + def invoke_internal(self, context: InvocationContext) -> BaseInvocationOutput: + for field_name, field in self.__fields__.items(): + _input = field.field_info.extra.get("input", None) + if field.required and not hasattr(self, field_name): + if _input == Input.Connection: + raise RequiredConnectionException(self.__fields__["type"].default, field_name) + elif _input == Input.Any: + raise MissingInputException(self.__fields__["type"].default, field_name) + return self.invoke(context) + + id: str = InputField(description="The id of this node. Must be unique among all nodes.") + is_intermediate: bool = InputField( + default=False, description="Whether or not this node is an intermediate node.", input=Input.Direct + ) + UIConfig: ClassVar[Type[UIConfigBase]] -# TODO: figure out a better way to provide these hints -# TODO: when we can upgrade to python 3.11, we can use the`NotRequired` type instead of `total=False` -class UIConfig(TypedDict, total=False): - type_hints: Dict[ - str, - Literal[ - "integer", - "float", - "boolean", - "string", - "enum", - "image", - "latents", - "model", - "control", - "image_collection", - "vae_model", - "lora_model", - ], - ] - tags: List[str] - title: str +T = TypeVar("T", bound=BaseInvocation) -class CustomisedSchemaExtra(TypedDict): - ui: UIConfig +def title(title: str) -> Callable[[Type[T]], Type[T]]: + """Adds a title to the invocation. Use this to override the default title generation, which is based on the class name.""" + + def wrapper(cls: Type[T]) -> Type[T]: + uiconf_name = cls.__qualname__ + ".UIConfig" + if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name: + cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict()) + cls.UIConfig.title = title + return cls + + return wrapper -class InvocationConfig(BaseConfig): - """Customizes pydantic's BaseModel.Config class for use by Invocations. +def tags(*tags: str) -> Callable[[Type[T]], Type[T]]: + """Adds tags to the invocation. Use this to improve the streamline finding the invocation in the UI.""" - Provide `schema_extra` a `ui` dict to add hints for generated UIs. + def wrapper(cls: Type[T]) -> Type[T]: + uiconf_name = cls.__qualname__ + ".UIConfig" + if not hasattr(cls, "UIConfig") or cls.UIConfig.__qualname__ != uiconf_name: + cls.UIConfig = type(uiconf_name, (UIConfigBase,), dict()) + cls.UIConfig.tags = list(tags) + return cls - `tags` - - A list of strings, used to categorise invocations. - - `type_hints` - - A dict of field types which override the types in the invocation definition. - - Each key should be the name of one of the invocation's fields. - - Each value should be one of the valid types: - - `integer`, `float`, `boolean`, `string`, `enum`, `image`, `latents`, `model` - - ```python - class Config(InvocationConfig): - schema_extra = { - "ui": { - "tags": ["stable-diffusion", "image"], - "type_hints": { - "initial_image": "image", - }, - }, - } - ``` - """ - - schema_extra: CustomisedSchemaExtra + return wrapper diff --git a/invokeai/app/invocations/collections.py b/invokeai/app/invocations/collections.py index 01c003da96..0dd3b757dc 100644 --- a/invokeai/app/invocations/collections.py +++ b/invokeai/app/invocations/collections.py @@ -3,58 +3,78 @@ from typing import Literal import numpy as np -from pydantic import Field, validator +from pydantic import validator from invokeai.app.models.image import ImageField from invokeai.app.util.misc import SEED_MAX, get_random_seed -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext, UIConfig +from .baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + InputField, + InvocationContext, + OutputField, + UITypeHint, + tags, + title, +) class IntCollectionOutput(BaseInvocationOutput): """A collection of integers""" - type: Literal["int_collection"] = "int_collection" + type: Literal["int_collection_output"] = "int_collection_output" # Outputs - collection: list[int] = Field(default=[], description="The int collection") + collection: list[int] = OutputField( + default=[], description="The int collection", ui_type_hint=UITypeHint.IntegerCollection + ) class FloatCollectionOutput(BaseInvocationOutput): """A collection of floats""" - type: Literal["float_collection"] = "float_collection" + type: Literal["float_collection_output"] = "float_collection_output" # Outputs - collection: list[float] = Field(default=[], description="The float collection") + collection: list[float] = OutputField( + default=[], description="The float collection", ui_type_hint=UITypeHint.FloatCollection + ) + + +class StringCollectionOutput(BaseInvocationOutput): + """A collection of strings""" + + type: Literal["string_collection_output"] = "string_collection_output" + + # Outputs + collection: list[str] = OutputField( + default=[], description="The output strings", ui_type_hint=UITypeHint.StringCollection + ) class ImageCollectionOutput(BaseInvocationOutput): """A collection of images""" - type: Literal["image_collection"] = "image_collection" + type: Literal["image_collection_output"] = "image_collection_output" # Outputs - collection: list[ImageField] = Field(default=[], description="The output images") - - class Config: - schema_extra = {"required": ["type", "collection"]} + collection: list[ImageField] = OutputField( + default=[], description="The output images", ui_type_hint=UITypeHint.ImageCollection + ) +@title("Integer Range") +@tags("collection", "integer", "range") class RangeInvocation(BaseInvocation): """Creates a range of numbers from start to stop with step""" type: Literal["range"] = "range" # Inputs - start: int = Field(default=0, description="The start of the range") - stop: int = Field(default=10, description="The stop of the range") - step: int = Field(default=1, description="The step of the range") - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Range", "tags": ["range", "integer", "collection"]}, - } + start: int = InputField(default=0, description="The start of the range") + stop: int = InputField(default=10, description="The stop of the range") + step: int = InputField(default=1, description="The step of the range") @validator("stop") def stop_gt_start(cls, v, values): @@ -66,72 +86,56 @@ class RangeInvocation(BaseInvocation): return IntCollectionOutput(collection=list(range(self.start, self.stop, self.step))) +@title("Integer Range of Size") +@tags("range", "integer", "size", "collection") class RangeOfSizeInvocation(BaseInvocation): """Creates a range from start to start + size with step""" type: Literal["range_of_size"] = "range_of_size" # Inputs - start: int = Field(default=0, description="The start of the range") - size: int = Field(default=1, description="The number of values") - step: int = Field(default=1, description="The step of the range") - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Sized Range", "tags": ["range", "integer", "size", "collection"]}, - } + start: int = InputField(default=0, description="The start of the range") + size: int = InputField(default=1, description="The number of values") + step: int = InputField(default=1, description="The step of the range") def invoke(self, context: InvocationContext) -> IntCollectionOutput: return IntCollectionOutput(collection=list(range(self.start, self.start + self.size, self.step))) +@title("Random Range") +@tags("range", "integer", "random", "collection") class RandomRangeInvocation(BaseInvocation): """Creates a collection of random numbers""" type: Literal["random_range"] = "random_range" # Inputs - low: int = Field(default=0, description="The inclusive low value") - high: int = Field(default=np.iinfo(np.int32).max, description="The exclusive high value") - size: int = Field(default=1, description="The number of values to generate") - seed: int = Field( + low: int = InputField(default=0, description="The inclusive low value") + high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value") + size: int = InputField(default=1, description="The number of values to generate") + seed: int = InputField( ge=0, le=SEED_MAX, description="The seed for the RNG (omit for random)", default_factory=get_random_seed, ) - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Random Range", "tags": ["range", "integer", "random", "collection"]}, - } - def invoke(self, context: InvocationContext) -> IntCollectionOutput: rng = np.random.default_rng(self.seed) return IntCollectionOutput(collection=list(rng.integers(low=self.low, high=self.high, size=self.size))) +@title("Image Collection") +@tags("image", "collection") class ImageCollectionInvocation(BaseInvocation): """Load a collection of images and provide it as output.""" - # fmt: off type: Literal["image_collection"] = "image_collection" # Inputs - images: list[ImageField] = Field( - default=[], description="The image collection to load" + images: list[ImageField] = InputField( + default=[], description="The image collection to load", ui_type_hint=UITypeHint.ImageCollection ) - # fmt: on def invoke(self, context: InvocationContext) -> ImageCollectionOutput: return ImageCollectionOutput(collection=self.images) - - class Config(InvocationConfig): - schema_extra = { - "ui": { - "type_hints": { - "title": "Image Collection", - "images": "image_collection", - } - }, - } diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 86565366d9..0f7c61a6dd 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -1,29 +1,39 @@ -from typing import Literal, Optional, Union, List, Annotated -from pydantic import BaseModel, Field import re - -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig -from .model import ClipField - -from ...backend.util.devices import torch_dtype -from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent -from ...backend.model_management import BaseModelType, ModelType, SubModelType, ModelPatcher +from dataclasses import dataclass +from typing import List, Literal, Union import torch from compel import Compel, ReturnedEmbeddingsType from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment -from ...backend.util.devices import torch_dtype -from ...backend.model_management import ModelType -from ...backend.model_management.models import ModelNotFoundException +from pydantic import BaseModel, Field + +from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import ( + BasicConditioningInfo, + SDXLConditioningInfo, +) + +from ...backend.model_management import ModelPatcher, ModelType from ...backend.model_management.lora import ModelPatcher -from ...backend.stable_diffusion import InvokeAIDiffuserComponent, BasicConditioningInfo, SDXLConditioningInfo -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext +from ...backend.model_management.models import ModelNotFoundException +from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent +from ...backend.util.devices import torch_dtype +from .baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + FieldDescriptions, + Input, + InputField, + InvocationContext, + OutputField, + UIComponent, + tags, + title, +) from .model import ClipField -from dataclasses import dataclass class ConditioningField(BaseModel): - conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data") + conditioning_name: str = Field(description="The name of conditioning data") class Config: schema_extra = {"required": ["conditioning_name"]} @@ -47,23 +57,27 @@ class CompelOutput(BaseInvocationOutput): # fmt: off type: Literal["compel_output"] = "compel_output" - conditioning: ConditioningField = Field(default=None, description="Conditioning") + conditioning: ConditioningField = OutputField(description=FieldDescriptions.cond) # fmt: on +@title("Compel Prompt") +@tags("prompt", "compel") class CompelInvocation(BaseInvocation): """Parse prompt using compel package to conditioning.""" type: Literal["compel"] = "compel" - prompt: str = Field(default="", description="Prompt") - clip: ClipField = Field(None, description="Clip to use") - - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}}, - } + prompt: str = InputField( + default="", + description=FieldDescriptions.compel_prompt, + ui_component=UIComponent.Textarea, + ) + clip: ClipField = InputField( + title="CLIP", + description=FieldDescriptions.clip, + input=Input.Connection, + ) @torch.no_grad() def invoke(self, context: InvocationContext) -> CompelOutput: @@ -270,27 +284,23 @@ class SDXLPromptInvocationBase: return c, c_pooled, ec +@title("SDXL Compel Prompt") +@tags("sdxl", "compel", "prompt") class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" type: Literal["sdxl_compel_prompt"] = "sdxl_compel_prompt" - prompt: str = Field(default="", description="Prompt") - style: str = Field(default="", description="Style prompt") - original_width: int = Field(1024, description="") - original_height: int = Field(1024, description="") - crop_top: int = Field(0, description="") - crop_left: int = Field(0, description="") - target_width: int = Field(1024, description="") - target_height: int = Field(1024, description="") - clip: ClipField = Field(None, description="Clip to use") - clip2: ClipField = Field(None, description="Clip2 to use") - - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "SDXL Prompt (Compel)", "tags": ["prompt", "compel"], "type_hints": {"model": "model"}}, - } + prompt: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea) + style: str = InputField(default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea) + original_width: int = InputField(default=1024, description="") + original_height: int = InputField(default=1024, description="") + crop_top: int = InputField(default=0, description="") + crop_left: int = InputField(default=0, description="") + target_width: int = InputField(default=1024, description="") + target_height: int = InputField(default=1024, description="") + clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) + clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) @torch.no_grad() def invoke(self, context: InvocationContext) -> CompelOutput: @@ -333,28 +343,22 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): ) +@title("SDXL Refiner Compel Prompt") +@tags("sdxl", "compel", "prompt") class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): """Parse prompt using compel package to conditioning.""" type: Literal["sdxl_refiner_compel_prompt"] = "sdxl_refiner_compel_prompt" - style: str = Field(default="", description="Style prompt") # TODO: ? - original_width: int = Field(1024, description="") - original_height: int = Field(1024, description="") - crop_top: int = Field(0, description="") - crop_left: int = Field(0, description="") - aesthetic_score: float = Field(6.0, description="") - clip2: ClipField = Field(None, description="Clip to use") - - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "SDXL Refiner Prompt (Compel)", - "tags": ["prompt", "compel"], - "type_hints": {"model": "model"}, - }, - } + style: str = InputField( + default="", description=FieldDescriptions.compel_prompt, ui_component=UIComponent.Textarea + ) # TODO: ? + original_width: int = InputField(default=1024, description="") + original_height: int = InputField(default=1024, description="") + crop_top: int = InputField(default=0, description="") + crop_left: int = InputField(default=0, description="") + aesthetic_score: float = InputField(default=6.0, description=FieldDescriptions.sdxl_aesthetic) + clip2: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) @torch.no_grad() def invoke(self, context: InvocationContext) -> CompelOutput: @@ -391,21 +395,18 @@ class ClipSkipInvocationOutput(BaseInvocationOutput): """Clip skip node output""" type: Literal["clip_skip_output"] = "clip_skip_output" - clip: ClipField = Field(None, description="Clip with skipped layers") + clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP") +@title("CLIP Skip") +@tags("clipskip", "clip", "skip") class ClipSkipInvocation(BaseInvocation): """Skip layers in clip text_encoder model.""" type: Literal["clip_skip"] = "clip_skip" - clip: ClipField = Field(None, description="Clip to use") - skipped_layers: int = Field(0, description="Number of layers to skip in text_encoder") - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "CLIP Skip", "tags": ["clip", "skip"]}, - } + clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection, title="CLIP") + skipped_layers: int = InputField(default=0, description=FieldDescriptions.skipped_layers) def invoke(self, context: InvocationContext) -> ClipSkipInvocationOutput: self.clip.skipped_layers += self.skipped_layers diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index d2b2d44526..de8ad00026 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -28,77 +28,27 @@ from pydantic import BaseModel, Field, validator from ...backend.model_management import BaseModelType, ModelType from ..models.image import ImageCategory, ImageField, ResourceOrigin -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext -from ..models.image import ImageOutput, PILInvocationConfig +from .baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + FieldDescriptions, + InputField, + Input, + InvocationContext, + OutputField, + UITypeHint, + tags, + title, +) +from ..models.image import ImageOutput -CONTROLNET_DEFAULT_MODELS = [ - ########################################### - # lllyasviel sd v1.5, ControlNet v1.0 models - ############################################## - "lllyasviel/sd-controlnet-canny", - "lllyasviel/sd-controlnet-depth", - "lllyasviel/sd-controlnet-hed", - "lllyasviel/sd-controlnet-seg", - "lllyasviel/sd-controlnet-openpose", - "lllyasviel/sd-controlnet-scribble", - "lllyasviel/sd-controlnet-normal", - "lllyasviel/sd-controlnet-mlsd", - ############################################# - # lllyasviel sd v1.5, ControlNet v1.1 models - ############################################# - "lllyasviel/control_v11p_sd15_canny", - "lllyasviel/control_v11p_sd15_openpose", - "lllyasviel/control_v11p_sd15_seg", - # "lllyasviel/control_v11p_sd15_depth", # broken - "lllyasviel/control_v11f1p_sd15_depth", - "lllyasviel/control_v11p_sd15_normalbae", - "lllyasviel/control_v11p_sd15_scribble", - "lllyasviel/control_v11p_sd15_mlsd", - "lllyasviel/control_v11p_sd15_softedge", - "lllyasviel/control_v11p_sd15s2_lineart_anime", - "lllyasviel/control_v11p_sd15_lineart", - "lllyasviel/control_v11p_sd15_inpaint", - # "lllyasviel/control_v11u_sd15_tile", - # problem (temporary?) with huffingface "lllyasviel/control_v11u_sd15_tile", - # so for now replace "lllyasviel/control_v11f1e_sd15_tile", - "lllyasviel/control_v11e_sd15_shuffle", - "lllyasviel/control_v11e_sd15_ip2p", - "lllyasviel/control_v11f1e_sd15_tile", - ################################################# - # thibaud sd v2.1 models (ControlNet v1.0? or v1.1? - ################################################## - "thibaud/controlnet-sd21-openpose-diffusers", - "thibaud/controlnet-sd21-canny-diffusers", - "thibaud/controlnet-sd21-depth-diffusers", - "thibaud/controlnet-sd21-scribble-diffusers", - "thibaud/controlnet-sd21-hed-diffusers", - "thibaud/controlnet-sd21-zoedepth-diffusers", - "thibaud/controlnet-sd21-color-diffusers", - "thibaud/controlnet-sd21-openposev2-diffusers", - "thibaud/controlnet-sd21-lineart-diffusers", - "thibaud/controlnet-sd21-normalbae-diffusers", - "thibaud/controlnet-sd21-ade20k-diffusers", - ############################################## - # ControlNetMediaPipeface, ControlNet v1.1 - ############################################## - # ["CrucibleAI/ControlNetMediaPipeFace", "diffusion_sd15"], # SD 1.5 - # diffusion_sd15 needs to be passed to from_pretrained() as subfolder arg - # hacked t2l to split to model & subfolder if format is "model,subfolder" - "CrucibleAI/ControlNetMediaPipeFace,diffusion_sd15", # SD 1.5 - "CrucibleAI/ControlNetMediaPipeFace", # SD 2.1? -] -CONTROLNET_NAME_VALUES = Literal[tuple(CONTROLNET_DEFAULT_MODELS)] -CONTROLNET_MODE_VALUES = Literal[tuple(["balanced", "more_prompt", "more_control", "unbalanced"])] +CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"] CONTROLNET_RESIZE_VALUES = Literal[ - tuple( - [ - "just_resize", - "crop_resize", - "fill_resize", - "just_resize_simple", - ] - ) + "just_resize", + "crop_resize", + "fill_resize", + "just_resize_simple", ] @@ -110,9 +60,8 @@ class ControlNetModelField(BaseModel): class ControlField(BaseModel): - image: ImageField = Field(default=None, description="The control image") - control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use") - # control_weight: Optional[float] = Field(default=1, description="weight given to controlnet") + image: ImageField = Field(description="The control image") + control_model: ControlNetModelField = Field(description="The ControlNet model to use") control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet") begin_step_percent: float = Field( default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)" @@ -135,60 +84,39 @@ class ControlField(BaseModel): raise ValueError("Control weights must be within -1 to 2 range") return v - class Config: - schema_extra = { - "required": ["image", "control_model", "control_weight", "begin_step_percent", "end_step_percent"], - "ui": { - "type_hints": { - "control_weight": "float", - "control_model": "controlnet_model", - # "control_weight": "number", - } - }, - } - class ControlOutput(BaseInvocationOutput): """node output for ControlNet info""" - # fmt: off type: Literal["control_output"] = "control_output" - control: ControlField = Field(default=None, description="The control info") - # fmt: on + + # Outputs + control: ControlField = OutputField(description=FieldDescriptions.control) +@title("ControlNet") +@tags("controlnet") class ControlNetInvocation(BaseInvocation): """Collects ControlNet info to pass to other nodes""" - # fmt: off type: Literal["controlnet"] = "controlnet" - # Inputs - image: ImageField = Field(default=None, description="The control image") - control_model: ControlNetModelField = Field(default="lllyasviel/sd-controlnet-canny", - description="control model used") - control_weight: Union[float, List[float]] = Field(default=1.0, description="The weight given to the ControlNet") - begin_step_percent: float = Field(default=0, ge=-1, le=2, - description="When the ControlNet is first applied (% of total steps)") - end_step_percent: float = Field(default=1, ge=0, le=1, - description="When the ControlNet is last applied (% of total steps)") - control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode used") - resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode used") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "ControlNet", - "tags": ["controlnet", "latents"], - "type_hints": { - "model": "model", - "control": "control", - # "cfg_scale": "float", - "cfg_scale": "number", - "control_weight": "float", - }, - }, - } + # Inputs + image: ImageField = InputField(description="The control image") + control_model: ControlNetModelField = InputField( + default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct + ) + control_weight: Union[float, List[float]] = InputField( + default=1.0, description="The weight given to the ControlNet", ui_type_hint=UITypeHint.Float + ) + begin_step_percent: float = InputField( + default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)" + ) + end_step_percent: float = InputField( + default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)" + ) + control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used") + resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used") def invoke(self, context: InvocationContext) -> ControlOutput: return ControlOutput( @@ -204,19 +132,13 @@ class ControlNetInvocation(BaseInvocation): ) -class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig): +class ImageProcessorInvocation(BaseInvocation): """Base class for invocations that preprocess images for ControlNet""" - # fmt: off type: Literal["image_processor"] = "image_processor" - # Inputs - image: ImageField = Field(default=None, description="The image to process") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Image Processor", "tags": ["image", "processor"]}, - } + # Inputs + image: ImageField = InputField(description="The image to process") def run_processor(self, image): # superclass just passes through image without processing @@ -255,20 +177,20 @@ class ImageProcessorInvocation(BaseInvocation, PILInvocationConfig): ) -class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): +@title("Canny Processor") +@tags("controlnet", "canny") +class CannyImageProcessorInvocation(ImageProcessorInvocation): """Canny edge detection for ControlNet""" - # fmt: off type: Literal["canny_image_processor"] = "canny_image_processor" - # Input - low_threshold: int = Field(default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)") - high_threshold: int = Field(default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Canny Processor", "tags": ["controlnet", "canny", "image", "processor"]}, - } + # Input + low_threshold: int = InputField( + default=100, ge=0, le=255, description="The low threshold of the Canny pixel gradient (0-255)" + ) + high_threshold: int = InputField( + default=200, ge=0, le=255, description="The high threshold of the Canny pixel gradient (0-255)" + ) def run_processor(self, image): canny_processor = CannyDetector() @@ -276,23 +198,19 @@ class CannyImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi return processed_image -class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): +@title("HED (softedge) Processor") +@tags("controlnet", "hed", "softedge") +class HedImageProcessorInvocation(ImageProcessorInvocation): """Applies HED edge detection to image""" - # fmt: off type: Literal["hed_image_processor"] = "hed_image_processor" - # Inputs - detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection") - image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image") - # safe not supported in controlnet_aux v0.0.3 - # safe: bool = Field(default=False, description="whether to use safe mode") - scribble: bool = Field(default=False, description="Whether to use scribble mode") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Softedge(HED) Processor", "tags": ["controlnet", "softedge", "hed", "image", "processor"]}, - } + # Inputs + detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) + image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) + # safe not supported in controlnet_aux v0.0.3 + # safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode) + scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode) def run_processor(self, image): hed_processor = HEDdetector.from_pretrained("lllyasviel/Annotators") @@ -307,21 +225,17 @@ class HedImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig) return processed_image -class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): +@title("Lineart Processor") +@tags("controlnet", "lineart") +class LineartImageProcessorInvocation(ImageProcessorInvocation): """Applies line art processing to image""" - # fmt: off type: Literal["lineart_image_processor"] = "lineart_image_processor" - # Inputs - detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection") - image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image") - coarse: bool = Field(default=False, description="Whether to use coarse mode") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Lineart Processor", "tags": ["controlnet", "lineart", "image", "processor"]}, - } + # Inputs + detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) + image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) + coarse: bool = InputField(default=False, description="Whether to use coarse mode") def run_processor(self, image): lineart_processor = LineartDetector.from_pretrained("lllyasviel/Annotators") @@ -331,23 +245,16 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCon return processed_image -class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): +@title("Lineart Anime Processor") +@tags("controlnet", "lineart", "anime") +class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation): """Applies line art anime processing to image""" - # fmt: off type: Literal["lineart_anime_image_processor"] = "lineart_anime_image_processor" - # Inputs - detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection") - image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "Lineart Anime Processor", - "tags": ["controlnet", "lineart", "anime", "image", "processor"], - }, - } + # Inputs + detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) + image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) def run_processor(self, image): processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators") @@ -359,21 +266,17 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation, PILInvocati return processed_image -class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): +@title("Openpose Processor") +@tags("controlnet", "openpose", "pose") +class OpenposeImageProcessorInvocation(ImageProcessorInvocation): """Applies Openpose processing to image""" - # fmt: off type: Literal["openpose_image_processor"] = "openpose_image_processor" - # Inputs - hand_and_face: bool = Field(default=False, description="Whether to use hands and face mode") - detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection") - image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Openpose Processor", "tags": ["controlnet", "openpose", "image", "processor"]}, - } + # Inputs + hand_and_face: bool = InputField(default=False, description="Whether to use hands and face mode") + detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) + image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) def run_processor(self, image): openpose_processor = OpenposeDetector.from_pretrained("lllyasviel/Annotators") @@ -386,22 +289,18 @@ class OpenposeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCo return processed_image -class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): +@title("Midas (Depth) Processor") +@tags("controlnet", "midas", "depth") +class MidasDepthImageProcessorInvocation(ImageProcessorInvocation): """Applies Midas depth processing to image""" - # fmt: off type: Literal["midas_depth_image_processor"] = "midas_depth_image_processor" - # Inputs - a_mult: float = Field(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)") - bg_th: float = Field(default=0.1, ge=0, description="Midas parameter `bg_th`") - # depth_and_normal not supported in controlnet_aux v0.0.3 - # depth_and_normal: bool = Field(default=False, description="whether to use depth and normal mode") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Midas (Depth) Processor", "tags": ["controlnet", "midas", "depth", "image", "processor"]}, - } + # Inputs + a_mult: float = InputField(default=2.0, ge=0, description="Midas parameter `a_mult` (a = a_mult * PI)") + bg_th: float = InputField(default=0.1, ge=0, description="Midas parameter `bg_th`") + # depth_and_normal not supported in controlnet_aux v0.0.3 + # depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode") def run_processor(self, image): midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators") @@ -415,20 +314,16 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocation return processed_image -class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): +@title("Normal BAE Processor") +@tags("controlnet", "normal", "bae") +class NormalbaeImageProcessorInvocation(ImageProcessorInvocation): """Applies NormalBae processing to image""" - # fmt: off type: Literal["normalbae_image_processor"] = "normalbae_image_processor" - # Inputs - detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection") - image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Normal BAE Processor", "tags": ["controlnet", "normal", "bae", "image", "processor"]}, - } + # Inputs + detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) + image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) def run_processor(self, image): normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators") @@ -438,22 +333,18 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation, PILInvocationC return processed_image -class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): +@title("MLSD Processor") +@tags("controlnet", "mlsd") +class MlsdImageProcessorInvocation(ImageProcessorInvocation): """Applies MLSD processing to image""" - # fmt: off type: Literal["mlsd_image_processor"] = "mlsd_image_processor" - # Inputs - detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection") - image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image") - thr_v: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_v`") - thr_d: float = Field(default=0.1, ge=0, description="MLSD parameter `thr_d`") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "MLSD Processor", "tags": ["controlnet", "mlsd", "image", "processor"]}, - } + # Inputs + detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) + image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) + thr_v: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_v`") + thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`") def run_processor(self, image): mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators") @@ -467,22 +358,18 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig return processed_image -class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): +@title("PIDI Processor") +@tags("controlnet", "pidi") +class PidiImageProcessorInvocation(ImageProcessorInvocation): """Applies PIDI processing to image""" - # fmt: off type: Literal["pidi_image_processor"] = "pidi_image_processor" - # Inputs - detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection") - image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image") - safe: bool = Field(default=False, description="Whether to use safe mode") - scribble: bool = Field(default=False, description="Whether to use scribble mode") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "PIDI Processor", "tags": ["controlnet", "pidi", "image", "processor"]}, - } + # Inputs + detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) + image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) + safe: bool = InputField(default=False, description=FieldDescriptions.safe_mode) + scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode) def run_processor(self, image): pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators") @@ -496,26 +383,19 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig return processed_image -class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): +@title("Content Shuffle Processor") +@tags("controlnet", "contentshuffle") +class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation): """Applies content shuffle processing to image""" - # fmt: off type: Literal["content_shuffle_image_processor"] = "content_shuffle_image_processor" - # Inputs - detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection") - image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image") - h: Optional[int] = Field(default=512, ge=0, description="Content shuffle `h` parameter") - w: Optional[int] = Field(default=512, ge=0, description="Content shuffle `w` parameter") - f: Optional[int] = Field(default=256, ge=0, description="Content shuffle `f` parameter") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "Content Shuffle Processor", - "tags": ["controlnet", "contentshuffle", "image", "processor"], - }, - } + # Inputs + detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) + image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) + h: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `h` parameter") + w: Optional[int] = InputField(default=512, ge=0, description="Content shuffle `w` parameter") + f: Optional[int] = InputField(default=256, ge=0, description="Content shuffle `f` parameter") def run_processor(self, image): content_shuffle_processor = ContentShuffleDetector() @@ -531,17 +411,12 @@ class ContentShuffleImageProcessorInvocation(ImageProcessorInvocation, PILInvoca # should work with controlnet_aux >= 0.0.4 and timm <= 0.6.13 -class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): +@title("Zoe (Depth) Processor") +@tags("controlnet", "zoe", "depth") +class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation): """Applies Zoe depth processing to image""" - # fmt: off type: Literal["zoe_depth_image_processor"] = "zoe_depth_image_processor" - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Zoe (Depth) Processor", "tags": ["controlnet", "zoe", "depth", "image", "processor"]}, - } def run_processor(self, image): zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators") @@ -549,20 +424,16 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation, PILInvocationCo return processed_image -class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): +@title("Mediapipe Face Processor") +@tags("controlnet", "mediapipe", "face") +class MediapipeFaceProcessorInvocation(ImageProcessorInvocation): """Applies mediapipe face processing to image""" - # fmt: off type: Literal["mediapipe_face_processor"] = "mediapipe_face_processor" - # Inputs - max_faces: int = Field(default=1, ge=1, description="Maximum number of faces to detect") - min_confidence: float = Field(default=0.5, ge=0, le=1, description="Minimum confidence for face detection") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Mediapipe Processor", "tags": ["controlnet", "mediapipe", "image", "processor"]}, - } + # Inputs + max_faces: int = InputField(default=1, ge=1, description="Maximum number of faces to detect") + min_confidence: float = InputField(default=0.5, ge=0, le=1, description="Minimum confidence for face detection") def run_processor(self, image): # MediaPipeFaceDetector throws an error if image has alpha channel @@ -574,23 +445,19 @@ class MediapipeFaceProcessorInvocation(ImageProcessorInvocation, PILInvocationCo return processed_image -class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): +@title("Leres (Depth) Processor") +@tags("controlnet", "leres", "depth") +class LeresImageProcessorInvocation(ImageProcessorInvocation): """Applies leres processing to image""" - # fmt: off type: Literal["leres_image_processor"] = "leres_image_processor" - # Inputs - thr_a: float = Field(default=0, description="Leres parameter `thr_a`") - thr_b: float = Field(default=0, description="Leres parameter `thr_b`") - boost: bool = Field(default=False, description="Whether to use boost mode") - detect_resolution: int = Field(default=512, ge=0, description="The pixel resolution for detection") - image_resolution: int = Field(default=512, ge=0, description="The pixel resolution for the output image") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Leres (Depth) Processor", "tags": ["controlnet", "leres", "depth", "image", "processor"]}, - } + # Inputs + thr_a: float = InputField(default=0, description="Leres parameter `thr_a`") + thr_b: float = InputField(default=0, description="Leres parameter `thr_b`") + boost: bool = InputField(default=False, description="Whether to use boost mode") + detect_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.detect_res) + image_resolution: int = InputField(default=512, ge=0, description=FieldDescriptions.image_res) def run_processor(self, image): leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators") @@ -605,21 +472,16 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation, PILInvocationConfi return processed_image -class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): - # fmt: off - type: Literal["tile_image_processor"] = "tile_image_processor" - # Inputs - #res: int = Field(default=512, ge=0, le=1024, description="The pixel resolution for each tile") - down_sampling_rate: float = Field(default=1.0, ge=1.0, le=8.0, description="Down sampling rate") - # fmt: on +@title("Tile Resample Processor") +@tags("controlnet", "tile") +class TileResamplerProcessorInvocation(ImageProcessorInvocation): + """Tile resampler processor""" - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "Tile Resample Processor", - "tags": ["controlnet", "tile", "resample", "image", "processor"], - }, - } + type: Literal["tile_image_processor"] = "tile_image_processor" + + # Inputs + # res: int = InputField(default=512, ge=0, le=1024, description="The pixel resolution for each tile") + down_sampling_rate: float = InputField(default=1.0, ge=1.0, le=8.0, description="Down sampling rate") # tile_resample copied from sd-webui-controlnet/scripts/processor.py def tile_resample( @@ -648,20 +510,12 @@ class TileResamplerProcessorInvocation(ImageProcessorInvocation, PILInvocationCo return processed_image -class SegmentAnythingProcessorInvocation(ImageProcessorInvocation, PILInvocationConfig): +@title("Segment Anything Processor") +@tags("controlnet", "segmentanything") +class SegmentAnythingProcessorInvocation(ImageProcessorInvocation): """Applies segment anything processing to image""" - # fmt: off type: Literal["segment_anything_processor"] = "segment_anything_processor" - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "Segment Anything Processor", - "tags": ["controlnet", "segment", "anything", "sam", "image", "processor"], - }, - } def run_processor(self, image): # segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py index bd3a4adbe4..ed2030a835 100644 --- a/invokeai/app/invocations/cv.py +++ b/invokeai/app/invocations/cv.py @@ -5,40 +5,22 @@ from typing import Literal import cv2 as cv import numpy from PIL import Image, ImageOps -from pydantic import BaseModel, Field from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin -from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig +from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title from .image import ImageOutput -class CvInvocationConfig(BaseModel): - """Helper class to provide all OpenCV invocations with additional config""" - - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "tags": ["cv", "image"], - }, - } - - -class CvInpaintInvocation(BaseInvocation, CvInvocationConfig): +@title("OpenCV Inpaint") +@tags("opencv", "inpaint") +class CvInpaintInvocation(BaseInvocation): """Simple inpaint using opencv.""" - # fmt: off type: Literal["cv_inpaint"] = "cv_inpaint" # Inputs - image: ImageField = Field(default=None, description="The image to inpaint") - mask: ImageField = Field(default=None, description="The mask to use when inpainting") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "OpenCV Inpaint", "tags": ["opencv", "inpaint"]}, - } + image: ImageField = InputField(description="The image to inpaint") + mask: ImageField = InputField(description="The mask to use when inpainting") def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 2c47020207..5c277ec30f 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -1,37 +1,30 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) from pathlib import Path -from typing import Literal, Optional, Union +from typing import Literal, Optional import cv2 import numpy from PIL import Image, ImageChops, ImageFilter, ImageOps -from pydantic import Field from invokeai.app.invocations.metadata import CoreMetadata from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark from invokeai.backend.image_util.safety_checker import SafetyChecker -from ..models.image import ImageCategory, ImageField, ImageOutput, MaskOutput, PILInvocationConfig, ResourceOrigin -from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext +from ..models.image import ImageCategory, ImageField, ImageOutput, MaskOutput, ResourceOrigin +from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, tags, title +@title("Load Image") +@tags("image") class LoadImageInvocation(BaseInvocation): """Load an image and provide it as output.""" - # fmt: off + # Metadata type: Literal["load_image"] = "load_image" # Inputs - image: Optional[ImageField] = Field( - default=None, description="The image to load" - ) - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Load Image", "tags": ["image", "load"]}, - } + image: ImageField = InputField(description="The image to load") def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) @@ -43,18 +36,16 @@ class LoadImageInvocation(BaseInvocation): ) +@title("Show Image") +@tags("image") class ShowImageInvocation(BaseInvocation): """Displays a provided image, and passes it forward in the pipeline.""" + # Metadata type: Literal["show_image"] = "show_image" # Inputs - image: Optional[ImageField] = Field(default=None, description="The image to show") - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Show Image", "tags": ["image", "show"]}, - } + image: ImageField = InputField(description="The image to show") def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) @@ -70,24 +61,20 @@ class ShowImageInvocation(BaseInvocation): ) -class ImageCropInvocation(BaseInvocation, PILInvocationConfig): +@title("Crop Image") +@tags("image", "crop") +class ImageCropInvocation(BaseInvocation): """Crops an image to a specified box. The box can be outside of the image.""" - # fmt: off + # Metadata type: Literal["img_crop"] = "img_crop" # Inputs - image: Optional[ImageField] = Field(default=None, description="The image to crop") - x: int = Field(default=0, description="The left x coordinate of the crop rectangle") - y: int = Field(default=0, description="The top y coordinate of the crop rectangle") - width: int = Field(default=512, gt=0, description="The width of the crop rectangle") - height: int = Field(default=512, gt=0, description="The height of the crop rectangle") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Crop Image", "tags": ["image", "crop"]}, - } + image: ImageField = InputField(description="The image to crop") + x: int = InputField(default=0, description="The left x coordinate of the crop rectangle") + y: int = InputField(default=0, description="The top y coordinate of the crop rectangle") + width: int = InputField(default=512, gt=0, description="The width of the crop rectangle") + height: int = InputField(default=512, gt=0, description="The height of the crop rectangle") def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) @@ -111,24 +98,23 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig): ) -class ImagePasteInvocation(BaseInvocation, PILInvocationConfig): +@title("Paste Image") +@tags("image", "paste") +class ImagePasteInvocation(BaseInvocation): """Pastes an image into another image.""" - # fmt: off + # Metadata type: Literal["img_paste"] = "img_paste" # Inputs - base_image: Optional[ImageField] = Field(default=None, description="The base image") - image: Optional[ImageField] = Field(default=None, description="The image to paste") - mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting") - x: int = Field(default=0, description="The left x coordinate at which to paste the image") - y: int = Field(default=0, description="The top y coordinate at which to paste the image") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Paste Image", "tags": ["image", "paste"]}, - } + base_image: ImageField = InputField(description="The base image") + image: ImageField = InputField(description="The image to paste") + mask: Optional[ImageField] = InputField( + default=None, + description="The mask to use when pasting", + ) + x: int = InputField(default=0, description="The left x coordinate at which to paste the image") + y: int = InputField(default=0, description="The top y coordinate at which to paste the image") def invoke(self, context: InvocationContext) -> ImageOutput: base_image = context.services.images.get_pil_image(self.base_image.image_name) @@ -164,21 +150,17 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig): ) -class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig): +@title("Mask from Alpha") +@tags("image", "mask") +class MaskFromAlphaInvocation(BaseInvocation): """Extracts the alpha channel of an image as a mask.""" - # fmt: off + # Metadata type: Literal["tomask"] = "tomask" # Inputs - image: Optional[ImageField] = Field(default=None, description="The image to create the mask from") - invert: bool = Field(default=False, description="Whether or not to invert the mask") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Mask From Alpha", "tags": ["image", "mask", "alpha"]}, - } + image: ImageField = InputField(description="The image to create the mask from") + invert: bool = InputField(default=False, description="Whether or not to invert the mask") def invoke(self, context: InvocationContext) -> MaskOutput: image = context.services.images.get_pil_image(self.image.image_name) @@ -203,21 +185,17 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig): ) -class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig): +@title("Multiply Images") +@tags("image", "multiply") +class ImageMultiplyInvocation(BaseInvocation): """Multiplies two images together using `PIL.ImageChops.multiply()`.""" - # fmt: off + # Metadata type: Literal["img_mul"] = "img_mul" # Inputs - image1: Optional[ImageField] = Field(default=None, description="The first image to multiply") - image2: Optional[ImageField] = Field(default=None, description="The second image to multiply") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Multiply Images", "tags": ["image", "multiply"]}, - } + image1: ImageField = InputField(description="The first image to multiply") + image2: ImageField = InputField(description="The second image to multiply") def invoke(self, context: InvocationContext) -> ImageOutput: image1 = context.services.images.get_pil_image(self.image1.image_name) @@ -244,21 +222,17 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig): IMAGE_CHANNELS = Literal["A", "R", "G", "B"] -class ImageChannelInvocation(BaseInvocation, PILInvocationConfig): +@title("Extract Image Channel") +@tags("image", "channel") +class ImageChannelInvocation(BaseInvocation): """Gets a channel from an image.""" - # fmt: off + # Metadata type: Literal["img_chan"] = "img_chan" # Inputs - image: Optional[ImageField] = Field(default=None, description="The image to get the channel from") - channel: IMAGE_CHANNELS = Field(default="A", description="The channel to get") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Image Channel", "tags": ["image", "channel"]}, - } + image: ImageField = InputField(description="The image to get the channel from") + channel: IMAGE_CHANNELS = InputField(default="A", description="The channel to get") def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) @@ -284,21 +258,17 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig): IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"] -class ImageConvertInvocation(BaseInvocation, PILInvocationConfig): +@title("Convert Image Mode") +@tags("image", "convert") +class ImageConvertInvocation(BaseInvocation): """Converts an image to a different mode.""" - # fmt: off + # Metadata type: Literal["img_conv"] = "img_conv" # Inputs - image: Optional[ImageField] = Field(default=None, description="The image to convert") - mode: IMAGE_MODES = Field(default="L", description="The mode to convert to") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Convert Image", "tags": ["image", "convert"]}, - } + image: ImageField = InputField(description="The image to convert") + mode: IMAGE_MODES = InputField(default="L", description="The mode to convert to") def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) @@ -321,22 +291,19 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig): ) -class ImageBlurInvocation(BaseInvocation, PILInvocationConfig): +@title("Blur Image") +@tags("image", "blur") +class ImageBlurInvocation(BaseInvocation): """Blurs an image""" - # fmt: off + # Metadata type: Literal["img_blur"] = "img_blur" # Inputs - image: Optional[ImageField] = Field(default=None, description="The image to blur") - radius: float = Field(default=8.0, ge=0, description="The blur radius") - blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Blur Image", "tags": ["image", "blur"]}, - } + image: ImageField = InputField(description="The image to blur") + radius: float = InputField(default=8.0, ge=0, description="The blur radius") + # Metadata + blur_type: Literal["gaussian", "box"] = InputField(default="gaussian", description="The type of blur") def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) @@ -382,23 +349,19 @@ PIL_RESAMPLING_MAP = { } -class ImageResizeInvocation(BaseInvocation, PILInvocationConfig): +@title("Resize Image") +@tags("image", "resize") +class ImageResizeInvocation(BaseInvocation): """Resizes an image to specific dimensions""" - # fmt: off + # Metadata type: Literal["img_resize"] = "img_resize" # Inputs - image: Optional[ImageField] = Field(default=None, description="The image to resize") - width: Union[int, None] = Field(ge=64, multiple_of=8, description="The width to resize to (px)") - height: Union[int, None] = Field(ge=64, multiple_of=8, description="The height to resize to (px)") - resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Resize Image", "tags": ["image", "resize"]}, - } + image: ImageField = InputField(description="The image to resize") + width: int = InputField(default=512, ge=64, multiple_of=8, description="The width to resize to (px)") + height: int = InputField(default=512, ge=64, multiple_of=8, description="The height to resize to (px)") + resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) @@ -426,22 +389,22 @@ class ImageResizeInvocation(BaseInvocation, PILInvocationConfig): ) -class ImageScaleInvocation(BaseInvocation, PILInvocationConfig): +@title("Scale Image") +@tags("image", "scale") +class ImageScaleInvocation(BaseInvocation): """Scales an image by a factor""" - # fmt: off + # Metadata type: Literal["img_scale"] = "img_scale" # Inputs - image: Optional[ImageField] = Field(default=None, description="The image to scale") - scale_factor: Optional[float] = Field(default=2.0, gt=0, description="The factor by which to scale the image") - resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Scale Image", "tags": ["image", "scale"]}, - } + image: ImageField = InputField(description="The image to scale") + scale_factor: float = InputField( + default=2.0, + gt=0, + description="The factor by which to scale the image", + ) + resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode") def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) @@ -471,22 +434,18 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig): ) -class ImageLerpInvocation(BaseInvocation, PILInvocationConfig): +@title("Lerp Image") +@tags("image", "lerp") +class ImageLerpInvocation(BaseInvocation): """Linear interpolation of all pixels of an image""" - # fmt: off + # Metadata type: Literal["img_lerp"] = "img_lerp" # Inputs - image: Optional[ImageField] = Field(default=None, description="The image to lerp") - min: int = Field(default=0, ge=0, le=255, description="The minimum output value") - max: int = Field(default=255, ge=0, le=255, description="The maximum output value") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Image Linear Interpolation", "tags": ["image", "linear", "interpolation", "lerp"]}, - } + image: ImageField = InputField(description="The image to lerp") + min: int = InputField(default=0, ge=0, le=255, description="The minimum output value") + max: int = InputField(default=255, ge=0, le=255, description="The maximum output value") def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) @@ -512,25 +471,18 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig): ) -class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig): +@title("Inverse Lerp Image") +@tags("image", "ilerp") +class ImageInverseLerpInvocation(BaseInvocation): """Inverse linear interpolation of all pixels of an image""" - # fmt: off + # Metadata type: Literal["img_ilerp"] = "img_ilerp" # Inputs - image: Optional[ImageField] = Field(default=None, description="The image to lerp") - min: int = Field(default=0, ge=0, le=255, description="The minimum input value") - max: int = Field(default=255, ge=0, le=255, description="The maximum input value") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "Image Inverse Linear Interpolation", - "tags": ["image", "linear", "interpolation", "inverse"], - }, - } + image: ImageField = InputField(description="The image to lerp") + min: int = InputField(default=0, ge=0, le=255, description="The minimum input value") + max: int = InputField(default=255, ge=0, le=255, description="The maximum input value") def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) @@ -556,21 +508,19 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig): ) -class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig): +@title("Blur NSFW Image") +@tags("image", "nsfw") +class ImageNSFWBlurInvocation(BaseInvocation): """Add blur to NSFW-flagged images""" - # fmt: off + # Metadata type: Literal["img_nsfw"] = "img_nsfw" # Inputs - image: Optional[ImageField] = Field(default=None, description="The image to check") - metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Blur NSFW Images", "tags": ["image", "nsfw", "checker"]}, - } + image: ImageField = InputField(description="The image to check") + metadata: Optional[CoreMetadata] = InputField( + default=None, description=FieldDescriptions.core_metadata, ui_hidden=True + ) def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) @@ -607,22 +557,20 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig): return caution.resize((caution.width // 2, caution.height // 2)) -class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig): +@title("Add Invisible Watermark") +@tags("image", "watermark") +class ImageWatermarkInvocation(BaseInvocation): """Add an invisible watermark to an image""" - # fmt: off + # Metadata type: Literal["img_watermark"] = "img_watermark" # Inputs - image: Optional[ImageField] = Field(default=None, description="The image to check") - text: str = Field(default='InvokeAI', description="Watermark text") - metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Add Invisible Watermark", "tags": ["image", "watermark", "invisible"]}, - } + image: ImageField = InputField(description="The image to check") + text: str = InputField(default="InvokeAI", description="Watermark text") + metadata: Optional[CoreMetadata] = InputField( + default=None, description=FieldDescriptions.core_metadata, ui_hidden=True + ) def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) @@ -644,19 +592,21 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig): ) -class MaskEdgeInvocation(BaseInvocation, PILInvocationConfig): +@title("Mask Edge") +@tags("image", "mask", "inpaint") +class MaskEdgeInvocation(BaseInvocation): """Applies an edge mask to an image""" - # fmt: off type: Literal["mask_edge"] = "mask_edge" # Inputs - image: Optional[ImageField] = Field(default=None, description="The image to apply the mask to") - edge_size: int = Field(description="The size of the edge") - edge_blur: int = Field(description="The amount of blur on the edge") - low_threshold: int = Field(description="First threshold for the hysteresis procedure in Canny edge detection") - high_threshold: int = Field(description="Second threshold for the hysteresis procedure in Canny edge detection") - # fmt: on + image: ImageField = InputField(description="The image to apply the mask to") + edge_size: int = InputField(description="The size of the edge") + edge_blur: int = InputField(description="The amount of blur on the edge") + low_threshold: int = InputField(description="First threshold for the hysteresis procedure in Canny edge detection") + high_threshold: int = InputField( + description="Second threshold for the hysteresis procedure in Canny edge detection" + ) def invoke(self, context: InvocationContext) -> MaskOutput: mask = context.services.images.get_pil_image(self.image.image_name) @@ -690,21 +640,16 @@ class MaskEdgeInvocation(BaseInvocation, PILInvocationConfig): ) -class MaskCombineInvocation(BaseInvocation, PILInvocationConfig): +@title("Combine Mask") +@tags("image", "mask", "multiply") +class MaskCombineInvocation(BaseInvocation): """Combine two masks together by multiplying them using `PIL.ImageChops.multiply()`.""" - # fmt: off type: Literal["mask_combine"] = "mask_combine" # Inputs - mask1: ImageField = Field(default=None, description="The first mask to combine") - mask2: ImageField = Field(default=None, description="The second image to combine") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Mask Combine", "tags": ["mask", "combine"]}, - } + mask1: ImageField = InputField(description="The first mask to combine") + mask2: ImageField = InputField(description="The second image to combine") def invoke(self, context: InvocationContext) -> ImageOutput: mask1 = context.services.images.get_pil_image(self.mask1.image_name).convert("L") @@ -728,7 +673,9 @@ class MaskCombineInvocation(BaseInvocation, PILInvocationConfig): ) -class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig): +@title("Color Correct") +@tags("image", "color") +class ColorCorrectInvocation(BaseInvocation): """ Shifts the colors of a target image to match the reference image, optionally using a mask to only color-correct certain regions of the target image. @@ -736,10 +683,11 @@ class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig): type: Literal["color_correct"] = "color_correct" - image: Optional[ImageField] = Field(default=None, description="The image to color-correct") - reference: Optional[ImageField] = Field(default=None, description="Reference image for color-correction") - mask: Optional[ImageField] = Field(default=None, description="Mask to use when applying color-correction") - mask_blur_radius: float = Field(default=8, description="Mask blur radius") + # Inputs + image: ImageField = InputField(description="The image to color-correct") + reference: ImageField = InputField(description="Reference image for color-correction") + mask: Optional[ImageField] = InputField(default=None, description="Mask to use when applying color-correction") + mask_blur_radius: float = InputField(default=8, description="Mask blur radius") def invoke(self, context: InvocationContext) -> ImageOutput: pil_init_mask = None @@ -833,16 +781,16 @@ class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig): ) +@title("Image Hue Adjustment") +@tags("image", "hue", "hsl") class ImageHueAdjustmentInvocation(BaseInvocation): """Adjusts the Hue of an image.""" - # fmt: off type: Literal["img_hue_adjust"] = "img_hue_adjust" # Inputs - image: ImageField = Field(default=None, description="The image to adjust") - hue: int = Field(default=0, description="The degrees by which to rotate the hue, 0-360") - # fmt: on + image: ImageField = InputField(description="The image to adjust") + hue: int = InputField(default=0, description="The degrees by which to rotate the hue, 0-360") def invoke(self, context: InvocationContext) -> ImageOutput: pil_image = context.services.images.get_pil_image(self.image.image_name) @@ -877,16 +825,18 @@ class ImageHueAdjustmentInvocation(BaseInvocation): ) +@title("Image Luminosity Adjustment") +@tags("image", "luminosity", "hsl") class ImageLuminosityAdjustmentInvocation(BaseInvocation): """Adjusts the Luminosity (Value) of an image.""" - # fmt: off type: Literal["img_luminosity_adjust"] = "img_luminosity_adjust" # Inputs - image: ImageField = Field(default=None, description="The image to adjust") - luminosity: float = Field(default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)") - # fmt: on + image: ImageField = InputField(description="The image to adjust") + luminosity: float = InputField( + default=1.0, ge=0, le=1, description="The factor by which to adjust the luminosity (value)" + ) def invoke(self, context: InvocationContext) -> ImageOutput: pil_image = context.services.images.get_pil_image(self.image.image_name) @@ -925,16 +875,16 @@ class ImageLuminosityAdjustmentInvocation(BaseInvocation): ) +@title("Image Saturation Adjustment") +@tags("image", "saturation", "hsl") class ImageSaturationAdjustmentInvocation(BaseInvocation): """Adjusts the Saturation of an image.""" - # fmt: off type: Literal["img_saturation_adjust"] = "img_saturation_adjust" # Inputs - image: ImageField = Field(default=None, description="The image to adjust") - saturation: float = Field(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation") - # fmt: on + image: ImageField = InputField(description="The image to adjust") + saturation: float = InputField(default=1.0, ge=0, le=1, description="The factor by which to adjust the saturation") def invoke(self, context: InvocationContext) -> ImageOutput: pil_image = context.services.images.get_pil_image(self.image.image_name) diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index cd5b2f9a11..2294f806ca 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -5,18 +5,13 @@ from typing import Literal, Optional, get_args import numpy as np import math from PIL import Image, ImageOps -from pydantic import Field from invokeai.app.invocations.image import ImageOutput from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.backend.image_util.patchmatch import PatchMatch from ..models.image import ColorField, ImageCategory, ImageField, ResourceOrigin -from .baseinvocation import ( - BaseInvocation, - InvocationConfig, - InvocationContext, -) +from .baseinvocation import BaseInvocation, InputField, InvocationContext, UITypeHint, title, tags def infill_methods() -> list[str]: @@ -114,21 +109,20 @@ def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int] return si +@title("Solid Color Infill") +@tags("image", "inpaint") class InfillColorInvocation(BaseInvocation): """Infills transparent areas of an image with a solid color""" type: Literal["infill_rgba"] = "infill_rgba" - image: Optional[ImageField] = Field(default=None, description="The image to infill") - color: ColorField = Field( + + # Inputs + image: ImageField = InputField(description="The image to infill") + color: ColorField = InputField( default=ColorField(r=127, g=127, b=127, a=255), description="The color to use to infill", ) - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Color Infill", "tags": ["image", "inpaint", "color", "infill"]}, - } - def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) @@ -153,25 +147,23 @@ class InfillColorInvocation(BaseInvocation): ) +@title("Tile Infill") +@tags("image", "inpaint") class InfillTileInvocation(BaseInvocation): """Infills transparent areas of an image with tiles of the image""" type: Literal["infill_tile"] = "infill_tile" - image: Optional[ImageField] = Field(default=None, description="The image to infill") - tile_size: int = Field(default=32, ge=1, description="The tile size (px)") - seed: int = Field( + # Input + image: ImageField = InputField(description="The image to infill") + tile_size: int = InputField(default=32, ge=1, description="The tile size (px)") + seed: int = InputField( ge=0, le=SEED_MAX, description="The seed to use for tile generation (omit for random)", default_factory=get_random_seed, ) - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Tile Infill", "tags": ["image", "inpaint", "tile", "infill"]}, - } - def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) @@ -194,17 +186,15 @@ class InfillTileInvocation(BaseInvocation): ) +@title("PatchMatch Infill") +@tags("image", "inpaint") class InfillPatchMatchInvocation(BaseInvocation): """Infills transparent areas of an image using the PatchMatch algorithm""" type: Literal["infill_patchmatch"] = "infill_patchmatch" - image: Optional[ImageField] = Field(default=None, description="The image to infill") - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Patch Match Infill", "tags": ["image", "inpaint", "patchmatch", "infill"]}, - } + # Inputs + image: ImageField = InputField(description="The image to infill") def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index c66c9c6214..63cfd95394 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -13,7 +13,8 @@ from diffusers.models.attention_processor import ( LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) -from diffusers.schedulers import DPMSolverSDEScheduler, SchedulerMixin as Scheduler +from diffusers.schedulers import DPMSolverSDEScheduler +from diffusers.schedulers import SchedulerMixin as Scheduler from pydantic import BaseModel, Field, validator from torchvision.transforms.functional import resize as tv_resize @@ -23,6 +24,7 @@ from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend.model_management.models import ModelType, SilenceWarnings from ...backend.model_management import BaseModelType, ModelPatcher +from ...backend.model_management.lora import ModelPatcher from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion.diffusers_pipeline import ( ConditioningData, @@ -32,9 +34,20 @@ from ...backend.stable_diffusion.diffusers_pipeline import ( ) from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP -from ...backend.util.devices import choose_precision, choose_torch_device, torch_dtype +from ...backend.util.devices import choose_precision, choose_torch_device from ..models.image import ImageCategory, ImageField, ResourceOrigin -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext +from .baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + FieldDescriptions, + Input, + InputField, + InvocationContext, + OutputField, + UITypeHint, + tags, + title, +) from .compel import ConditioningField from .controlnet_image_processors import ControlField from .image import ImageOutput @@ -46,8 +59,8 @@ DEFAULT_PRECISION = choose_precision(choose_torch_device()) class LatentsField(BaseModel): """A latents field used for passing latents between invocations""" - latents_name: Optional[str] = Field(default=None, description="The name of the latents") - seed: Optional[int] = Field(description="Seed used to generate this latents") + latents_name: str = Field(description="The name of the latents") + seed: Optional[int] = Field(default=None, description="Seed used to generate this latents") class Config: schema_extra = {"required": ["latents_name"]} @@ -56,14 +69,14 @@ class LatentsField(BaseModel): class LatentsOutput(BaseInvocationOutput): """Base class for invocations that output latents""" - # fmt: off type: Literal["latents_output"] = "latents_output" # Inputs - latents: LatentsField = Field(default=None, description="The output latents") - width: int = Field(description="The width of the latents in pixels") - height: int = Field(description="The height of the latents in pixels") - # fmt: on + latents: LatentsField = OutputField( + description=FieldDescriptions.latents, + ) + width: int = OutputField(description=FieldDescriptions.width) + height: int = OutputField(description=FieldDescriptions.height) def build_latents_output(latents_name: str, latents: torch.Tensor, seed: Optional[int]): @@ -111,30 +124,36 @@ def get_scheduler( return scheduler +@title("Denoise Latents") +@tags("latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l") class DenoiseLatentsInvocation(BaseInvocation): """Denoises noisy latents to decodable images""" type: Literal["denoise_latents"] = "denoise_latents" # Inputs - positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation") - negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation") - noise: Optional[LatentsField] = Field(description="The noise to use") - steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") - cfg_scale: Union[float, List[float]] = Field( - default=7.5, - ge=1, - description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", + positive_conditioning: ConditioningField = InputField( + description=FieldDescriptions.positive_cond, input=Input.Connection ) - denoising_start: float = Field(default=0.0, ge=0, le=1, description="") - denoising_end: float = Field(default=1.0, ge=0, le=1, description="") - scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use") - unet: UNetField = Field(default=None, description="UNet submodel") - control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use") - latents: Optional[LatentsField] = Field(description="The latents to use as a base image") - mask: Optional[ImageField] = Field( - None, - description="Mask", + negative_conditioning: ConditioningField = InputField( + description=FieldDescriptions.negative_cond, input=Input.Connection + ) + noise: Optional[LatentsField] = InputField(description=FieldDescriptions.noise, input=Input.Connection) + steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps) + cfg_scale: Union[float, List[float]] = InputField( + default=7.5, ge=1, description=FieldDescriptions.cfg_scale, ui_type_hint=UITypeHint.Float + ) + denoising_start: float = InputField(default=0.0, ge=0, le=1, description=FieldDescriptions.denoising_start) + denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end) + scheduler: SAMPLER_NAME_VALUES = InputField(default="euler", description=FieldDescriptions.scheduler) + unet: UNetField = InputField(description=FieldDescriptions.unet, input=Input.Connection) + control: Union[ControlField, list[ControlField]] = InputField( + default=None, description=FieldDescriptions.control, input=Input.Connection + ) + latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection) + mask: Optional[ImageField] = InputField( + default=None, + description=FieldDescriptions.mask, ) @validator("cfg_scale") @@ -149,20 +168,6 @@ class DenoiseLatentsInvocation(BaseInvocation): raise ValueError("cfg_scale must be greater than 1") return v - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "Denoise Latents", - "tags": ["denoise", "latents"], - "type_hints": { - "model": "model", - "control": "control", - "cfg_scale": "number", - }, - }, - } - # TODO: pass this an emitter method or something? or a session for dispatching? def dispatch_progress( self, @@ -474,29 +479,29 @@ class DenoiseLatentsInvocation(BaseInvocation): return build_latents_output(latents_name=name, latents=result_latents, seed=seed) -# Latent to image +@title("Latents to Image") +@tags("latents", "image", "vae") class LatentsToImageInvocation(BaseInvocation): """Generates an image from latents.""" type: Literal["l2i"] = "l2i" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to generate an image from") - vae: VaeField = Field(default=None, description="Vae submodel") - tiled: bool = Field(default=False, description="Decode latents by overlaping tiles (less memory consumption)") - fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision") - metadata: Optional[CoreMetadata] = Field( - default=None, description="Optional core metadata to be written to the image" + latents: LatentsField = InputField( + description=FieldDescriptions.latents, + input=Input.Connection, + ) + vae: VaeField = InputField( + description=FieldDescriptions.vae, + input=Input.Connection, + ) + tiled: bool = InputField(default=False, description=FieldDescriptions.tiled) + fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32) + metadata: CoreMetadata = InputField( + default=None, + description=FieldDescriptions.core_metadata, + ui_hidden=True, ) - - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "Latents To Image", - "tags": ["latents", "image"], - }, - } @torch.no_grad() def invoke(self, context: InvocationContext) -> ImageOutput: @@ -574,24 +579,30 @@ class LatentsToImageInvocation(BaseInvocation): LATENTS_INTERPOLATION_MODE = Literal["nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"] +@title("Resize Latents") +@tags("latents", "resize") class ResizeLatentsInvocation(BaseInvocation): """Resizes latents to explicit width/height (in pixels). Provided dimensions are floor-divided by 8.""" type: Literal["lresize"] = "lresize" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to resize") - width: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The width to resize to (px)") - height: Union[int, None] = Field(default=512, ge=64, multiple_of=8, description="The height to resize to (px)") - mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode") - antialias: bool = Field( - default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)" + latents: LatentsField = InputField( + description=FieldDescriptions.latents, + input=Input.Connection, ) - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Resize Latents", "tags": ["latents", "resize"]}, - } + width: int = InputField( + ge=64, + multiple_of=8, + description=FieldDescriptions.width, + ) + height: int = InputField( + ge=64, + multiple_of=8, + description=FieldDescriptions.width, + ) + mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode) + antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.services.latents.get(self.latents.latents_name) @@ -616,23 +627,21 @@ class ResizeLatentsInvocation(BaseInvocation): return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed) +@title("Scale Latents") +@tags("latents", "resize") class ScaleLatentsInvocation(BaseInvocation): """Scales latents by a given factor.""" type: Literal["lscale"] = "lscale" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to scale") - scale_factor: float = Field(gt=0, description="The factor by which to scale the latents") - mode: LATENTS_INTERPOLATION_MODE = Field(default="bilinear", description="The interpolation mode") - antialias: bool = Field( - default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)" + latents: LatentsField = InputField( + description=FieldDescriptions.latents, + input=Input.Connection, ) - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Scale Latents", "tags": ["latents", "scale"]}, - } + scale_factor: float = InputField(gt=0, description=FieldDescriptions.scale_factor) + mode: LATENTS_INTERPOLATION_MODE = InputField(default="bilinear", description=FieldDescriptions.interp_mode) + antialias: bool = InputField(default=False, description=FieldDescriptions.torch_antialias) def invoke(self, context: InvocationContext) -> LatentsOutput: latents = context.services.latents.get(self.latents.latents_name) @@ -658,22 +667,23 @@ class ScaleLatentsInvocation(BaseInvocation): return build_latents_output(latents_name=name, latents=resized_latents, seed=self.latents.seed) +@title("Image to Latents") +@tags("latents", "image", "vae") class ImageToLatentsInvocation(BaseInvocation): """Encodes an image into latents.""" type: Literal["i2l"] = "i2l" # Inputs - image: Optional[ImageField] = Field(description="The image to encode") - vae: VaeField = Field(default=None, description="Vae submodel") - tiled: bool = Field(default=False, description="Encode latents by overlaping tiles(less memory consumption)") - fp32: bool = Field(DEFAULT_PRECISION == "float32", description="Decode in full precision") - - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Image To Latents", "tags": ["latents", "image"]}, - } + image: ImageField = InputField( + description="The image to encode", + ) + vae: VaeField = InputField( + description=FieldDescriptions.vae, + input=Input.Connection, + ) + tiled: bool = InputField(default=False, description=FieldDescriptions.tiled) + fp32: bool = InputField(default=DEFAULT_PRECISION == "float32", description=FieldDescriptions.fp32) @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: diff --git a/invokeai/app/invocations/math.py b/invokeai/app/invocations/math.py index 32b1ab2a39..81c032ca89 100644 --- a/invokeai/app/invocations/math.py +++ b/invokeai/app/invocations/math.py @@ -2,134 +2,104 @@ from typing import Literal -from pydantic import BaseModel, Field import numpy as np from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, + FieldDescriptions, + InputField, InvocationContext, - InvocationConfig, + OutputField, + tags, + title, ) -class MathInvocationConfig(BaseModel): - """Helper class to provide all math invocations with additional config""" - - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "tags": ["math"], - } - } - - class IntOutput(BaseInvocationOutput): """An integer output""" - # fmt: off type: Literal["int_output"] = "int_output" - a: int = Field(default=None, description="The output integer") - # fmt: on + a: int = OutputField(default=None, description="The output integer") class FloatOutput(BaseInvocationOutput): """A float output""" - # fmt: off type: Literal["float_output"] = "float_output" - param: float = Field(default=None, description="The output float") - # fmt: on + a: float = OutputField(default=None, description="The output float") -class AddInvocation(BaseInvocation, MathInvocationConfig): +@title("Add Integers") +@tags("math") +class AddInvocation(BaseInvocation): """Adds two numbers""" - # fmt: off type: Literal["add"] = "add" - a: int = Field(default=0, description="The first number") - b: int = Field(default=0, description="The second number") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Add", "tags": ["math", "add"]}, - } + # Inputs + a: int = InputField(default=0, description=FieldDescriptions.num_1) + b: int = InputField(default=0, description=FieldDescriptions.num_2) def invoke(self, context: InvocationContext) -> IntOutput: return IntOutput(a=self.a + self.b) -class SubtractInvocation(BaseInvocation, MathInvocationConfig): +@title("Subtract Integers") +@tags("math") +class SubtractInvocation(BaseInvocation): """Subtracts two numbers""" - # fmt: off type: Literal["sub"] = "sub" - a: int = Field(default=0, description="The first number") - b: int = Field(default=0, description="The second number") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Subtract", "tags": ["math", "subtract"]}, - } + # Inputs + a: int = InputField(default=0, description=FieldDescriptions.num_1) + b: int = InputField(default=0, description=FieldDescriptions.num_2) def invoke(self, context: InvocationContext) -> IntOutput: return IntOutput(a=self.a - self.b) -class MultiplyInvocation(BaseInvocation, MathInvocationConfig): +@title("Multiply Integers") +@tags("math") +class MultiplyInvocation(BaseInvocation): """Multiplies two numbers""" - # fmt: off type: Literal["mul"] = "mul" - a: int = Field(default=0, description="The first number") - b: int = Field(default=0, description="The second number") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Multiply", "tags": ["math", "multiply"]}, - } + # Inputs + a: int = InputField(default=0, description=FieldDescriptions.num_1) + b: int = InputField(default=0, description=FieldDescriptions.num_2) def invoke(self, context: InvocationContext) -> IntOutput: return IntOutput(a=self.a * self.b) -class DivideInvocation(BaseInvocation, MathInvocationConfig): +@title("Divide Integers") +@tags("math") +class DivideInvocation(BaseInvocation): """Divides two numbers""" - # fmt: off type: Literal["div"] = "div" - a: int = Field(default=0, description="The first number") - b: int = Field(default=0, description="The second number") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Divide", "tags": ["math", "divide"]}, - } + # Inputs + a: int = InputField(default=0, description=FieldDescriptions.num_1) + b: int = InputField(default=0, description=FieldDescriptions.num_2) def invoke(self, context: InvocationContext) -> IntOutput: return IntOutput(a=int(self.a / self.b)) +@title("Random Integer") +@tags("math") class RandomIntInvocation(BaseInvocation): """Outputs a single random integer.""" - # fmt: off type: Literal["rand_int"] = "rand_int" - low: int = Field(default=0, description="The inclusive low value") - high: int = Field( - default=np.iinfo(np.int32).max, description="The exclusive high value" - ) - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Random Integer", "tags": ["math", "random", "integer"]}, - } + # Inputs + low: int = InputField(default=0, description="The inclusive low value") + high: int = InputField(default=np.iinfo(np.int32).max, description="The exclusive high value") def invoke(self, context: InvocationContext) -> IntOutput: return IntOutput(a=np.random.randint(self.low, self.high)) diff --git a/invokeai/app/invocations/metadata.py b/invokeai/app/invocations/metadata.py index d0549f8539..b0e7c13d43 100644 --- a/invokeai/app/invocations/metadata.py +++ b/invokeai/app/invocations/metadata.py @@ -1,18 +1,21 @@ -from typing import Literal, Optional, Union +from typing import Literal, Optional from pydantic import Field -from ...version import __version__ from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationConfig, + InputField, InvocationContext, + tags, + title, ) from invokeai.app.invocations.controlnet_image_processors import ControlField from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField from invokeai.app.util.model_exclude_null import BaseModelExcludeNull +from ...version import __version__ + class LoRAMetadataField(BaseModelExcludeNull): """LoRA metadata for an image generated in InvokeAI.""" @@ -43,37 +46,37 @@ class CoreMetadata(BaseModelExcludeNull): model: MainModelField = Field(description="The main model used for inference") controlnets: list[ControlField] = Field(description="The ControlNets used for inference") loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference") - vae: Union[VAEModelField, None] = Field( + vae: Optional[VAEModelField] = Field( default=None, description="The VAE used for decoding, if the main model's default was not used", ) # Latents-to-Latents - strength: Union[float, None] = Field( + strength: Optional[float] = Field( default=None, description="The strength used for latents-to-latents", ) - init_image: Union[str, None] = Field(default=None, description="The name of the initial image") + init_image: Optional[str] = Field(default=None, description="The name of the initial image") # SDXL - positive_style_prompt: Union[str, None] = Field(default=None, description="The positive style prompt parameter") - negative_style_prompt: Union[str, None] = Field(default=None, description="The negative style prompt parameter") + positive_style_prompt: Optional[str] = Field(default=None, description="The positive style prompt parameter") + negative_style_prompt: Optional[str] = Field(default=None, description="The negative style prompt parameter") # SDXL Refiner - refiner_model: Union[MainModelField, None] = Field(default=None, description="The SDXL Refiner model used") - refiner_cfg_scale: Union[float, None] = Field( + refiner_model: Optional[MainModelField] = Field(default=None, description="The SDXL Refiner model used") + refiner_cfg_scale: Optional[float] = Field( default=None, description="The classifier-free guidance scale parameter used for the refiner", ) - refiner_steps: Union[int, None] = Field(default=None, description="The number of steps used for the refiner") - refiner_scheduler: Union[str, None] = Field(default=None, description="The scheduler used for the refiner") - refiner_positive_aesthetic_store: Union[float, None] = Field( + refiner_steps: Optional[int] = Field(default=None, description="The number of steps used for the refiner") + refiner_scheduler: Optional[str] = Field(default=None, description="The scheduler used for the refiner") + refiner_positive_aesthetic_store: Optional[float] = Field( default=None, description="The aesthetic score used for the refiner" ) - refiner_negative_aesthetic_store: Union[float, None] = Field( + refiner_negative_aesthetic_store: Optional[float] = Field( default=None, description="The aesthetic score used for the refiner" ) - refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising") + refiner_start: Optional[float] = Field(default=None, description="The start value used for refiner denoising") class ImageMetadata(BaseModelExcludeNull): @@ -94,66 +97,83 @@ class MetadataAccumulatorOutput(BaseInvocationOutput): metadata: CoreMetadata = Field(description="The core metadata for the image") +@title("Metadata Accumulator") +@tags("metadata") class MetadataAccumulatorInvocation(BaseInvocation): """Outputs a Core Metadata Object""" type: Literal["metadata_accumulator"] = "metadata_accumulator" - generation_mode: str = Field( + generation_mode: str = InputField( description="The generation mode that output this image", ) - positive_prompt: str = Field(description="The positive prompt parameter") - negative_prompt: str = Field(description="The negative prompt parameter") - width: int = Field(description="The width parameter") - height: int = Field(description="The height parameter") - seed: int = Field(description="The seed used for noise generation") - rand_device: str = Field(description="The device used for random number generation") - cfg_scale: float = Field(description="The classifier-free guidance scale parameter") - steps: int = Field(description="The number of steps used for inference") - scheduler: str = Field(description="The scheduler used for inference") - clip_skip: int = Field( + positive_prompt: str = InputField(description="The positive prompt parameter") + negative_prompt: str = InputField(description="The negative prompt parameter") + width: int = InputField(description="The width parameter") + height: int = InputField(description="The height parameter") + seed: int = InputField(description="The seed used for noise generation") + rand_device: str = InputField(description="The device used for random number generation") + cfg_scale: float = InputField(description="The classifier-free guidance scale parameter") + steps: int = InputField(description="The number of steps used for inference") + scheduler: str = InputField(description="The scheduler used for inference") + clip_skip: int = InputField( description="The number of skipped CLIP layers", ) - model: MainModelField = Field(description="The main model used for inference") - controlnets: list[ControlField] = Field(description="The ControlNets used for inference") - loras: list[LoRAMetadataField] = Field(description="The LoRAs used for inference") - strength: Union[float, None] = Field( + model: MainModelField = InputField(description="The main model used for inference") + controlnets: list[ControlField] = InputField(description="The ControlNets used for inference") + loras: list[LoRAMetadataField] = InputField(description="The LoRAs used for inference") + strength: Optional[float] = InputField( default=None, description="The strength used for latents-to-latents", ) - init_image: Union[str, None] = Field(default=None, description="The name of the initial image") - vae: Union[VAEModelField, None] = Field( + init_image: Optional[str] = InputField( + default=None, + description="The name of the initial image", + ) + vae: Optional[VAEModelField] = InputField( default=None, description="The VAE used for decoding, if the main model's default was not used", ) # SDXL - positive_style_prompt: Union[str, None] = Field(default=None, description="The positive style prompt parameter") - negative_style_prompt: Union[str, None] = Field(default=None, description="The negative style prompt parameter") + positive_style_prompt: Optional[str] = InputField( + default=None, + description="The positive style prompt parameter", + ) + negative_style_prompt: Optional[str] = InputField( + default=None, + description="The negative style prompt parameter", + ) # SDXL Refiner - refiner_model: Union[MainModelField, None] = Field(default=None, description="The SDXL Refiner model used") - refiner_cfg_scale: Union[float, None] = Field( + refiner_model: Optional[MainModelField] = InputField( + default=None, + description="The SDXL Refiner model used", + ) + refiner_cfg_scale: Optional[float] = InputField( default=None, description="The classifier-free guidance scale parameter used for the refiner", ) - refiner_steps: Union[int, None] = Field(default=None, description="The number of steps used for the refiner") - refiner_scheduler: Union[str, None] = Field(default=None, description="The scheduler used for the refiner") - refiner_positive_aesthetic_score: Union[float, None] = Field( - default=None, description="The aesthetic score used for the refiner" + refiner_steps: Optional[int] = InputField( + default=None, + description="The number of steps used for the refiner", ) - refiner_negative_aesthetic_score: Union[float, None] = Field( - default=None, description="The aesthetic score used for the refiner" + refiner_scheduler: Optional[str] = InputField( + default=None, + description="The scheduler used for the refiner", + ) + refiner_positive_aesthetic_store: Optional[float] = InputField( + default=None, + description="The aesthetic score used for the refiner", + ) + refiner_negative_aesthetic_store: Optional[float] = InputField( + default=None, + description="The aesthetic score used for the refiner", + ) + refiner_start: Optional[float] = InputField( + default=None, + description="The start value used for refiner denoising", ) - refiner_start: Union[float, None] = Field(default=None, description="The start value used for refiner denoising") - - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "Metadata Accumulator", - "tags": ["image", "metadata", "generation"], - }, - } def invoke(self, context: InvocationContext) -> MetadataAccumulatorOutput: """Collects and outputs a CoreMetadata object""" diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 0d21f8f0ce..de32a9948f 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -4,7 +4,18 @@ from typing import List, Literal, Optional, Union from pydantic import BaseModel, Field from ...backend.model_management import BaseModelType, ModelType, SubModelType -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext +from .baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + FieldDescriptions, + InputField, + Input, + InvocationContext, + OutputField, + UITypeHint, + tags, + title, +) class ModelInfo(BaseModel): @@ -39,13 +50,11 @@ class VaeField(BaseModel): class ModelLoaderOutput(BaseInvocationOutput): """Model loader output""" - # fmt: off type: Literal["model_loader_output"] = "model_loader_output" - unet: UNetField = Field(default=None, description="UNet submodel") - clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") - vae: VaeField = Field(default=None, description="Vae submodel") - # fmt: on + unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet") + clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP") + vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE") class MainModelField(BaseModel): @@ -63,24 +72,17 @@ class LoRAModelField(BaseModel): base_model: BaseModelType = Field(description="Base model") +@title("Main Model Loader") +@tags("model") class MainModelLoaderInvocation(BaseInvocation): """Loads a main model, outputting its submodels.""" type: Literal["main_model_loader"] = "main_model_loader" - model: MainModelField = Field(description="The model to load") + # Inputs + model: MainModelField = InputField(description=FieldDescriptions.main_model, input=Input.Direct) # TODO: precision? - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "Model Loader", - "tags": ["model", "loader"], - "type_hints": {"model": "model"}, - }, - } - def invoke(self, context: InvocationContext) -> ModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name @@ -155,22 +157,6 @@ class MainModelLoaderInvocation(BaseInvocation): loras=[], skipped_layers=0, ), - clip2=ClipField( - tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Tokenizer2, - ), - text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.TextEncoder2, - ), - loras=[], - skipped_layers=0, - ), vae=VaeField( vae=ModelInfo( model_name=model_name, @@ -188,30 +174,27 @@ class LoraLoaderOutput(BaseInvocationOutput): # fmt: off type: Literal["lora_loader_output"] = "lora_loader_output" - unet: Optional[UNetField] = Field(default=None, description="UNet submodel") - clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels") + unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet") + clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP") # fmt: on +@title("LoRA Loader") +@tags("lora", "model") class LoraLoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" type: Literal["lora_loader"] = "lora_loader" - lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name") - weight: float = Field(default=0.75, description="With what weight to apply lora") - - unet: Optional[UNetField] = Field(description="UNet model for applying lora") - clip: Optional[ClipField] = Field(description="Clip model for applying lora") - - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "Lora Loader", - "tags": ["lora", "loader"], - "type_hints": {"lora": "lora_model"}, - }, - } + # Inputs + lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA") + weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) + unet: Optional[UNetField] = InputField( + default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNet" + ) + clip: Optional[ClipField] = InputField( + default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP" + ) def invoke(self, context: InvocationContext) -> LoraLoaderOutput: if self.lora is None: @@ -263,37 +246,35 @@ class LoraLoaderInvocation(BaseInvocation): class SDXLLoraLoaderOutput(BaseInvocationOutput): - """Model loader output""" + """SDXL LoRA Loader Output""" # fmt: off type: Literal["sdxl_lora_loader_output"] = "sdxl_lora_loader_output" - unet: Optional[UNetField] = Field(default=None, description="UNet submodel") - clip: Optional[ClipField] = Field(default=None, description="Tokenizer and text_encoder submodels") - clip2: Optional[ClipField] = Field(default=None, description="Tokenizer2 and text_encoder2 submodels") + unet: Optional[UNetField] = OutputField(default=None, description=FieldDescriptions.unet, title="UNet") + clip: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 1") + clip2: Optional[ClipField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP 2") # fmt: on +@title("SDXL LoRA Loader") +@tags("sdxl", "lora", "model") class SDXLLoraLoaderInvocation(BaseInvocation): """Apply selected lora to unet and text_encoder.""" type: Literal["sdxl_lora_loader"] = "sdxl_lora_loader" - lora: Union[LoRAModelField, None] = Field(default=None, description="Lora model name") - weight: float = Field(default=0.75, description="With what weight to apply lora") - - unet: Optional[UNetField] = Field(description="UNet model for applying lora") - clip: Optional[ClipField] = Field(description="Clip model for applying lora") - clip2: Optional[ClipField] = Field(description="Clip2 model for applying lora") - - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "SDXL Lora Loader", - "tags": ["lora", "loader"], - "type_hints": {"lora": "lora_model"}, - }, - } + lora: LoRAModelField = InputField(description=FieldDescriptions.lora_model, input=Input.Direct, title="LoRA") + weight: float = Field(default=0.75, description=FieldDescriptions.lora_weight) + unet: Optional[UNetField] = Field( + default=None, description=FieldDescriptions.unet, input=Input.Connection, title="UNET" + ) + clip: Optional[ClipField] = Field( + default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 1" + ) + clip2: Optional[ClipField] = Field( + default=None, description=FieldDescriptions.clip, input=Input.Connection, title="CLIP 2" + ) def invoke(self, context: InvocationContext) -> SDXLLoraLoaderOutput: if self.lora is None: @@ -369,29 +350,23 @@ class VAEModelField(BaseModel): class VaeLoaderOutput(BaseInvocationOutput): """Model loader output""" - # fmt: off type: Literal["vae_loader_output"] = "vae_loader_output" - vae: VaeField = Field(default=None, description="Vae model") - # fmt: on + # Outputs + vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE") +@title("VAE Loader") +@tags("vae", "model") class VaeLoaderInvocation(BaseInvocation): """Loads a VAE model, outputting a VaeLoaderOutput""" type: Literal["vae_loader"] = "vae_loader" - vae_model: VAEModelField = Field(description="The VAE to load") - - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "VAE Loader", - "tags": ["vae", "loader"], - "type_hints": {"vae_model": "vae_model"}, - }, - } + # Inputs + vae_model: VAEModelField = InputField( + description=FieldDescriptions.vae_model, input=Input.Direct, ui_type_hint=UITypeHint.VaeModelField, title="VAE" + ) def invoke(self, context: InvocationContext) -> VaeLoaderOutput: base_model = self.vae_model.base_model diff --git a/invokeai/app/invocations/noise.py b/invokeai/app/invocations/noise.py index db64e5b6e5..7049dad61a 100644 --- a/invokeai/app/invocations/noise.py +++ b/invokeai/app/invocations/noise.py @@ -1,19 +1,24 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team -import math from typing import Literal -from pydantic import Field, validator import torch -from invokeai.app.invocations.latent import LatentsField +from pydantic import validator +from invokeai.app.invocations.latent import LatentsField from invokeai.app.util.misc import SEED_MAX, get_random_seed + from ...backend.util.devices import choose_torch_device, torch_dtype from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, - InvocationConfig, + FieldDescriptions, + InputField, InvocationContext, + OutputField, + UITypeHint, + tags, + title, ) """ @@ -61,14 +66,12 @@ Nodes class NoiseOutput(BaseInvocationOutput): """Invocation noise output""" - # fmt: off - type: Literal["noise_output"] = "noise_output" + type: Literal["noise_output"] = "noise_output" # Inputs - noise: LatentsField = Field(default=None, description="The output noise") - width: int = Field(description="The width of the noise in pixels") - height: int = Field(description="The height of the noise in pixels") - # fmt: on + noise: LatentsField = OutputField(default=None, description=FieldDescriptions.noise) + width: int = OutputField(description=FieldDescriptions.width) + height: int = OutputField(description=FieldDescriptions.height) def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int): @@ -79,44 +82,37 @@ def build_noise_output(latents_name: str, latents: torch.Tensor, seed: int): ) +@title("Noise") +@tags("latents", "noise") class NoiseInvocation(BaseInvocation): """Generates latent noise.""" type: Literal["noise"] = "noise" # Inputs - seed: int = Field( + seed: int = InputField( ge=0, le=SEED_MAX, - description="The seed to use", + description=FieldDescriptions.seed, default_factory=get_random_seed, ) - width: int = Field( + width: int = InputField( default=512, multiple_of=8, gt=0, - description="The width of the resulting noise", + description=FieldDescriptions.width, ) - height: int = Field( + height: int = InputField( default=512, multiple_of=8, gt=0, - description="The height of the resulting noise", + description=FieldDescriptions.height, ) - use_cpu: bool = Field( + use_cpu: bool = InputField( default=True, description="Use CPU for noise generation (for reproducible results across platforms)", ) - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "Noise", - "tags": ["latents", "noise"], - }, - } - @validator("seed", pre=True) def modulo_seed(cls, v): """Returns the seed modulo (SEED_MAX + 1) to ensure it is within the valid range.""" diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py index 4f04a4f023..cd73d35d78 100644 --- a/invokeai/app/invocations/onnx.py +++ b/invokeai/app/invocations/onnx.py @@ -1,37 +1,44 @@ # Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779) +import inspect +import re from contextlib import ExitStack from typing import List, Literal, Optional, Union -import re -import inspect - -from pydantic import BaseModel, Field, validator -import torch import numpy as np +import torch from diffusers import ControlNetModel, DPMSolverMultistepScheduler from diffusers.image_processor import VaeImageProcessor from diffusers.schedulers import SchedulerMixin as Scheduler - -from ..models.image import ImageCategory, ImageField, ResourceOrigin -from ...backend.model_management import ONNXModelPatcher -from ...backend.util import choose_torch_device -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext -from .compel import ConditioningField -from .controlnet_image_processors import ControlField -from .image import ImageOutput -from .model import ModelInfo, UNetField, VaeField +from pydantic import BaseModel, Field, validator +from tqdm import tqdm from invokeai.app.invocations.metadata import CoreMetadata -from invokeai.backend import BaseModelType, ModelType, SubModelType from invokeai.app.util.step_callback import stable_diffusion_step_callback +from invokeai.backend import BaseModelType, ModelType, SubModelType + +from ...backend.model_management import ONNXModelPatcher from ...backend.stable_diffusion import PipelineIntermediateState - -from tqdm import tqdm -from .model import ClipField -from .latent import LatentsField, LatentsOutput, build_latents_output, get_scheduler, SAMPLER_NAME_VALUES -from .compel import CompelOutput - +from ...backend.util import choose_torch_device +from ..models.image import ImageCategory, ImageField, ResourceOrigin +from .baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + FieldDescriptions, + InputField, + Input, + InvocationContext, + OutputField, + UIComponent, + UITypeHint, + tags, + title, +) +from .compel import CompelOutput, ConditioningField +from .controlnet_image_processors import ControlField +from .image import ImageOutput +from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, build_latents_output, get_scheduler +from .model import ClipField, ModelInfo, UNetField, VaeField ORT_TO_NP_TYPE = { "tensor(bool)": np.bool_, @@ -51,11 +58,13 @@ ORT_TO_NP_TYPE = { PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))] +@title("ONNX Prompt (Raw)") +@tags("onnx", "prompt") class ONNXPromptInvocation(BaseInvocation): type: Literal["prompt_onnx"] = "prompt_onnx" - prompt: str = Field(default="", description="Prompt") - clip: ClipField = Field(None, description="Clip to use") + prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea) + clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) def invoke(self, context: InvocationContext) -> CompelOutput: tokenizer_info = context.services.model_manager.get_model( @@ -134,25 +143,48 @@ class ONNXPromptInvocation(BaseInvocation): # Text to image +@title("ONNX Text to Latents") +@tags("latents", "inference", "txt2img", "onnx") class ONNXTextToLatentsInvocation(BaseInvocation): """Generates latents from conditionings.""" type: Literal["t2l_onnx"] = "t2l_onnx" # Inputs - # fmt: off - positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation") - negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation") - noise: Optional[LatentsField] = Field(description="The noise to use") - steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") - cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) - scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" ) - precision: PRECISION_VALUES = Field(default = "tensor(float16)", description="The precision to use when generating latents") - unet: UNetField = Field(default=None, description="UNet submodel") - control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use") - # seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) - # seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") - # fmt: on + positive_conditioning: ConditioningField = InputField( + description=FieldDescriptions.positive_cond, + input=Input.Connection, + ) + negative_conditioning: ConditioningField = InputField( + description=FieldDescriptions.negative_cond, + input=Input.Connection, + ) + noise: LatentsField = InputField( + description=FieldDescriptions.noise, + input=Input.Connection, + ) + steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps) + cfg_scale: Union[float, List[float]] = InputField( + default=7.5, + ge=1, + description=FieldDescriptions.cfg_scale, + ui_type_hint=UITypeHint.Float, + ) + scheduler: SAMPLER_NAME_VALUES = InputField( + default="euler", description=FieldDescriptions.scheduler, input=Input.Direct + ) + precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision) + unet: UNetField = InputField( + description=FieldDescriptions.unet, + input=Input.Connection, + ) + control: Optional[Union[ControlField, list[ControlField]]] = InputField( + default=None, + description=FieldDescriptions.control, + ui_type_hint=UITypeHint.ControlField, + ) + # seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", ) + # seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'") @validator("cfg_scale") def ge_one(cls, v): @@ -166,20 +198,6 @@ class ONNXTextToLatentsInvocation(BaseInvocation): raise ValueError("cfg_scale must be greater than 1") return v - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "tags": ["latents"], - "type_hints": { - "model": "model", - "control": "control", - # "cfg_scale": "float", - "cfg_scale": "number", - }, - }, - } - # based on # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375 def invoke(self, context: InvocationContext) -> LatentsOutput: @@ -300,26 +318,28 @@ class ONNXTextToLatentsInvocation(BaseInvocation): # Latent to image +@title("ONNX Latents to Image") +@tags("latents", "image", "vae", "onnx") class ONNXLatentsToImageInvocation(BaseInvocation): """Generates an image from latents.""" type: Literal["l2i_onnx"] = "l2i_onnx" # Inputs - latents: Optional[LatentsField] = Field(description="The latents to generate an image from") - vae: VaeField = Field(default=None, description="Vae submodel") - metadata: Optional[CoreMetadata] = Field( - default=None, description="Optional core metadata to be written to the image" + latents: LatentsField = InputField( + description=FieldDescriptions.denoised_latents, + input=Input.Connection, ) - # tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)") - - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "tags": ["latents", "image"], - }, - } + vae: VaeField = InputField( + description=FieldDescriptions.vae, + input=Input.Connection, + ) + metadata: Optional[CoreMetadata] = InputField( + default=None, + description=FieldDescriptions.core_metadata, + ui_hidden=True, + ) + # tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)") def invoke(self, context: InvocationContext) -> ImageOutput: latents = context.services.latents.get(self.latents.latents_name) @@ -373,89 +393,13 @@ class ONNXModelLoaderOutput(BaseInvocationOutput): # fmt: off type: Literal["model_loader_output_onnx"] = "model_loader_output_onnx" - unet: UNetField = Field(default=None, description="UNet submodel") - clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") - vae_decoder: VaeField = Field(default=None, description="Vae submodel") - vae_encoder: VaeField = Field(default=None, description="Vae submodel") + unet: UNetField = OutputField(default=None, description=FieldDescriptions.unet, title="UNet") + clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP") + vae_decoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Decoder") + vae_encoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Encoder") # fmt: on -class ONNXSD1ModelLoaderInvocation(BaseInvocation): - """Loading submodels of selected model.""" - - type: Literal["sd1_model_loader_onnx"] = "sd1_model_loader_onnx" - - model_name: str = Field(default="", description="Model to load") - # TODO: precision? - - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": {"tags": ["model", "loader"], "type_hints": {"model_name": "model"}}, # TODO: rename to model_name? - } - - def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput: - model_name = "stable-diffusion-v1-5" - base_model = BaseModelType.StableDiffusion1 - - # TODO: not found exceptions - if not context.services.model_manager.model_exists( - model_name=model_name, - base_model=BaseModelType.StableDiffusion1, - model_type=ModelType.ONNX, - ): - raise Exception(f"Unkown model name: {model_name}!") - - return ONNXModelLoaderOutput( - unet=UNetField( - unet=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=ModelType.ONNX, - submodel=SubModelType.UNet, - ), - scheduler=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=ModelType.ONNX, - submodel=SubModelType.Scheduler, - ), - loras=[], - ), - clip=ClipField( - tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=ModelType.ONNX, - submodel=SubModelType.Tokenizer, - ), - text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=ModelType.ONNX, - submodel=SubModelType.TextEncoder, - ), - loras=[], - ), - vae_decoder=VaeField( - vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=ModelType.ONNX, - submodel=SubModelType.VaeDecoder, - ), - ), - vae_encoder=VaeField( - vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=ModelType.ONNX, - submodel=SubModelType.VaeEncoder, - ), - ), - ) - - class OnnxModelField(BaseModel): """Onnx model field""" @@ -464,22 +408,17 @@ class OnnxModelField(BaseModel): model_type: ModelType = Field(description="Model Type") +@title("ONNX Model Loader") +@tags("onnx", "model") class OnnxModelLoaderInvocation(BaseInvocation): """Loads a main model, outputting its submodels.""" type: Literal["onnx_model_loader"] = "onnx_model_loader" - model: OnnxModelField = Field(description="The model to load") - - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "Onnx Model Loader", - "tags": ["model", "loader"], - "type_hints": {"model": "model"}, - }, - } + # Inputs + model: OnnxModelField = InputField( + description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type_hint=UITypeHint.ONNXModelField + ) def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput: base_model = self.model.base_model diff --git a/invokeai/app/invocations/param_easing.py b/invokeai/app/invocations/param_easing.py index f910e5379c..d35dba5df3 100644 --- a/invokeai/app/invocations/param_easing.py +++ b/invokeai/app/invocations/param_easing.py @@ -1,73 +1,63 @@ import io -from typing import Literal, Optional, Any +from typing import Literal, Optional -# from PIL.Image import Image -import PIL.Image -from matplotlib.ticker import MaxNLocator -from matplotlib.figure import Figure - -from pydantic import BaseModel, Field -import numpy as np import matplotlib.pyplot as plt +import numpy as np +import PIL.Image from easing_functions import ( - LinearInOut, - QuadEaseInOut, - QuadEaseIn, - QuadEaseOut, - CubicEaseInOut, - CubicEaseIn, - CubicEaseOut, - QuarticEaseInOut, - QuarticEaseIn, - QuarticEaseOut, - QuinticEaseInOut, - QuinticEaseIn, - QuinticEaseOut, - SineEaseInOut, - SineEaseIn, - SineEaseOut, - CircularEaseIn, - CircularEaseInOut, - CircularEaseOut, - ExponentialEaseInOut, - ExponentialEaseIn, - ExponentialEaseOut, - ElasticEaseIn, - ElasticEaseInOut, - ElasticEaseOut, BackEaseIn, BackEaseInOut, BackEaseOut, BounceEaseIn, BounceEaseInOut, BounceEaseOut, + CircularEaseIn, + CircularEaseInOut, + CircularEaseOut, + CubicEaseIn, + CubicEaseInOut, + CubicEaseOut, + ElasticEaseIn, + ElasticEaseInOut, + ElasticEaseOut, + ExponentialEaseIn, + ExponentialEaseInOut, + ExponentialEaseOut, + LinearInOut, + QuadEaseIn, + QuadEaseInOut, + QuadEaseOut, + QuarticEaseIn, + QuarticEaseInOut, + QuarticEaseOut, + QuinticEaseIn, + QuinticEaseInOut, + QuinticEaseOut, + SineEaseIn, + SineEaseInOut, + SineEaseOut, ) +from matplotlib.figure import Figure +from matplotlib.ticker import MaxNLocator +from pydantic import BaseModel, Field -from .baseinvocation import ( - BaseInvocation, - BaseInvocationOutput, - InvocationContext, - InvocationConfig, -) from ...backend.util.logging import InvokeAILogger +from .baseinvocation import BaseInvocation, InputField, InvocationContext, tags, title from .collections import FloatCollectionOutput +@title("Float Range") +@tags("math", "range") class FloatLinearRangeInvocation(BaseInvocation): """Creates a range""" type: Literal["float_range"] = "float_range" # Inputs - start: float = Field(default=5, description="The first value of the range") - stop: float = Field(default=10, description="The last value of the range") - steps: int = Field(default=30, description="number of values to interpolate over (including start and stop)") - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Linear Range (Float)", "tags": ["math", "float", "linear", "range"]}, - } + start: float = InputField(default=5, description="The first value of the range") + stop: float = InputField(default=10, description="The last value of the range") + steps: int = InputField(default=30, description="number of values to interpolate over (including start and stop)") def invoke(self, context: InvocationContext) -> FloatCollectionOutput: param_list = list(np.linspace(self.start, self.stop, self.steps)) @@ -108,37 +98,32 @@ EASING_FUNCTIONS_MAP = { "BounceInOut": BounceEaseInOut, } -EASING_FUNCTION_KEYS: Any = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))] +EASING_FUNCTION_KEYS = Literal[tuple(list(EASING_FUNCTIONS_MAP.keys()))] # actually I think for now could just use CollectionOutput (which is list[Any] +@title("Step Param Easing") +@tags("step", "easing") class StepParamEasingInvocation(BaseInvocation): """Experimental per-step parameter easing for denoising steps""" type: Literal["step_param_easing"] = "step_param_easing" # Inputs - # fmt: off - easing: EASING_FUNCTION_KEYS = Field(default="Linear", description="The easing function to use") - num_steps: int = Field(default=20, description="number of denoising steps") - start_value: float = Field(default=0.0, description="easing starting value") - end_value: float = Field(default=1.0, description="easing ending value") - start_step_percent: float = Field(default=0.0, description="fraction of steps at which to start easing") - end_step_percent: float = Field(default=1.0, description="fraction of steps after which to end easing") + easing: EASING_FUNCTION_KEYS = InputField(default="Linear", description="The easing function to use") + num_steps: int = InputField(default=20, description="number of denoising steps") + start_value: float = InputField(default=0.0, description="easing starting value") + end_value: float = InputField(default=1.0, description="easing ending value") + start_step_percent: float = InputField(default=0.0, description="fraction of steps at which to start easing") + end_step_percent: float = InputField(default=1.0, description="fraction of steps after which to end easing") # if None, then start_value is used prior to easing start - pre_start_value: Optional[float] = Field(default=None, description="value before easing start") + pre_start_value: Optional[float] = InputField(default=None, description="value before easing start") # if None, then end value is used prior to easing end - post_end_value: Optional[float] = Field(default=None, description="value after easing end") - mirror: bool = Field(default=False, description="include mirror of easing function") + post_end_value: Optional[float] = InputField(default=None, description="value after easing end") + mirror: bool = InputField(default=False, description="include mirror of easing function") # FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely - # alt_mirror: bool = Field(default=False, description="alternative mirroring by dual easing") - show_easing_plot: bool = Field(default=False, description="show easing plot") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Param Easing By Step", "tags": ["param", "step", "easing"]}, - } + # alt_mirror: bool = InputField(default=False, description="alternative mirroring by dual easing") + show_easing_plot: bool = InputField(default=False, description="show easing plot") def invoke(self, context: InvocationContext) -> FloatCollectionOutput: log_diagnostics = False diff --git a/invokeai/app/invocations/params.py b/invokeai/app/invocations/params.py index 513eb8762f..27382d8f8d 100644 --- a/invokeai/app/invocations/params.py +++ b/invokeai/app/invocations/params.py @@ -2,82 +2,80 @@ from typing import Literal -from pydantic import Field - from invokeai.app.invocations.prompt import PromptOutput -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext +from .baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + InputField, + InvocationContext, + OutputField, + tags, + title, +) from .math import FloatOutput, IntOutput # Pass-through parameter nodes - used by subgraphs +@title("Integer Parameter") +@tags("integer") class ParamIntInvocation(BaseInvocation): """An integer parameter""" - # fmt: off type: Literal["param_int"] = "param_int" - a: int = Field(default=0, description="The integer value") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"tags": ["param", "integer"], "title": "Integer Parameter"}, - } + # Inputs + a: int = InputField(default=0, description="The integer value") def invoke(self, context: InvocationContext) -> IntOutput: return IntOutput(a=self.a) +@title("Float Parameter") +@tags("float") class ParamFloatInvocation(BaseInvocation): """A float parameter""" - # fmt: off type: Literal["param_float"] = "param_float" - param: float = Field(default=0.0, description="The float value") - # fmt: on - class Config(InvocationConfig): - schema_extra = { - "ui": {"tags": ["param", "float"], "title": "Float Parameter"}, - } + # Inputs + param: float = InputField(default=0.0, description="The float value") def invoke(self, context: InvocationContext) -> FloatOutput: - return FloatOutput(param=self.param) + return FloatOutput(a=self.param) class StringOutput(BaseInvocationOutput): """A string output""" type: Literal["string_output"] = "string_output" - text: str = Field(default=None, description="The output string") + text: str = OutputField(description="The output string") +@title("String Parameter") +@tags("string") class ParamStringInvocation(BaseInvocation): """A string parameter""" type: Literal["param_string"] = "param_string" - text: str = Field(default="", description="The string value") - class Config(InvocationConfig): - schema_extra = { - "ui": {"tags": ["param", "string"], "title": "String Parameter"}, - } + # Inputs + text: str = InputField(default="", description="The string value") def invoke(self, context: InvocationContext) -> StringOutput: return StringOutput(text=self.text) +@title("Prompt Parameter") +@tags("prompt") class ParamPromptInvocation(BaseInvocation): """A prompt input parameter""" type: Literal["param_prompt"] = "param_prompt" - prompt: str = Field(default="", description="The prompt value") - class Config(InvocationConfig): - schema_extra = { - "ui": {"tags": ["param", "prompt"], "title": "Prompt"}, - } + # Inputs + prompt: str = InputField(default="", description="The prompt value") def invoke(self, context: InvocationContext) -> PromptOutput: return PromptOutput(prompt=self.prompt) diff --git a/invokeai/app/invocations/prompt.py b/invokeai/app/invocations/prompt.py index 83a397ddcf..57320c695c 100644 --- a/invokeai/app/invocations/prompt.py +++ b/invokeai/app/invocations/prompt.py @@ -2,56 +2,52 @@ from os.path import exists from typing import Literal, Optional import numpy as np -from pydantic import Field, validator +from pydantic import validator -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext +from .baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + InputField, + InvocationContext, + OutputField, + UIComponent, + UITypeHint, + title, + tags, +) from dynamicprompts.generators import RandomPromptGenerator, CombinatorialPromptGenerator class PromptOutput(BaseInvocationOutput): """Base class for invocations that output a prompt""" - # fmt: off type: Literal["prompt"] = "prompt" - prompt: str = Field(default=None, description="The output prompt") - # fmt: on - - class Config: - schema_extra = { - "required": [ - "type", - "prompt", - ] - } + prompt: str = OutputField(description="The output prompt") class PromptCollectionOutput(BaseInvocationOutput): """Base class for invocations that output a collection of prompts""" - # fmt: off type: Literal["prompt_collection_output"] = "prompt_collection_output" - prompt_collection: list[str] = Field(description="The output prompt collection") - count: int = Field(description="The size of the prompt collection") - # fmt: on - - class Config: - schema_extra = {"required": ["type", "prompt_collection", "count"]} + prompt_collection: list[str] = OutputField( + description="The output prompt collection", ui_type_hint=UITypeHint.StringCollection + ) + count: int = OutputField(description="The size of the prompt collection") +@title("Dynamic Prompt") +@tags("prompt", "collection") class DynamicPromptInvocation(BaseInvocation): """Parses a prompt using adieyal/dynamicprompts' random or combinatorial generator""" type: Literal["dynamic_prompt"] = "dynamic_prompt" - prompt: str = Field(description="The prompt to parse with dynamicprompts") - max_prompts: int = Field(default=1, description="The number of prompts to generate") - combinatorial: bool = Field(default=False, description="Whether to use the combinatorial generator") - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Dynamic Prompt", "tags": ["prompt", "dynamic"]}, - } + # Inputs + prompt: str = InputField(description="The prompt to parse with dynamicprompts", ui_component=UIComponent.Textarea) + max_prompts: int = InputField(default=1, description="The number of prompts to generate") + combinatorial: bool = InputField(default=False, description="Whether to use the combinatorial generator") def invoke(self, context: InvocationContext) -> PromptCollectionOutput: if self.combinatorial: @@ -64,24 +60,23 @@ class DynamicPromptInvocation(BaseInvocation): return PromptCollectionOutput(prompt_collection=prompts, count=len(prompts)) +@title("Prompts from File") +@tags("prompt", "file") class PromptsFromFileInvocation(BaseInvocation): """Loads prompts from a text file""" - # fmt: off - type: Literal['prompt_from_file'] = 'prompt_from_file' + type: Literal["prompt_from_file"] = "prompt_from_file" # Inputs - file_path: str = Field(description="Path to prompt text file") - pre_prompt: Optional[str] = Field(description="String to prepend to each prompt") - post_prompt: Optional[str] = Field(description="String to append to each prompt") - start_line: int = Field(default=1, ge=1, description="Line in the file to start start from") - max_prompts: int = Field(default=1, ge=0, description="Max lines to read from file (0=all)") - # fmt: on - - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Prompts From File", "tags": ["prompt", "file"]}, - } + file_path: str = InputField(description="Path to prompt text file", ui_type_hint=UITypeHint.FilePath) + pre_prompt: Optional[str] = InputField( + description="String to prepend to each prompt", ui_component=UIComponent.Textarea + ) + post_prompt: Optional[str] = InputField( + description="String to append to each prompt", ui_component=UIComponent.Textarea + ) + start_line: int = InputField(default=1, ge=1, description="Line in the file to start start from") + max_prompts: int = InputField(default=1, ge=0, description="Max lines to read from file (0=all)") @validator("file_path") def file_path_exists(cls, v): diff --git a/invokeai/app/invocations/sdxl.py b/invokeai/app/invocations/sdxl.py index a5a1c2c641..d25a37327e 100644 --- a/invokeai/app/invocations/sdxl.py +++ b/invokeai/app/invocations/sdxl.py @@ -1,55 +1,55 @@ -import torch from typing import Literal -from pydantic import Field from ...backend.model_management import ModelType, SubModelType -from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext -from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo +from .baseinvocation import ( + BaseInvocation, + BaseInvocationOutput, + FieldDescriptions, + Input, + InputField, + InvocationContext, + OutputField, + UITypeHint, + tags, + title, +) +from .model import ClipField, MainModelField, ModelInfo, UNetField, VaeField class SDXLModelLoaderOutput(BaseInvocationOutput): """SDXL base model loader output""" - # fmt: off type: Literal["sdxl_model_loader_output"] = "sdxl_model_loader_output" - unet: UNetField = Field(default=None, description="UNet submodel") - clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") - clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") - vae: VaeField = Field(default=None, description="Vae submodel") - # fmt: on + unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet") + clip: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 1") + clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2") + vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE") class SDXLRefinerModelLoaderOutput(BaseInvocationOutput): """SDXL refiner model loader output""" - # fmt: off type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output" - unet: UNetField = Field(default=None, description="UNet submodel") - clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") - vae: VaeField = Field(default=None, description="Vae submodel") - # fmt: on - # fmt: on + + unet: UNetField = OutputField(description=FieldDescriptions.unet, title="UNet") + clip2: ClipField = OutputField(description=FieldDescriptions.clip, title="CLIP 2") + vae: VaeField = OutputField(description=FieldDescriptions.vae, title="VAE") +@title("SDXL Main Model Loader") +@tags("model", "sdxl") class SDXLModelLoaderInvocation(BaseInvocation): """Loads an sdxl base model, outputting its submodels.""" type: Literal["sdxl_model_loader"] = "sdxl_model_loader" - model: MainModelField = Field(description="The model to load") + # Inputs + model: MainModelField = InputField( + description=FieldDescriptions.sdxl_main_model, input=Input.Direct, ui_type_hint=UITypeHint.SDXLMainModelField + ) # TODO: precision? - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "SDXL Model Loader", - "tags": ["model", "loader", "sdxl"], - "type_hints": {"model": "model"}, - }, - } - def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name @@ -122,24 +122,21 @@ class SDXLModelLoaderInvocation(BaseInvocation): ) +@title("SDXL Refiner Model Loader") +@tags("model", "sdxl", "refiner") class SDXLRefinerModelLoaderInvocation(BaseInvocation): """Loads an sdxl refiner model, outputting its submodels.""" type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader" - model: MainModelField = Field(description="The model to load") + # Inputs + model: MainModelField = InputField( + description=FieldDescriptions.sdxl_refiner_model, + input=Input.Direct, + ui_type_hint=UITypeHint.SDXLRefinerModelField, + ) # TODO: precision? - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "title": "SDXL Refiner Model Loader", - "tags": ["model", "loader", "sdxl_refiner"], - "type_hints": {"model": "refiner_model"}, - }, - } - def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput: base_model = self.model.base_model model_name = self.model.model_name diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index fd220223db..4e9c9fac2f 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -6,12 +6,11 @@ import cv2 as cv import numpy as np from basicsr.archs.rrdbnet_arch import RRDBNet from PIL import Image -from pydantic import Field from realesrgan import RealESRGANer from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin -from .baseinvocation import BaseInvocation, InvocationConfig, InvocationContext +from .baseinvocation import BaseInvocation, InputField, InvocationContext, title, tags from .image import ImageOutput # TODO: Populate this from disk? @@ -24,17 +23,16 @@ ESRGAN_MODELS = Literal[ ] +@title("Upscale (RealESRGAN)") +@tags("esrgan", "upscale") class ESRGANInvocation(BaseInvocation): """Upscales an image using RealESRGAN.""" type: Literal["esrgan"] = "esrgan" - image: Union[ImageField, None] = Field(default=None, description="The input image") - model_name: ESRGAN_MODELS = Field(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use") - class Config(InvocationConfig): - schema_extra = { - "ui": {"title": "Upscale (RealESRGAN)", "tags": ["image", "upscale", "realesrgan"]}, - } + # Inputs + image: ImageField = InputField(description="The input image") + model_name: ESRGAN_MODELS = InputField(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use") def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) diff --git a/invokeai/app/models/image.py b/invokeai/app/models/image.py index 2a5a0f9d3b..3a34a7d6da 100644 --- a/invokeai/app/models/image.py +++ b/invokeai/app/models/image.py @@ -5,14 +5,13 @@ from pydantic import BaseModel, Field from invokeai.app.util.metaenum import MetaEnum from ..invocations.baseinvocation import ( BaseInvocationOutput, - InvocationConfig, ) class ImageField(BaseModel): """An image field used for passing image objects between invocations""" - image_name: Optional[str] = Field(default=None, description="The name of the image") + image_name: str = Field(description="The name of the image") class Config: schema_extra = {"required": ["image_name"]} @@ -36,17 +35,6 @@ class ProgressImage(BaseModel): dataURL: str = Field(description="The image data as a b64 data URL") -class PILInvocationConfig(BaseModel): - """Helper class to provide all PIL invocations with additional config""" - - class Config(InvocationConfig): - schema_extra = { - "ui": { - "tags": ["PIL", "image"], - }, - } - - class ImageOutput(BaseInvocationOutput): """Base class for invocations that output an image""" diff --git a/invokeai/app/services/graph.py b/invokeai/app/services/graph.py index d7f021df14..5dacfe1ec1 100644 --- a/invokeai/app/services/graph.py +++ b/invokeai/app/services/graph.py @@ -3,16 +3,7 @@ import copy import itertools import uuid -from typing import ( - Annotated, - Any, - Literal, - Optional, - Union, - get_args, - get_origin, - get_type_hints, -) +from typing import Annotated, Any, Literal, Optional, Union, get_args, get_origin, get_type_hints import networkx as nx from pydantic import BaseModel, root_validator, validator @@ -22,7 +13,11 @@ from ..invocations import * from ..invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, + Input, + InputField, InvocationContext, + OutputField, + UITypeHint, ) # in 3.10 this would be "from types import NoneType" @@ -183,15 +178,9 @@ class IterateInvocationOutput(BaseInvocationOutput): type: Literal["iterate_output"] = "iterate_output" - item: Any = Field(description="The item being iterated over") - - class Config: - schema_extra = { - "required": [ - "type", - "item", - ] - } + item: Any = OutputField( + description="The item being iterated over", title="Collection Item", ui_type_hint=UITypeHint.CollectionItem + ) # TODO: Fill this out and move to invocations @@ -200,8 +189,10 @@ class IterateInvocation(BaseInvocation): type: Literal["iterate"] = "iterate" - collection: list[Any] = Field(description="The list of items to iterate over", default_factory=list) - index: int = Field(description="The index, will be provided on executed iterators", default=0) + collection: list[Any] = InputField( + description="The list of items to iterate over", default_factory=list, ui_type_hint=UITypeHint.Collection + ) + index: int = InputField(description="The index, will be provided on executed iterators", default=0, ui_hidden=True) def invoke(self, context: InvocationContext) -> IterateInvocationOutput: """Produces the outputs as values""" @@ -211,15 +202,9 @@ class IterateInvocation(BaseInvocation): class CollectInvocationOutput(BaseInvocationOutput): type: Literal["collect_output"] = "collect_output" - collection: list[Any] = Field(description="The collection of input items") - - class Config: - schema_extra = { - "required": [ - "type", - "collection", - ] - } + collection: list[Any] = OutputField( + description="The collection of input items", title="Collection", ui_type_hint=UITypeHint.Collection + ) class CollectInvocation(BaseInvocation): @@ -227,13 +212,14 @@ class CollectInvocation(BaseInvocation): type: Literal["collect"] = "collect" - item: Any = Field( + item: Any = InputField( description="The item to collect (all inputs must be of the same type)", - default=None, + ui_type_hint=UITypeHint.CollectionItem, + title="Collection Item", + input=Input.Connection, ) - collection: list[Any] = Field( - description="The collection, will be provided on execution", - default_factory=list, + collection: list[Any] = InputField( + description="The collection, will be provided on execution", default_factory=list, ui_hidden=True ) def invoke(self, context: InvocationContext) -> CollectInvocationOutput: diff --git a/invokeai/app/services/processor.py b/invokeai/app/services/processor.py index 41170a304b..b8c2f93e93 100644 --- a/invokeai/app/services/processor.py +++ b/invokeai/app/services/processor.py @@ -87,7 +87,10 @@ class DefaultInvocationProcessor(InvocationProcessorABC): # Invoke try: with statistics.collect_stats(invocation, graph_execution_state.id): - outputs = invocation.invoke( + # use the internal invoke_internal(), which wraps the node's invoke() method in + # this accomodates nodes which require a value, but get it only from a + # connection + outputs = invocation.invoke_internal( InvocationContext( services=self.__invoker.services, graph_execution_state_id=graph_execution_state.id, diff --git a/invokeai/app/services/sqlite.py b/invokeai/app/services/sqlite.py index 855f3f1939..3c46b1c2a0 100644 --- a/invokeai/app/services/sqlite.py +++ b/invokeai/app/services/sqlite.py @@ -49,7 +49,8 @@ class SqliteItemStorage(ItemStorageABC, Generic[T]): def _parse_item(self, item: str) -> T: item_type = get_args(self.__orig_class__)[0] - return parse_raw_as(item_type, item) + parsed = parse_raw_as(item_type, item) + return parsed def set(self, item: T): try: diff --git a/invokeai/frontend/web/package.json b/invokeai/frontend/web/package.json index 8cc2c158be..6c9db74bbc 100644 --- a/invokeai/frontend/web/package.json +++ b/invokeai/frontend/web/package.json @@ -61,6 +61,7 @@ "@dagrejs/graphlib": "^2.1.13", "@dnd-kit/core": "^6.0.8", "@dnd-kit/modifiers": "^6.0.1", + "@dnd-kit/utilities": "^3.2.1", "@emotion/react": "^11.11.1", "@emotion/styled": "^11.11.0", "@floating-ui/react-dom": "^2.0.1", diff --git a/invokeai/frontend/web/scripts/colors.js b/invokeai/frontend/web/scripts/colors.js new file mode 100644 index 0000000000..3fc8f8d751 --- /dev/null +++ b/invokeai/frontend/web/scripts/colors.js @@ -0,0 +1,34 @@ +export const COLORS = { + reset: '\x1b[0m', + bright: '\x1b[1m', + dim: '\x1b[2m', + underscore: '\x1b[4m', + blink: '\x1b[5m', + reverse: '\x1b[7m', + hidden: '\x1b[8m', + + fg: { + black: '\x1b[30m', + red: '\x1b[31m', + green: '\x1b[32m', + yellow: '\x1b[33m', + blue: '\x1b[34m', + magenta: '\x1b[35m', + cyan: '\x1b[36m', + white: '\x1b[37m', + gray: '\x1b[90m', + crimson: '\x1b[38m', + }, + bg: { + black: '\x1b[40m', + red: '\x1b[41m', + green: '\x1b[42m', + yellow: '\x1b[43m', + blue: '\x1b[44m', + magenta: '\x1b[45m', + cyan: '\x1b[46m', + white: '\x1b[47m', + gray: '\x1b[100m', + crimson: '\x1b[48m', + }, +}; diff --git a/invokeai/frontend/web/scripts/typegen.js b/invokeai/frontend/web/scripts/typegen.js index ec67c48f2d..d105917e66 100644 --- a/invokeai/frontend/web/scripts/typegen.js +++ b/invokeai/frontend/web/scripts/typegen.js @@ -1,23 +1,83 @@ import fs from 'node:fs'; import openapiTS from 'openapi-typescript'; +import { COLORS } from './colors.js'; const OPENAPI_URL = 'http://127.0.0.1:9090/openapi.json'; const OUTPUT_FILE = 'src/services/api/schema.d.ts'; async function main() { process.stdout.write( - `Generating types "${OPENAPI_URL}" --> "${OUTPUT_FILE}"...` + `Generating types "${OPENAPI_URL}" --> "${OUTPUT_FILE}"...\n\n` ); const types = await openapiTS(OPENAPI_URL, { exportType: true, - transform: (schemaObject) => { + transform: (schemaObject, metadata) => { if ('format' in schemaObject && schemaObject.format === 'binary') { return schemaObject.nullable ? 'Blob | null' : 'Blob'; } + + /** + * Because invocations may have required fields that accept connection input, the generated + * types may be incorrect. + * + * For example, the ImageResizeInvocation has a required `image` field, but because it accepts + * connection input, it should be optional on instantiation of the field. + * + * To handle this, the schema exposes an `input` property that can be used to determine if the + * field accepts connection input. If it does, we can make the field optional. + */ + + // Check if we are generating types for an invocation + const isInvocationPath = metadata.path.match( + /^#\/components\/schemas\/\w*Invocation$/ + ); + + const hasInvocationProperties = + schemaObject.properties && + ['id', 'is_intermediate', 'type'].every( + (prop) => prop in schemaObject.properties + ); + + if (isInvocationPath && hasInvocationProperties) { + // We only want to make fields optional if they are required + if (!Array.isArray(schemaObject?.required)) { + schemaObject.required = ['id', 'type']; + return; + } + + schemaObject.required.forEach((prop) => { + const acceptsConnection = ['any', 'connection'].includes( + schemaObject.properties?.[prop]?.['input'] + ); + + if (acceptsConnection) { + // remove this prop from the required array + const invocationName = metadata.path.split('/').pop(); + console.log( + `Making connectable field optional: ${COLORS.fg.green}${invocationName}.${COLORS.fg.cyan}${prop}${COLORS.reset}` + ); + schemaObject.required = schemaObject.required.filter( + (r) => r !== prop + ); + } + }); + + schemaObject.required = [ + ...new Set(schemaObject.required.concat(['id', 'type'])), + ]; + + return; + } + // if ( + // 'input' in schemaObject && + // (schemaObject.input === 'any' || schemaObject.input === 'connection') + // ) { + // schemaObject.required = false; + // } }, }); fs.writeFileSync(OUTPUT_FILE, types); - process.stdout.write(` OK!\r\n`); + process.stdout.write(`\nOK!\r\n`); } main(); diff --git a/invokeai/frontend/web/src/app/components/GlobalHotkeys.ts b/invokeai/frontend/web/src/app/components/GlobalHotkeys.ts index 9827e7f2b3..bbe77dc698 100644 --- a/invokeai/frontend/web/src/app/components/GlobalHotkeys.ts +++ b/invokeai/frontend/web/src/app/components/GlobalHotkeys.ts @@ -1,8 +1,12 @@ import { createSelector } from '@reduxjs/toolkit'; -import { RootState } from 'app/store/store'; +import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { requestCanvasRescale } from 'features/canvas/store/thunks/requestCanvasScale'; -import { shiftKeyPressed } from 'features/ui/store/hotkeysSlice'; +import { + ctrlKeyPressed, + metaKeyPressed, + shiftKeyPressed, +} from 'features/ui/store/hotkeysSlice'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import { setActiveTab, @@ -16,11 +20,11 @@ import React, { memo } from 'react'; import { isHotkeyPressed, useHotkeys } from 'react-hotkeys-hook'; const globalHotkeysSelector = createSelector( - [(state: RootState) => state.hotkeys, (state: RootState) => state.ui], - (hotkeys, ui) => { - const { shift } = hotkeys; + [stateSelector], + ({ hotkeys, ui }) => { + const { shift, ctrl, meta } = hotkeys; const { shouldPinParametersPanel, shouldPinGallery } = ui; - return { shift, shouldPinGallery, shouldPinParametersPanel }; + return { shift, ctrl, meta, shouldPinGallery, shouldPinParametersPanel }; }, { memoizeOptions: { @@ -37,9 +41,8 @@ const globalHotkeysSelector = createSelector( */ const GlobalHotkeys: React.FC = () => { const dispatch = useAppDispatch(); - const { shift, shouldPinParametersPanel, shouldPinGallery } = useAppSelector( - globalHotkeysSelector - ); + const { shift, ctrl, meta, shouldPinParametersPanel, shouldPinGallery } = + useAppSelector(globalHotkeysSelector); const activeTabName = useAppSelector(activeTabNameSelector); useHotkeys( @@ -50,9 +53,19 @@ const GlobalHotkeys: React.FC = () => { } else { shift && dispatch(shiftKeyPressed(false)); } + if (isHotkeyPressed('ctrl')) { + !ctrl && dispatch(ctrlKeyPressed(true)); + } else { + ctrl && dispatch(ctrlKeyPressed(false)); + } + if (isHotkeyPressed('meta')) { + !meta && dispatch(metaKeyPressed(true)); + } else { + meta && dispatch(metaKeyPressed(false)); + } }, { keyup: true, keydown: true }, - [shift] + [shift, ctrl, meta] ); useHotkeys('o', () => { diff --git a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx index 93b7825db7..7e2ed7f571 100644 --- a/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx +++ b/invokeai/frontend/web/src/app/components/InvokeAIUI.tsx @@ -14,7 +14,7 @@ import { $authToken, $baseUrl, $projectId } from 'services/api/client'; import { socketMiddleware } from 'services/events/middleware'; import Loading from '../../common/components/Loading/Loading'; import '../../i18n'; -import ImageDndContext from './ImageDnd/ImageDndContext'; +import AppDndContext from '../../features/dnd/components/AppDndContext'; const App = lazy(() => import('./App')); const ThemeLocaleProvider = lazy(() => import('./ThemeLocaleProvider')); @@ -80,9 +80,9 @@ const InvokeAIUI = ({ }> - + - + diff --git a/invokeai/frontend/web/src/app/logging/logger.ts b/invokeai/frontend/web/src/app/logging/logger.ts index ef27c98d1f..7797b8dc92 100644 --- a/invokeai/frontend/web/src/app/logging/logger.ts +++ b/invokeai/frontend/web/src/app/logging/logger.ts @@ -19,7 +19,8 @@ type LoggerNamespace = | 'nodes' | 'system' | 'socketio' - | 'session'; + | 'session' + | 'dnd'; export const logger = (namespace: LoggerNamespace) => $logger.get().child({ namespace }); diff --git a/invokeai/frontend/web/src/app/store/middleware/devtools/actionsDenylist.ts b/invokeai/frontend/web/src/app/store/middleware/devtools/actionsDenylist.ts index 6d41d488c8..a596fce931 100644 --- a/invokeai/frontend/web/src/app/store/middleware/devtools/actionsDenylist.ts +++ b/invokeai/frontend/web/src/app/store/middleware/devtools/actionsDenylist.ts @@ -15,7 +15,7 @@ export const actionsDenylist = [ 'socket/socketGeneratorProgress', 'socket/appSocketGeneratorProgress', // every time user presses shift - 'hotkeys/shiftKeyPressed', + // 'hotkeys/shiftKeyPressed', // this happens after every state change '@@REMEMBER_PERSISTED', ]; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts index 043105cb66..fc0b44653d 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageDropped.ts @@ -1,16 +1,20 @@ import { createAction } from '@reduxjs/toolkit'; -import { - TypesafeDraggableData, - TypesafeDroppableData, -} from 'app/components/ImageDnd/typesafeDnd'; import { logger } from 'app/logging/logger'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice'; +import { + TypesafeDraggableData, + TypesafeDroppableData, +} from 'features/dnd/types'; import { imageSelected } from 'features/gallery/store/gallerySlice'; -import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; +import { + fieldImageValueChanged, + workflowExposedFieldAdded, +} from 'features/nodes/store/nodesSlice'; import { initialImageChanged } from 'features/parameters/store/generationSlice'; import { imagesApi } from 'services/api/endpoints/images'; import { startAppListening } from '../'; +import { parseify } from 'common/util/serialize'; export const dndDropped = createAction<{ overData: TypesafeDroppableData; @@ -21,7 +25,7 @@ export const addImageDroppedListener = () => { startAppListening({ actionCreator: dndDropped, effect: async (action, { dispatch }) => { - const log = logger('images'); + const log = logger('dnd'); const { activeData, overData } = action.payload; if (activeData.payloadType === 'IMAGE_DTO') { @@ -31,10 +35,28 @@ export const addImageDroppedListener = () => { { activeData, overData }, `Images (${activeData.payload.imageDTOs.length}) dropped` ); + } else if (activeData.payloadType === 'NODE_FIELD') { + log.debug( + { activeData: parseify(activeData), overData: parseify(overData) }, + 'Node field dropped' + ); } else { log.debug({ activeData, overData }, `Unknown payload dropped`); } + if ( + overData.actionType === 'ADD_FIELD_TO_LINEAR' && + activeData.payloadType === 'NODE_FIELD' + ) { + const { nodeId, field } = activeData.payload; + dispatch( + workflowExposedFieldAdded({ + nodeId, + fieldName: field.name, + }) + ); + } + /** * Image dropped on current image */ @@ -99,7 +121,7 @@ export const addImageDroppedListener = () => { ) { const { fieldName, nodeId } = overData.context; dispatch( - fieldValueChanged({ + fieldImageValueChanged({ nodeId, fieldName, value: activeData.payload.imageDTO, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts index 6dc2d482a9..0c55908748 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/imageUploaded.ts @@ -2,7 +2,7 @@ import { UseToastOptions } from '@chakra-ui/react'; import { logger } from 'app/logging/logger'; import { setInitialCanvasImage } from 'features/canvas/store/canvasSlice'; import { controlNetImageChanged } from 'features/controlNet/store/controlNetSlice'; -import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; +import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice'; import { initialImageChanged } from 'features/parameters/store/generationSlice'; import { addToast } from 'features/system/store/systemSlice'; import { omit } from 'lodash-es'; @@ -111,7 +111,9 @@ export const addImageUploadedFulfilledListener = () => { if (postUploadAction?.type === 'SET_NODES_IMAGE') { const { nodeId, fieldName } = postUploadAction; - dispatch(fieldValueChanged({ nodeId, fieldName, value: imageDTO })); + dispatch( + fieldImageValueChanged({ nodeId, fieldName, value: imageDTO }) + ); dispatch( addToast({ ...DEFAULT_UPLOADED_TOAST, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts index 436a58aa8e..4d30ee3b8b 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/modelsLoaded.ts @@ -15,12 +15,21 @@ import { setShouldUseSDXLRefiner, } from 'features/sdxl/store/sdxlSlice'; import { forEach, some } from 'lodash-es'; -import { modelsApi, vaeModelsAdapter } from 'services/api/endpoints/models'; +import { + mainModelsAdapter, + modelsApi, + vaeModelsAdapter, +} from 'services/api/endpoints/models'; +import { TypeGuardFor } from 'services/api/types'; import { startAppListening } from '..'; export const addModelsLoadedListener = () => { startAppListening({ - predicate: (state, action) => + predicate: ( + action + ): action is TypeGuardFor< + typeof modelsApi.endpoints.getMainModels.matchFulfilled + > => modelsApi.endpoints.getMainModels.matchFulfilled(action) && !action.meta.arg.originalArgs.includes('sdxl-refiner'), effect: async (action, { getState, dispatch }) => { @@ -32,29 +41,28 @@ export const addModelsLoadedListener = () => { ); const currentModel = getState().generation.model; + const models = mainModelsAdapter.getSelectors().selectAll(action.payload); - const isCurrentModelAvailable = some( - action.payload.entities, - (m) => - m?.model_name === currentModel?.model_name && - m?.base_model === currentModel?.base_model && - m?.model_type === currentModel?.model_type - ); - - if (isCurrentModelAvailable) { - return; - } - - const firstModelId = action.payload.ids[0]; - const firstModel = action.payload.entities[firstModelId]; - - if (!firstModel) { + if (models.length === 0) { // No models loaded at all dispatch(modelChanged(null)); return; } - const result = zMainOrOnnxModel.safeParse(firstModel); + const isCurrentModelAvailable = currentModel + ? models.some( + (m) => + m.model_name === currentModel.model_name && + m.base_model === currentModel.base_model && + m.model_type === currentModel.model_type + ) + : false; + + if (isCurrentModelAvailable) { + return; + } + + const result = zMainOrOnnxModel.safeParse(models[0]); if (!result.success) { log.error( @@ -68,7 +76,11 @@ export const addModelsLoadedListener = () => { }, }); startAppListening({ - predicate: (state, action) => + predicate: ( + action + ): action is TypeGuardFor< + typeof modelsApi.endpoints.getMainModels.matchFulfilled + > => modelsApi.endpoints.getMainModels.matchFulfilled(action) && action.meta.arg.originalArgs.includes('sdxl-refiner'), effect: async (action, { getState, dispatch }) => { @@ -80,30 +92,29 @@ export const addModelsLoadedListener = () => { ); const currentModel = getState().sdxl.refinerModel; + const models = mainModelsAdapter.getSelectors().selectAll(action.payload); - const isCurrentModelAvailable = some( - action.payload.entities, - (m) => - m?.model_name === currentModel?.model_name && - m?.base_model === currentModel?.base_model && - m?.model_type === currentModel?.model_type - ); - - if (isCurrentModelAvailable) { - return; - } - - const firstModelId = action.payload.ids[0]; - const firstModel = action.payload.entities[firstModelId]; - - if (!firstModel) { + if (models.length === 0) { // No models loaded at all dispatch(refinerModelChanged(null)); dispatch(setShouldUseSDXLRefiner(false)); return; } - const result = zSDXLRefinerModel.safeParse(firstModel); + const isCurrentModelAvailable = currentModel + ? models.some( + (m) => + m.model_name === currentModel.model_name && + m.base_model === currentModel.base_model && + m.model_type === currentModel.model_type + ) + : false; + + if (isCurrentModelAvailable) { + return; + } + + const result = zSDXLRefinerModel.safeParse(models[0]); if (!result.success) { log.error( diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts index 44729f215a..dd86c77735 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts @@ -13,7 +13,7 @@ export const addReceivedOpenAPISchemaListener = () => { const log = logger('system'); const schemaJSON = action.payload; - log.debug({ schemaJSON }, 'Dereferenced OpenAPI schema'); + log.debug({ schemaJSON }, 'Received OpenAPI schema'); const nodeTemplates = parseSchema(schemaJSON); @@ -28,9 +28,12 @@ export const addReceivedOpenAPISchemaListener = () => { startAppListening({ actionCreator: receivedOpenAPISchema.rejected, - effect: () => { + effect: (action) => { const log = logger('system'); - log.error('Problem dereferencing OpenAPI Schema'); + log.error( + { error: parseify(action.error) }, + 'Problem retrieving OpenAPI Schema' + ); }, }); }; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts index 5b3b9424b6..5501f208fd 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketInvocationComplete.ts @@ -19,7 +19,7 @@ import { } from 'services/events/actions'; import { startAppListening } from '../..'; -const nodeDenylist = ['dataURL_image']; +const nodeDenylist = ['load_image']; export const addInvocationCompleteEventListener = () => { startAppListening({ diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedNodes.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedNodes.ts index 0c298cbb24..5894bba5df 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedNodes.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/userInvokedNodes.ts @@ -15,7 +15,7 @@ export const addUserInvokedNodesListener = () => { const log = logger('session'); const state = getState(); - const graph = buildNodesGraph(state); + const graph = buildNodesGraph(state.nodes); dispatch(nodesGraphBuilt(graph)); log.debug({ graph: parseify(graph) }, 'Nodes graph built'); diff --git a/invokeai/frontend/web/src/app/types/invokeai.ts b/invokeai/frontend/web/src/app/types/invokeai.ts index 827424fa7f..a39ed2ca7b 100644 --- a/invokeai/frontend/web/src/app/types/invokeai.ts +++ b/invokeai/frontend/web/src/app/types/invokeai.ts @@ -1,86 +1,7 @@ -import { - // CONTROLNET_MODELS, - CONTROLNET_PROCESSORS, -} from 'features/controlNet/store/constants'; +import { CONTROLNET_PROCESSORS } from 'features/controlNet/store/constants'; import { InvokeTabName } from 'features/ui/store/tabMap'; import { O } from 'ts-toolbelt'; -// These are old types from the model management UI - -// export type ModelStatus = 'active' | 'cached' | 'not loaded'; - -// export type Model = { -// status: ModelStatus; -// description: string; -// weights: string; -// config?: string; -// vae?: string; -// width?: number; -// height?: number; -// default?: boolean; -// format?: string; -// }; - -// export type DiffusersModel = { -// status: ModelStatus; -// description: string; -// repo_id?: string; -// path?: string; -// vae?: { -// repo_id?: string; -// path?: string; -// }; -// format?: string; -// default?: boolean; -// }; - -// export type ModelList = Record; - -// export type FoundModel = { -// name: string; -// location: string; -// }; - -// export type InvokeModelConfigProps = { -// name: string | undefined; -// description: string | undefined; -// config: string | undefined; -// weights: string | undefined; -// vae: string | undefined; -// width: number | undefined; -// height: number | undefined; -// default: boolean | undefined; -// format: string | undefined; -// }; - -// export type InvokeDiffusersModelConfigProps = { -// name: string | undefined; -// description: string | undefined; -// repo_id: string | undefined; -// path: string | undefined; -// default: boolean | undefined; -// format: string | undefined; -// vae: { -// repo_id: string | undefined; -// path: string | undefined; -// }; -// }; - -// export type InvokeModelConversionProps = { -// model_name: string; -// save_location: string; -// custom_location: string | null; -// }; - -// export type InvokeModelMergingProps = { -// models_to_merge: string[]; -// alpha: number; -// interp: 'weighted_sum' | 'sigmoid' | 'inv_sigmoid' | 'add_difference'; -// force: boolean; -// merged_model_name: string; -// model_merge_save_path: string | null; -// }; - /** * A disable-able application feature */ diff --git a/invokeai/frontend/web/src/common/components/IAIDndImage.tsx b/invokeai/frontend/web/src/common/components/IAIDndImage.tsx index 780447aba6..defe600b78 100644 --- a/invokeai/frontend/web/src/common/components/IAIDndImage.tsx +++ b/invokeai/frontend/web/src/common/components/IAIDndImage.tsx @@ -6,10 +6,6 @@ import { useColorMode, useColorModeValue, } from '@chakra-ui/react'; -import { - TypesafeDraggableData, - TypesafeDroppableData, -} from 'app/components/ImageDnd/typesafeDnd'; import IAIIconButton from 'common/components/IAIIconButton'; import { IAILoadingImageFallback, @@ -17,6 +13,10 @@ import { } from 'common/components/IAIImageFallback'; import ImageMetadataOverlay from 'common/components/ImageMetadataOverlay'; import { useImageUploadButton } from 'common/hooks/useImageUploadButton'; +import { + TypesafeDraggableData, + TypesafeDroppableData, +} from 'features/dnd/types'; import ImageContextMenu from 'features/gallery/components/ImageContextMenu/ImageContextMenu'; import { MouseEvent, @@ -157,11 +157,10 @@ const IAIDndImage = (props: IAIDndImageProps) => { ) } - width={imageDTO.width} - height={imageDTO.height} onError={onError} draggable={false} sx={{ + w: imageDTO.width, objectFit: 'contain', maxW: 'full', maxH: 'full', @@ -213,13 +212,6 @@ const IAIDndImage = (props: IAIDndImageProps) => { onClick={onClick} /> )} - {!isDropDisabled && ( - - )} {onClickReset && withResetIcon && imageDTO && ( { }} /> )} + {!isDropDisabled && ( + + )} )} diff --git a/invokeai/frontend/web/src/common/components/IAIDraggable.tsx b/invokeai/frontend/web/src/common/components/IAIDraggable.tsx index 482a8ac604..363799a573 100644 --- a/invokeai/frontend/web/src/common/components/IAIDraggable.tsx +++ b/invokeai/frontend/web/src/common/components/IAIDraggable.tsx @@ -1,22 +1,19 @@ -import { Box } from '@chakra-ui/react'; -import { - TypesafeDraggableData, - useDraggable, -} from 'app/components/ImageDnd/typesafeDnd'; -import { MouseEvent, memo, useRef } from 'react'; +import { Box, BoxProps } from '@chakra-ui/react'; +import { useDraggableTypesafe } from 'features/dnd/hooks/typesafeHooks'; +import { TypesafeDraggableData } from 'features/dnd/types'; +import { memo, useRef } from 'react'; import { v4 as uuidv4 } from 'uuid'; -type IAIDraggableProps = { +type IAIDraggableProps = BoxProps & { disabled?: boolean; data?: TypesafeDraggableData; - onClick?: (event: MouseEvent) => void; }; const IAIDraggable = (props: IAIDraggableProps) => { - const { data, disabled, onClick } = props; + const { data, disabled, ...rest } = props; const dndId = useRef(uuidv4()); - const { attributes, listeners, setNodeRef } = useDraggable({ + const { attributes, listeners, setNodeRef } = useDraggableTypesafe({ id: dndId.current, disabled, data, @@ -24,7 +21,6 @@ const IAIDraggable = (props: IAIDraggableProps) => { return ( { insetInlineStart={0} {...attributes} {...listeners} + {...rest} /> ); }; diff --git a/invokeai/frontend/web/src/common/components/IAIDroppable.tsx b/invokeai/frontend/web/src/common/components/IAIDroppable.tsx index 1038f36840..e4fb121c78 100644 --- a/invokeai/frontend/web/src/common/components/IAIDroppable.tsx +++ b/invokeai/frontend/web/src/common/components/IAIDroppable.tsx @@ -1,9 +1,7 @@ import { Box } from '@chakra-ui/react'; -import { - TypesafeDroppableData, - isValidDrop, - useDroppable, -} from 'app/components/ImageDnd/typesafeDnd'; +import { useDroppableTypesafe } from 'features/dnd/hooks/typesafeHooks'; +import { TypesafeDroppableData } from 'features/dnd/types'; +import { isValidDrop } from 'features/dnd/util/isValidDrop'; import { AnimatePresence } from 'framer-motion'; import { ReactNode, memo, useRef } from 'react'; import { v4 as uuidv4 } from 'uuid'; @@ -19,7 +17,7 @@ const IAIDroppable = (props: IAIDroppableProps) => { const { dropLabel, data, disabled } = props; const dndId = useRef(uuidv4()); - const { isOver, setNodeRef, active } = useDroppable({ + const { isOver, setNodeRef, active } = useDroppableTypesafe({ id: dndId.current, disabled, data, diff --git a/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx b/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx index 2057525b7a..a150e4ed0c 100644 --- a/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx +++ b/invokeai/frontend/web/src/common/components/IAIImageFallback.tsx @@ -49,7 +49,7 @@ export const IAILoadingImageFallback = (props: Props) => { type IAINoImageFallbackProps = { label?: string; - icon?: As; + icon?: As | null; boxSize?: StyleProps['boxSize']; sx?: ChakraProps['sx']; }; @@ -76,7 +76,7 @@ export const IAINoContentFallback = (props: IAINoImageFallbackProps) => { ...props.sx, }} > - + {icon && } {props.label && {props.label}} ); diff --git a/invokeai/frontend/web/src/common/components/IAISwitch.tsx b/invokeai/frontend/web/src/common/components/IAISwitch.tsx index 9803626397..da0883d77e 100644 --- a/invokeai/frontend/web/src/common/components/IAISwitch.tsx +++ b/invokeai/frontend/web/src/common/components/IAISwitch.tsx @@ -1,10 +1,13 @@ import { + Flex, FormControl, FormControlProps, + FormHelperText, FormLabel, FormLabelProps, Switch, SwitchProps, + Text, Tooltip, } from '@chakra-ui/react'; import { memo } from 'react'; @@ -15,6 +18,7 @@ export interface IAISwitchProps extends SwitchProps { formControlProps?: FormControlProps; formLabelProps?: FormLabelProps; tooltip?: string; + helperText?: string; } /** @@ -28,6 +32,7 @@ const IAISwitch = (props: IAISwitchProps) => { formControlProps, formLabelProps, tooltip, + helperText, ...rest } = props; return ( @@ -35,25 +40,33 @@ const IAISwitch = (props: IAISwitchProps) => { - {label && ( - - {label} - - )} - + + + {label && ( + + {label} + + )} + + + {helperText && ( + + {helperText} + + )} + ); diff --git a/invokeai/frontend/web/src/common/hooks/useChakraThemeTokens.ts b/invokeai/frontend/web/src/common/hooks/useChakraThemeTokens.ts index 770add7253..0afb7e7e5d 100644 --- a/invokeai/frontend/web/src/common/hooks/useChakraThemeTokens.ts +++ b/invokeai/frontend/web/src/common/hooks/useChakraThemeTokens.ts @@ -40,6 +40,44 @@ export const useChakraThemeTokens = () => { accent850, accent900, accent950, + baseAlpha50, + baseAlpha100, + baseAlpha150, + baseAlpha200, + baseAlpha250, + baseAlpha300, + baseAlpha350, + baseAlpha400, + baseAlpha450, + baseAlpha500, + baseAlpha550, + baseAlpha600, + baseAlpha650, + baseAlpha700, + baseAlpha750, + baseAlpha800, + baseAlpha850, + baseAlpha900, + baseAlpha950, + accentAlpha50, + accentAlpha100, + accentAlpha150, + accentAlpha200, + accentAlpha250, + accentAlpha300, + accentAlpha350, + accentAlpha400, + accentAlpha450, + accentAlpha500, + accentAlpha550, + accentAlpha600, + accentAlpha650, + accentAlpha700, + accentAlpha750, + accentAlpha800, + accentAlpha850, + accentAlpha900, + accentAlpha950, ] = useToken('colors', [ 'base.50', 'base.100', @@ -79,6 +117,44 @@ export const useChakraThemeTokens = () => { 'accent.850', 'accent.900', 'accent.950', + 'baseAlpha.50', + 'baseAlpha.100', + 'baseAlpha.150', + 'baseAlpha.200', + 'baseAlpha.250', + 'baseAlpha.300', + 'baseAlpha.350', + 'baseAlpha.400', + 'baseAlpha.450', + 'baseAlpha.500', + 'baseAlpha.550', + 'baseAlpha.600', + 'baseAlpha.650', + 'baseAlpha.700', + 'baseAlpha.750', + 'baseAlpha.800', + 'baseAlpha.850', + 'baseAlpha.900', + 'baseAlpha.950', + 'accentAlpha.50', + 'accentAlpha.100', + 'accentAlpha.150', + 'accentAlpha.200', + 'accentAlpha.250', + 'accentAlpha.300', + 'accentAlpha.350', + 'accentAlpha.400', + 'accentAlpha.450', + 'accentAlpha.500', + 'accentAlpha.550', + 'accentAlpha.600', + 'accentAlpha.650', + 'accentAlpha.700', + 'accentAlpha.750', + 'accentAlpha.800', + 'accentAlpha.850', + 'accentAlpha.900', + 'accentAlpha.950', ]); return { @@ -120,5 +196,43 @@ export const useChakraThemeTokens = () => { accent850, accent900, accent950, + baseAlpha50, + baseAlpha100, + baseAlpha150, + baseAlpha200, + baseAlpha250, + baseAlpha300, + baseAlpha350, + baseAlpha400, + baseAlpha450, + baseAlpha500, + baseAlpha550, + baseAlpha600, + baseAlpha650, + baseAlpha700, + baseAlpha750, + baseAlpha800, + baseAlpha850, + baseAlpha900, + baseAlpha950, + accentAlpha50, + accentAlpha100, + accentAlpha150, + accentAlpha200, + accentAlpha250, + accentAlpha300, + accentAlpha350, + accentAlpha400, + accentAlpha450, + accentAlpha500, + accentAlpha550, + accentAlpha600, + accentAlpha650, + accentAlpha700, + accentAlpha750, + accentAlpha800, + accentAlpha850, + accentAlpha900, + accentAlpha950, }; }; diff --git a/invokeai/frontend/web/src/common/util/serialize.ts b/invokeai/frontend/web/src/common/util/serialize.ts index a9352a8228..a5db921f8d 100644 --- a/invokeai/frontend/web/src/common/util/serialize.ts +++ b/invokeai/frontend/web/src/common/util/serialize.ts @@ -1,4 +1,10 @@ /** * Serialize an object to JSON and back to a new object */ -export const parseify = (obj: unknown) => JSON.parse(JSON.stringify(obj)); +export const parseify = (obj: unknown) => { + try { + return JSON.parse(JSON.stringify(obj)); + } catch { + return 'Error parsing object'; + } +}; diff --git a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx index cdab176cd2..4fffb82275 100644 --- a/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx +++ b/invokeai/frontend/web/src/features/controlNet/components/ControlNetImagePreview.tsx @@ -4,7 +4,7 @@ import { skipToken } from '@reduxjs/toolkit/dist/query'; import { TypesafeDraggableData, TypesafeDroppableData, -} from 'app/components/ImageDnd/typesafeDnd'; +} from 'features/dnd/types'; import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; diff --git a/invokeai/frontend/web/src/features/controlNet/store/types.ts b/invokeai/frontend/web/src/features/controlNet/store/types.ts index 2d028fd0bb..80edb41699 100644 --- a/invokeai/frontend/web/src/features/controlNet/store/types.ts +++ b/invokeai/frontend/web/src/features/controlNet/store/types.ts @@ -138,7 +138,7 @@ export type RequiredZoeDepthImageProcessorInvocation = O.Required< /** * Any ControlNet Processor node, with its parameters flagged as required */ -export type RequiredControlNetProcessorNode = +export type RequiredControlNetProcessorNode = O.Required< | RequiredCannyImageProcessorInvocation | RequiredContentShuffleImageProcessorInvocation | RequiredHedImageProcessorInvocation @@ -150,7 +150,9 @@ export type RequiredControlNetProcessorNode = | RequiredNormalbaeImageProcessorInvocation | RequiredOpenposeImageProcessorInvocation | RequiredPidiImageProcessorInvocation - | RequiredZoeDepthImageProcessorInvocation; + | RequiredZoeDepthImageProcessorInvocation, + 'id' +>; /** * Type guard for CannyImageProcessorInvocation diff --git a/invokeai/frontend/web/src/features/deleteImageModal/store/selectors.ts b/invokeai/frontend/web/src/features/deleteImageModal/store/selectors.ts index 310521f32a..37be06bad6 100644 --- a/invokeai/frontend/web/src/features/deleteImageModal/store/selectors.ts +++ b/invokeai/frontend/web/src/features/deleteImageModal/store/selectors.ts @@ -3,6 +3,7 @@ import { RootState } from 'app/store/store'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { some } from 'lodash-es'; import { ImageUsage } from './types'; +import { isInvocationNode } from 'features/nodes/types/types'; export const getImageUsage = (state: RootState, image_name: string) => { const { generation, canvas, nodes, controlNet } = state; @@ -12,11 +13,11 @@ export const getImageUsage = (state: RootState, image_name: string) => { (obj) => obj.kind === 'image' && obj.imageName === image_name ); - const isNodesImage = nodes.nodes.some((node) => { + const isNodesImage = nodes.nodes.filter(isInvocationNode).some((node) => { return some( node.data.inputs, (input) => - input.type === 'image' && input.value?.image_name === image_name + input.type === 'ImageField' && input.value?.image_name === image_name ); }); diff --git a/invokeai/frontend/web/src/app/components/ImageDnd/ImageDndContext.tsx b/invokeai/frontend/web/src/features/dnd/components/AppDndContext.tsx similarity index 70% rename from invokeai/frontend/web/src/app/components/ImageDnd/ImageDndContext.tsx rename to invokeai/frontend/web/src/features/dnd/components/AppDndContext.tsx index 56eeb9b5db..bffe738aa9 100644 --- a/invokeai/frontend/web/src/app/components/ImageDnd/ImageDndContext.tsx +++ b/invokeai/frontend/web/src/features/dnd/components/AppDndContext.tsx @@ -6,23 +6,18 @@ import { useSensor, useSensors, } from '@dnd-kit/core'; -import { snapCenterToCursor } from '@dnd-kit/modifiers'; +import { logger } from 'app/logging/logger'; import { dndDropped } from 'app/store/middleware/listenerMiddleware/listeners/imageDropped'; import { useAppDispatch } from 'app/store/storeHooks'; +import { parseify } from 'common/util/serialize'; import { AnimatePresence, motion } from 'framer-motion'; import { PropsWithChildren, memo, useCallback, useState } from 'react'; +import { useScaledModifer } from '../hooks/useScaledCenteredModifer'; +import { DragEndEvent, DragStartEvent, TypesafeDraggableData } from '../types'; +import { DndContextTypesafe } from './DndContextTypesafe'; import DragPreview from './DragPreview'; -import { - DndContext, - DragEndEvent, - DragStartEvent, - TypesafeDraggableData, -} from './typesafeDnd'; -import { logger } from 'app/logging/logger'; -type ImageDndContextProps = PropsWithChildren; - -const ImageDndContext = (props: ImageDndContextProps) => { +const AppDndContext = (props: PropsWithChildren) => { const [activeDragData, setActiveDragData] = useState(null); const log = logger('images'); @@ -31,7 +26,10 @@ const ImageDndContext = (props: ImageDndContextProps) => { const handleDragStart = useCallback( (event: DragStartEvent) => { - log.trace({ dragData: event.active.data.current }, 'Drag started'); + log.trace( + { dragData: parseify(event.active.data.current) }, + 'Drag started' + ); const activeData = event.active.data.current; if (!activeData) { return; @@ -43,7 +41,10 @@ const ImageDndContext = (props: ImageDndContextProps) => { const handleDragEnd = useCallback( (event: DragEndEvent) => { - log.trace({ dragData: event.active.data.current }, 'Drag ended'); + log.trace( + { dragData: parseify(event.active.data.current) }, + 'Drag ended' + ); const overData = event.over?.data.current; if (!activeDragData || !overData) { return; @@ -69,15 +70,29 @@ const ImageDndContext = (props: ImageDndContextProps) => { const sensors = useSensors(mouseSensor, touchSensor); + const scaledModifier = useScaledModifer(); + return ( - {props.children} - + {activeDragData && ( { )} - + ); }; -export default memo(ImageDndContext); +export default memo(AppDndContext); diff --git a/invokeai/frontend/web/src/features/dnd/components/DndContextTypesafe.tsx b/invokeai/frontend/web/src/features/dnd/components/DndContextTypesafe.tsx new file mode 100644 index 0000000000..06fede4dc8 --- /dev/null +++ b/invokeai/frontend/web/src/features/dnd/components/DndContextTypesafe.tsx @@ -0,0 +1,6 @@ +import { DndContext } from '@dnd-kit/core'; +import { DndContextTypesafeProps } from '../types'; + +export function DndContextTypesafe(props: DndContextTypesafeProps) { + return ; +} diff --git a/invokeai/frontend/web/src/app/components/ImageDnd/DragPreview.tsx b/invokeai/frontend/web/src/features/dnd/components/DragPreview.tsx similarity index 69% rename from invokeai/frontend/web/src/app/components/ImageDnd/DragPreview.tsx rename to invokeai/frontend/web/src/features/dnd/components/DragPreview.tsx index c97778ffcd..0ee5d34b1a 100644 --- a/invokeai/frontend/web/src/app/components/ImageDnd/DragPreview.tsx +++ b/invokeai/frontend/web/src/features/dnd/components/DragPreview.tsx @@ -1,6 +1,6 @@ -import { Box, ChakraProps, Flex, Heading, Image } from '@chakra-ui/react'; +import { Box, ChakraProps, Flex, Heading, Image, Text } from '@chakra-ui/react'; import { memo } from 'react'; -import { TypesafeDraggableData } from './typesafeDnd'; +import { TypesafeDraggableData } from '../types'; type OverlayDragImageProps = { dragData: TypesafeDraggableData | null; @@ -30,19 +30,38 @@ const DragPreview = (props: OverlayDragImageProps) => { return null; } + if (props.dragData.payloadType === 'NODE_FIELD') { + const { field, fieldTemplate } = props.dragData.payload; + return ( + + {field.label || fieldTemplate.title} + + ); + } + if (props.dragData.payloadType === 'IMAGE_DTO') { const { thumbnail_url, width, height } = props.dragData.payload.imageDTO; return ( { return ( (activeTabName === 'nodes' ? nodes.zoom : 1) +); + +/** + * Applies scaling to the drag transform (if on node editor tab) and centers it on cursor. + */ +export const useScaledModifer = () => { + const zoom = useAppSelector(selectZoom); + const modifier: Modifier = useCallback( + ({ activatorEvent, draggingNodeRect, transform }) => { + if (draggingNodeRect && activatorEvent) { + const activatorCoordinates = getEventCoordinates(activatorEvent); + + if (!activatorCoordinates) { + return transform; + } + + const offsetX = activatorCoordinates.x - draggingNodeRect.left; + const offsetY = activatorCoordinates.y - draggingNodeRect.top; + + const x = transform.x + offsetX - draggingNodeRect.width / 2; + const y = transform.y + offsetY - draggingNodeRect.height / 2; + const scaleX = transform.scaleX * zoom; + const scaleY = transform.scaleY * zoom; + + return { + x, + y, + scaleX, + scaleY, + }; + } + + return transform; + }, + [zoom] + ); + + return modifier; +}; diff --git a/invokeai/frontend/web/src/app/components/ImageDnd/typesafeDnd.tsx b/invokeai/frontend/web/src/features/dnd/types/index.ts similarity index 51% rename from invokeai/frontend/web/src/app/components/ImageDnd/typesafeDnd.tsx rename to invokeai/frontend/web/src/features/dnd/types/index.ts index 6f24302070..294132d0a3 100644 --- a/invokeai/frontend/web/src/app/components/ImageDnd/typesafeDnd.tsx +++ b/invokeai/frontend/web/src/features/dnd/types/index.ts @@ -3,7 +3,6 @@ import { Active, Collision, DndContextProps, - DndContext as OriginalDndContext, Over, Translate, UseDraggableArguments, @@ -11,6 +10,10 @@ import { useDraggable as useOriginalDraggable, useDroppable as useOriginalDroppable, } from '@dnd-kit/core'; +import { + InputFieldTemplate, + InputFieldValue, +} from 'features/nodes/types/types'; import { ImageDTO } from 'services/api/types'; type BaseDropData = { @@ -62,6 +65,10 @@ export type RemoveFromBoardDropData = BaseDropData & { actionType: 'REMOVE_FROM_BOARD'; }; +export type AddFieldToLinearViewDropData = BaseDropData & { + actionType: 'ADD_FIELD_TO_LINEAR'; +}; + export type TypesafeDroppableData = | CurrentImageDropData | InitialImageDropData @@ -71,12 +78,22 @@ export type TypesafeDroppableData = | AddToBatchDropData | NodesMultiImageDropData | AddToBoardDropData - | RemoveFromBoardDropData; + | RemoveFromBoardDropData + | AddFieldToLinearViewDropData; type BaseDragData = { id: string; }; +export type NodeFieldDraggableData = BaseDragData & { + payloadType: 'NODE_FIELD'; + payload: { + nodeId: string; + field: InputFieldValue; + fieldTemplate: InputFieldTemplate; + }; +}; + export type ImageDraggableData = BaseDragData & { payloadType: 'IMAGE_DTO'; payload: { imageDTO: ImageDTO }; @@ -87,14 +104,17 @@ export type ImageDTOsDraggableData = BaseDragData & { payload: { imageDTOs: ImageDTO[] }; }; -export type TypesafeDraggableData = ImageDraggableData | ImageDTOsDraggableData; +export type TypesafeDraggableData = + | NodeFieldDraggableData + | ImageDraggableData + | ImageDTOsDraggableData; -interface UseDroppableTypesafeArguments +export interface UseDroppableTypesafeArguments extends Omit { data?: TypesafeDroppableData; } -type UseDroppableTypesafeReturnValue = Omit< +export type UseDroppableTypesafeReturnValue = Omit< ReturnType, 'active' | 'over' > & { @@ -102,16 +122,12 @@ type UseDroppableTypesafeReturnValue = Omit< over: TypesafeOver | null; }; -export function useDroppable(props: UseDroppableTypesafeArguments) { - return useOriginalDroppable(props) as UseDroppableTypesafeReturnValue; -} - -interface UseDraggableTypesafeArguments +export interface UseDraggableTypesafeArguments extends Omit { data?: TypesafeDraggableData; } -type UseDraggableTypesafeReturnValue = Omit< +export type UseDraggableTypesafeReturnValue = Omit< ReturnType, 'active' | 'over' > & { @@ -119,102 +135,14 @@ type UseDraggableTypesafeReturnValue = Omit< over: TypesafeOver | null; }; -export function useDraggable(props: UseDraggableTypesafeArguments) { - return useOriginalDraggable(props) as UseDraggableTypesafeReturnValue; -} - -interface TypesafeActive extends Omit { +export interface TypesafeActive extends Omit { data: React.MutableRefObject; } -interface TypesafeOver extends Omit { +export interface TypesafeOver extends Omit { data: React.MutableRefObject; } -export const isValidDrop = ( - overData: TypesafeDroppableData | undefined, - active: TypesafeActive | null -) => { - if (!overData || !active?.data.current) { - return false; - } - - const { actionType } = overData; - const { payloadType } = active.data.current; - - if (overData.id === active.data.current.id) { - return false; - } - - switch (actionType) { - case 'SET_CURRENT_IMAGE': - return payloadType === 'IMAGE_DTO'; - case 'SET_INITIAL_IMAGE': - return payloadType === 'IMAGE_DTO'; - case 'SET_CONTROLNET_IMAGE': - return payloadType === 'IMAGE_DTO'; - case 'SET_CANVAS_INITIAL_IMAGE': - return payloadType === 'IMAGE_DTO'; - case 'SET_NODES_IMAGE': - return payloadType === 'IMAGE_DTO'; - case 'SET_MULTI_NODES_IMAGE': - return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS'; - case 'ADD_TO_BATCH': - return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS'; - case 'ADD_TO_BOARD': { - // If the board is the same, don't allow the drop - - // Check the payload types - const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS'; - if (!isPayloadValid) { - return false; - } - - // Check if the image's board is the board we are dragging onto - if (payloadType === 'IMAGE_DTO') { - const { imageDTO } = active.data.current.payload; - const currentBoard = imageDTO.board_id ?? 'none'; - const destinationBoard = overData.context.boardId; - - return currentBoard !== destinationBoard; - } - - if (payloadType === 'IMAGE_DTOS') { - // TODO (multi-select) - return true; - } - - return false; - } - case 'REMOVE_FROM_BOARD': { - // If the board is the same, don't allow the drop - - // Check the payload types - const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS'; - if (!isPayloadValid) { - return false; - } - - // Check if the image's board is the board we are dragging onto - if (payloadType === 'IMAGE_DTO') { - const { imageDTO } = active.data.current.payload; - const currentBoard = imageDTO.board_id; - - return currentBoard !== 'none'; - } - - if (payloadType === 'IMAGE_DTOS') { - // TODO (multi-select) - return true; - } - - return false; - } - default: - return false; - } -}; - interface DragEvent { activatorEvent: Event; active: TypesafeActive; @@ -240,6 +168,3 @@ export interface DndContextTypesafeProps onDragEnd?(event: DragEndEvent): void; onDragCancel?(event: DragCancelEvent): void; } -export function DndContext(props: DndContextTypesafeProps) { - return ; -} diff --git a/invokeai/frontend/web/src/features/dnd/util/isValidDrop.ts b/invokeai/frontend/web/src/features/dnd/util/isValidDrop.ts new file mode 100644 index 0000000000..f704d22dff --- /dev/null +++ b/invokeai/frontend/web/src/features/dnd/util/isValidDrop.ts @@ -0,0 +1,87 @@ +import { TypesafeActive, TypesafeDroppableData } from '../types'; + +export const isValidDrop = ( + overData: TypesafeDroppableData | undefined, + active: TypesafeActive | null +) => { + if (!overData || !active?.data.current) { + return false; + } + + const { actionType } = overData; + const { payloadType } = active.data.current; + + if (overData.id === active.data.current.id) { + return false; + } + + switch (actionType) { + case 'ADD_FIELD_TO_LINEAR': + return payloadType === 'NODE_FIELD'; + case 'SET_CURRENT_IMAGE': + return payloadType === 'IMAGE_DTO'; + case 'SET_INITIAL_IMAGE': + return payloadType === 'IMAGE_DTO'; + case 'SET_CONTROLNET_IMAGE': + return payloadType === 'IMAGE_DTO'; + case 'SET_CANVAS_INITIAL_IMAGE': + return payloadType === 'IMAGE_DTO'; + case 'SET_NODES_IMAGE': + return payloadType === 'IMAGE_DTO'; + case 'SET_MULTI_NODES_IMAGE': + return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS'; + case 'ADD_TO_BATCH': + return payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS'; + case 'ADD_TO_BOARD': { + // If the board is the same, don't allow the drop + + // Check the payload types + const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS'; + if (!isPayloadValid) { + return false; + } + + // Check if the image's board is the board we are dragging onto + if (payloadType === 'IMAGE_DTO') { + const { imageDTO } = active.data.current.payload; + const currentBoard = imageDTO.board_id ?? 'none'; + const destinationBoard = overData.context.boardId; + + return currentBoard !== destinationBoard; + } + + if (payloadType === 'IMAGE_DTOS') { + // TODO (multi-select) + return true; + } + + return false; + } + case 'REMOVE_FROM_BOARD': { + // If the board is the same, don't allow the drop + + // Check the payload types + const isPayloadValid = payloadType === 'IMAGE_DTO' || 'IMAGE_DTOS'; + if (!isPayloadValid) { + return false; + } + + // Check if the image's board is the board we are dragging onto + if (payloadType === 'IMAGE_DTO') { + const { imageDTO } = active.data.current.payload; + const currentBoard = imageDTO.board_id; + + return currentBoard !== 'none'; + } + + if (payloadType === 'IMAGE_DTOS') { + // TODO (multi-select) + return true; + } + + return false; + } + default: + return false; + } +}; diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx index 228ce7080c..696a8b748b 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GalleryBoard.tsx @@ -11,7 +11,6 @@ import { } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; import { skipToken } from '@reduxjs/toolkit/dist/query'; -import { AddToBoardDropData } from 'app/components/ImageDnd/typesafeDnd'; import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; @@ -32,6 +31,7 @@ import { useGetImageDTOQuery } from 'services/api/endpoints/images'; import { BoardDTO } from 'services/api/types'; import AutoAddIcon from '../AutoAddIcon'; import BoardContextMenu from '../BoardContextMenu'; +import { AddToBoardDropData } from 'features/dnd/types'; interface GalleryBoardProps { board: BoardDTO; diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GenericBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GenericBoard.tsx index 0d630c524d..1698a81ac0 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GenericBoard.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/GenericBoard.tsx @@ -1,7 +1,7 @@ import { As, Badge, Flex } from '@chakra-ui/react'; -import { TypesafeDroppableData } from 'app/components/ImageDnd/typesafeDnd'; import IAIDroppable from 'common/components/IAIDroppable'; import { IAINoContentFallback } from 'common/components/IAIImageFallback'; +import { TypesafeDroppableData } from 'features/dnd/types'; import { BoardId } from 'features/gallery/store/types'; import { ReactNode } from 'react'; import BoardContextMenu from '../BoardContextMenu'; diff --git a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx index f1341b1146..fec280db0f 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Boards/BoardsList/NoBoardBoard.tsx @@ -1,15 +1,15 @@ import { Box, Flex, Image, Text } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; -import { RemoveFromBoardDropData } from 'app/components/ImageDnd/typesafeDnd'; import { stateSelector } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import InvokeAILogoImage from 'assets/images/logo.png'; import IAIDroppable from 'common/components/IAIDroppable'; import SelectionOverlay from 'common/components/SelectionOverlay'; +import { RemoveFromBoardDropData } from 'features/dnd/types'; import { - boardIdSelected, autoAddBoardIdChanged, + boardIdSelected, } from 'features/gallery/store/gallerySlice'; import { memo, useCallback, useMemo, useState } from 'react'; import { useBoardName } from 'services/api/hooks/useBoardName'; diff --git a/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImagePreview.tsx b/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImagePreview.tsx index f78ee286ef..2576c8e9e3 100644 --- a/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImagePreview.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImagePreview.tsx @@ -1,14 +1,14 @@ import { Box, Flex, Image } from '@chakra-ui/react'; import { createSelector } from '@reduxjs/toolkit'; import { skipToken } from '@reduxjs/toolkit/dist/query'; -import { - TypesafeDraggableData, - TypesafeDroppableData, -} from 'app/components/ImageDnd/typesafeDnd'; import { stateSelector } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import IAIDndImage from 'common/components/IAIDndImage'; import { IAINoContentFallback } from 'common/components/IAIImageFallback'; +import { + TypesafeDraggableData, + TypesafeDroppableData, +} from 'features/dnd/types'; import { useNextPrevImage } from 'features/gallery/hooks/useNextPrevImage'; import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors'; import { AnimatePresence, motion } from 'framer-motion'; diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx index f2ff2ad30b..804df49b8e 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGalleryContent.tsx @@ -52,11 +52,13 @@ const ImageGalleryContent = () => { return ( diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx index c9eee5f1f5..97f8199aed 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx @@ -1,9 +1,4 @@ import { Box, Flex } from '@chakra-ui/react'; -import { - ImageDTOsDraggableData, - ImageDraggableData, - TypesafeDraggableData, -} from 'app/components/ImageDnd/typesafeDnd'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIDndImage from 'common/components/IAIDndImage'; import IAIFillSkeleton from 'common/components/IAIFillSkeleton'; @@ -12,6 +7,11 @@ import { imagesToDeleteSelected } from 'features/deleteImageModal/store/slice'; import { MouseEvent, memo, useCallback, useMemo } from 'react'; import { FaTrash } from 'react-icons/fa'; import { useGetImageDTOQuery } from 'services/api/endpoints/images'; +import { + ImageDTOsDraggableData, + ImageDraggableData, + TypesafeDraggableData, +} from 'features/dnd/types'; interface HoverableImageProps { imageName: string; diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImageGrid.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImageGrid.tsx index 4a56fe0e9a..bacd5c38ad 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImageGrid.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImageGrid.tsx @@ -26,7 +26,7 @@ const overlayScrollbarsConfig: UseOverlayScrollbarsParams = { options: { scrollbars: { visibility: 'auto', - autoHide: 'leave', + autoHide: 'scroll', autoHideDelay: 1300, theme: 'os-theme-dark', }, diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataJSON.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataJSON.tsx index 590d40438b..69385607de 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataJSON.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataJSON.tsx @@ -1,26 +1,40 @@ import { Box, Flex, IconButton, Tooltip } from '@chakra-ui/react'; import { OverlayScrollbarsComponent } from 'overlayscrollbars-react'; -import { useMemo } from 'react'; -import { FaCopy } from 'react-icons/fa'; +import { useCallback, useMemo } from 'react'; +import { FaCopy, FaSave } from 'react-icons/fa'; type Props = { - copyTooltip: string; + label: string; jsonObject: object; + fileName?: string; }; const ImageMetadataJSON = (props: Props) => { - const { copyTooltip, jsonObject } = props; + const { label, jsonObject, fileName } = props; const jsonString = useMemo( () => JSON.stringify(jsonObject, null, 2), [jsonObject] ); + const handleCopy = useCallback(() => { + navigator.clipboard.writeText(jsonString); + }, [jsonString]); + + const handleSave = useCallback(() => { + const blob = new Blob([jsonString]); + const a = document.createElement('a'); + a.href = URL.createObjectURL(blob); + a.download = `${fileName || label}.json`; + document.body.appendChild(a); + a.click(); + a.remove(); + }, [jsonString, label, fileName]); + return ( { bottom: 0, overflow: 'auto', p: 4, + fontSize: 'sm', }} > { options={{ scrollbars: { visibility: 'auto', - autoHide: 'move', + autoHide: 'scroll', autoHideDelay: 1300, theme: 'os-theme-dark', }, @@ -54,12 +69,22 @@ const ImageMetadataJSON = (props: Props) => { - + } + variant="ghost" + opacity={0.7} + onClick={handleSave} + /> + + + } variant="ghost" - onClick={() => navigator.clipboard.writeText(jsonString)} + opacity={0.7} + onClick={handleCopy} /> diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataViewer.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataViewer.tsx index e1f2a9e46a..d70aea8a8d 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataViewer.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageMetadataViewer/ImageMetadataViewer.tsx @@ -10,7 +10,8 @@ import { Text, } from '@chakra-ui/react'; import { skipToken } from '@reduxjs/toolkit/dist/query'; -import { memo, useMemo } from 'react'; +import { IAINoContentFallback } from 'common/components/IAIImageFallback'; +import { memo } from 'react'; import { useGetImageMetadataQuery } from 'services/api/endpoints/images'; import { ImageDTO } from 'services/api/types'; import { useDebounce } from 'use-debounce'; @@ -41,48 +42,15 @@ const ImageMetadataViewer = ({ image }: ImageMetadataViewerProps) => { const metadata = currentData?.metadata; const graph = currentData?.graph; - const tabData = useMemo(() => { - const _tabData: { label: string; data: object; copyTooltip: string }[] = []; - - if (metadata) { - _tabData.push({ - label: 'Core Metadata', - data: metadata, - copyTooltip: 'Copy Core Metadata JSON', - }); - } - - if (image) { - _tabData.push({ - label: 'Image Details', - data: image, - copyTooltip: 'Copy Image Details JSON', - }); - } - - if (graph) { - _tabData.push({ - label: 'Graph', - data: graph, - copyTooltip: 'Copy Graph JSON', - }); - } - return _tabData; - }, [metadata, graph, image]); - return ( { sx={{ display: 'flex', flexDir: 'column', w: 'full', h: 'full' }} > - {tabData.map((tab) => ( - - - {tab.label} - - - ))} + Core Metadata + Image Details + Graph - - {tabData.map((tab) => ( - - - - ))} + + + {metadata ? ( + + ) : ( + + )} + + + {image ? ( + + ) : ( + + )} + + + {graph ? ( + + ) : ( + + )} + diff --git a/invokeai/frontend/web/src/features/nodes/components/AddNodeMenu.tsx b/invokeai/frontend/web/src/features/nodes/components/AddNodeMenu.tsx index a1a1acf1f8..a816762d0f 100644 --- a/invokeai/frontend/web/src/features/nodes/components/AddNodeMenu.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/AddNodeMenu.tsx @@ -9,30 +9,40 @@ import { map } from 'lodash-es'; import { forwardRef, useCallback } from 'react'; import 'reactflow/dist/style.css'; import { AnyInvocationType } from 'services/events/types'; -import { useBuildInvocation } from '../hooks/useBuildInvocation'; +import { useBuildNodeData } from '../hooks/useBuildNodeData'; import { nodeAdded } from '../store/nodesSlice'; type NodeTemplate = { label: string; value: string; description: string; + tags: string[]; }; const selector = createSelector( [stateSelector], ({ nodes }) => { - const data: NodeTemplate[] = map(nodes.invocationTemplates, (template) => { + const data: NodeTemplate[] = map(nodes.nodeTemplates, (template) => { return { label: template.title, value: template.type, description: template.description, + tags: template.tags, }; }); data.push({ label: 'Progress Image', - value: 'progress_image', - description: 'Displays the progress image in the Node Editor', + value: 'current_image', + description: 'Displays the current image in the Node Editor', + tags: ['progress'], + }); + + data.push({ + label: 'Notes', + value: 'notes', + description: 'Add notes about your workflow', + tags: ['notes'], }); return { data }; @@ -44,7 +54,7 @@ const AddNodeMenu = () => { const dispatch = useAppDispatch(); const { data } = useAppSelector(selector); - const buildInvocation = useBuildInvocation(); + const buildInvocation = useBuildNodeData(); const toaster = useAppToaster(); @@ -89,11 +99,12 @@ const AddNodeMenu = () => { filter={(value, item: NodeTemplate) => item.label.toLowerCase().includes(value.toLowerCase().trim()) || item.value.toLowerCase().includes(value.toLowerCase().trim()) || - item.description.toLowerCase().includes(value.toLowerCase().trim()) + item.description.toLowerCase().includes(value.toLowerCase().trim()) || + item.tags.includes(value.toLowerCase().trim()) } onChange={handleChange} sx={{ - width: '18rem', + width: '24rem', }} /> diff --git a/invokeai/frontend/web/src/features/nodes/components/CustomConnectionLine.tsx b/invokeai/frontend/web/src/features/nodes/components/CustomConnectionLine.tsx new file mode 100644 index 0000000000..678d8e3d1d --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/CustomConnectionLine.tsx @@ -0,0 +1,61 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { ConnectionLineComponentProps, getBezierPath } from 'reactflow'; +import { FIELDS, colorTokenToCssVar } from '../types/constants'; + +const selector = createSelector(stateSelector, ({ nodes }) => { + const { shouldAnimateEdges, currentConnectionFieldType, shouldColorEdges } = + nodes; + + const stroke = + currentConnectionFieldType && shouldColorEdges + ? colorTokenToCssVar(FIELDS[currentConnectionFieldType].color) + : colorTokenToCssVar('base.500'); + + let className = 'react-flow__custom_connection-path'; + + if (shouldAnimateEdges) { + className = className.concat(' animated'); + } + + return { + stroke, + className, + }; +}); + +export const CustomConnectionLine = ({ + fromX, + fromY, + fromPosition, + toX, + toY, + toPosition, +}: ConnectionLineComponentProps) => { + const { stroke, className } = useAppSelector(selector); + + const pathParams = { + sourceX: fromX, + sourceY: fromY, + sourcePosition: fromPosition, + targetX: toX, + targetY: toY, + targetPosition: toPosition, + }; + + const [dAttr] = getBezierPath(pathParams); + + return ( + + + + ); +}; diff --git a/invokeai/frontend/web/src/features/nodes/components/CustomEdges.tsx b/invokeai/frontend/web/src/features/nodes/components/CustomEdges.tsx new file mode 100644 index 0000000000..e0ccc6e323 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/CustomEdges.tsx @@ -0,0 +1,183 @@ +import { Badge, Flex } from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens'; +import { useMemo } from 'react'; +import { + BaseEdge, + EdgeLabelRenderer, + EdgeProps, + getBezierPath, +} from 'reactflow'; +import { FIELDS, colorTokenToCssVar } from '../types/constants'; +import { isInvocationNode } from '../types/types'; + +const makeEdgeSelector = ( + source: string, + sourceHandleId: string | null | undefined, + target: string, + targetHandleId: string | null | undefined, + selected?: boolean +) => + createSelector(stateSelector, ({ nodes }) => { + const sourceNode = nodes.nodes.find((node) => node.id === source); + const targetNode = nodes.nodes.find((node) => node.id === target); + + const isInvocationToInvocationEdge = + isInvocationNode(sourceNode) && isInvocationNode(targetNode); + + const isSelected = sourceNode?.selected || targetNode?.selected || selected; + const sourceType = isInvocationToInvocationEdge + ? sourceNode?.data?.outputs[sourceHandleId || '']?.type + : undefined; + + const stroke = + sourceType && nodes.shouldColorEdges + ? colorTokenToCssVar(FIELDS[sourceType].color) + : colorTokenToCssVar('base.500'); + + return { + isSelected, + shouldAnimate: nodes.shouldAnimateEdges && isSelected, + stroke, + }; + }); + +const CollapsedEdge = ({ + sourceX, + sourceY, + targetX, + targetY, + sourcePosition, + targetPosition, + markerEnd, + data, + selected, + source, + target, + sourceHandleId, + targetHandleId, +}: EdgeProps<{ count: number }>) => { + const selector = useMemo( + () => + makeEdgeSelector( + source, + sourceHandleId, + target, + targetHandleId, + selected + ), + [selected, source, sourceHandleId, target, targetHandleId] + ); + + const { isSelected, shouldAnimate } = useAppSelector(selector); + + const [edgePath, labelX, labelY] = getBezierPath({ + sourceX, + sourceY, + sourcePosition, + targetX, + targetY, + targetPosition, + }); + + const { base500 } = useChakraThemeTokens(); + + return ( + <> + + {data?.count && data.count > 1 && ( + + + + {data.count} + + + + )} + + ); +}; + +const DefaultEdge = ({ + sourceX, + sourceY, + targetX, + targetY, + sourcePosition, + targetPosition, + markerEnd, + selected, + source, + target, + sourceHandleId, + targetHandleId, +}: EdgeProps) => { + const selector = useMemo( + () => + makeEdgeSelector( + source, + sourceHandleId, + target, + targetHandleId, + selected + ), + [source, sourceHandleId, target, targetHandleId, selected] + ); + + const { isSelected, shouldAnimate, stroke } = useAppSelector(selector); + + const [edgePath] = getBezierPath({ + sourceX, + sourceY, + sourcePosition, + targetX, + targetY, + targetPosition, + }); + + return ( + + ); +}; + +export const edgeTypes = { + collapsed: CollapsedEdge, + default: DefaultEdge, +}; diff --git a/invokeai/frontend/web/src/features/nodes/components/CustomNodes.tsx b/invokeai/frontend/web/src/features/nodes/components/CustomNodes.tsx new file mode 100644 index 0000000000..3aacb3cd58 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/CustomNodes.tsx @@ -0,0 +1,9 @@ +import InvocationNode from './nodes/InvocationNode'; +import CurrentImageNode from './nodes/CurrentImageNode'; +import NotesNode from './nodes/NotesNode'; + +export const nodeTypes = { + invocation: InvocationNode, + current_image: CurrentImageNode, + notes: NotesNode, +}; diff --git a/invokeai/frontend/web/src/features/nodes/components/FieldHandle.tsx b/invokeai/frontend/web/src/features/nodes/components/FieldHandle.tsx deleted file mode 100644 index 86099a7315..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/FieldHandle.tsx +++ /dev/null @@ -1,64 +0,0 @@ -import { Tooltip } from '@chakra-ui/react'; -import { CSSProperties, memo } from 'react'; -import { Handle, Position, Connection, HandleType } from 'reactflow'; -import { FIELDS, HANDLE_TOOLTIP_OPEN_DELAY } from '../types/constants'; -// import { useConnectionEventStyles } from '../hooks/useConnectionEventStyles'; -import { InputFieldTemplate, OutputFieldTemplate } from '../types/types'; - -const handleBaseStyles: CSSProperties = { - position: 'absolute', - width: '1rem', - height: '1rem', - borderWidth: 0, -}; - -const inputHandleStyles: CSSProperties = { - left: '-1rem', -}; - -const outputHandleStyles: CSSProperties = { - right: '-0.5rem', -}; - -// const requiredConnectionStyles: CSSProperties = { -// boxShadow: '0 0 0.5rem 0.5rem var(--invokeai-colors-error-400)', -// }; - -type FieldHandleProps = { - nodeId: string; - field: InputFieldTemplate | OutputFieldTemplate; - isValidConnection: (connection: Connection) => boolean; - handleType: HandleType; - styles?: CSSProperties; -}; - -const FieldHandle = (props: FieldHandleProps) => { - const { field, isValidConnection, handleType, styles } = props; - const { name, type } = field; - - return ( - - - - ); -}; - -export default memo(FieldHandle); diff --git a/invokeai/frontend/web/src/features/nodes/components/FieldTypeLegend.tsx b/invokeai/frontend/web/src/features/nodes/components/FieldTypeLegend.tsx index 78316cc694..a523cc29fe 100644 --- a/invokeai/frontend/web/src/features/nodes/components/FieldTypeLegend.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/FieldTypeLegend.tsx @@ -1,8 +1,8 @@ -import 'reactflow/dist/style.css'; -import { Tooltip, Badge, Flex } from '@chakra-ui/react'; +import { Badge, Flex, Tooltip } from '@chakra-ui/react'; import { map } from 'lodash-es'; -import { FIELDS } from '../types/constants'; import { memo } from 'react'; +import 'reactflow/dist/style.css'; +import { FIELDS } from '../types/constants'; const FieldTypeLegend = () => { return ( @@ -10,8 +10,14 @@ const FieldTypeLegend = () => { {map(FIELDS, ({ title, description, color }, key) => ( {title} diff --git a/invokeai/frontend/web/src/features/nodes/components/Flow.tsx b/invokeai/frontend/web/src/features/nodes/components/Flow.tsx index 7b0718182b..71062e9774 100644 --- a/invokeai/frontend/web/src/features/nodes/components/Flow.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/Flow.tsx @@ -1,4 +1,3 @@ -import { RootState } from 'app/store/store'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useCallback } from 'react'; import { @@ -7,35 +6,49 @@ import { OnConnectEnd, OnConnectStart, OnEdgesChange, + OnEdgesDelete, OnInit, + OnMove, OnNodesChange, + OnNodesDelete, + OnSelectionChangeFunc, + ProOptions, ReactFlow, } from 'reactflow'; +import { useIsValidConnection } from '../hooks/useIsValidConnection'; import { connectionEnded, connectionMade, connectionStarted, edgesChanged, + edgesDeleted, nodesChanged, - setEditorInstance, + nodesDeleted, + selectedEdgesChanged, + selectedNodesChanged, + zoomChanged, } from '../store/nodesSlice'; -import { InvocationComponent } from './InvocationComponent'; -import ProgressImageNode from './ProgressImageNode'; -import BottomLeftPanel from './panels/BottomLeftPanel.tsx'; -import MinimapPanel from './panels/MinimapPanel'; -import TopCenterPanel from './panels/TopCenterPanel'; -import TopLeftPanel from './panels/TopLeftPanel'; -import TopRightPanel from './panels/TopRightPanel'; +import { CustomConnectionLine } from './CustomConnectionLine'; +import { edgeTypes } from './CustomEdges'; +import { nodeTypes } from './CustomNodes'; +import BottomLeftPanel from './editorPanels/BottomLeftPanel'; +import MinimapPanel from './editorPanels/MinimapPanel'; +import TopCenterPanel from './editorPanels/TopCenterPanel'; +import TopLeftPanel from './editorPanels/TopLeftPanel'; +import TopRightPanel from './editorPanels/TopRightPanel'; -const nodeTypes = { - invocation: InvocationComponent, - progress_image: ProgressImageNode, -}; +// TODO: can we support reactflow? if not, we could style the attribution so it matches the app +const proOptions: ProOptions = { hideAttribution: true }; export const Flow = () => { const dispatch = useAppDispatch(); - const nodes = useAppSelector((state: RootState) => state.nodes.nodes); - const edges = useAppSelector((state: RootState) => state.nodes.edges); + const nodes = useAppSelector((state) => state.nodes.nodes); + const edges = useAppSelector((state) => state.nodes.edges); + const shouldSnapToGrid = useAppSelector( + (state) => state.nodes.shouldSnapToGrid + ); + + const isValidConnection = useIsValidConnection(); const onNodesChange: OnNodesChange = useCallback( (changes) => { @@ -69,10 +82,36 @@ export const Flow = () => { dispatch(connectionEnded()); }, [dispatch]); - const onInit: OnInit = useCallback( - (v) => { - dispatch(setEditorInstance(v)); - if (v) v.fitView(); + const onInit: OnInit = useCallback((v) => { + v.fitView(); + }, []); + + const onEdgesDelete: OnEdgesDelete = useCallback( + (edges) => { + dispatch(edgesDeleted(edges)); + }, + [dispatch] + ); + + const onNodesDelete: OnNodesDelete = useCallback( + (nodes) => { + dispatch(nodesDeleted(nodes)); + }, + [dispatch] + ); + + const handleSelectionChange: OnSelectionChangeFunc = useCallback( + ({ nodes, edges }) => { + dispatch(selectedNodesChanged(nodes ? nodes.map((n) => n.id) : [])); + dispatch(selectedEdgesChanged(edges ? edges.map((e) => e.id) : [])); + }, + [dispatch] + ); + + const handleMove: OnMove = useCallback( + (e, viewport) => { + const { zoom } = viewport; + dispatch(zoomChanged(zoom)); }, [dispatch] ); @@ -80,24 +119,33 @@ export const Flow = () => { return ( - + ); }; diff --git a/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeHeader.tsx b/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeHeader.tsx deleted file mode 100644 index 7b56bc95b4..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeHeader.tsx +++ /dev/null @@ -1,55 +0,0 @@ -import { Flex, Heading, Icon, Tooltip } from '@chakra-ui/react'; -import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/hooks/useBuildInvocation'; -import { memo } from 'react'; -import { FaInfoCircle } from 'react-icons/fa'; - -interface IAINodeHeaderProps { - nodeId?: string; - title?: string; - description?: string; -} - -const IAINodeHeader = (props: IAINodeHeaderProps) => { - const { nodeId, title, description } = props; - return ( - - - - {title} - - - - - - - ); -}; - -export default memo(IAINodeHeader); diff --git a/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeInputs.tsx b/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeInputs.tsx deleted file mode 100644 index 6f779e4295..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeInputs.tsx +++ /dev/null @@ -1,149 +0,0 @@ -import { - Box, - Divider, - Flex, - FormControl, - FormLabel, - HStack, - Tooltip, -} from '@chakra-ui/react'; -import { RootState } from 'app/store/store'; -import { useAppSelector } from 'app/store/storeHooks'; -import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection'; -import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants'; -import { - InputFieldTemplate, - InputFieldValue, - InvocationTemplate, -} from 'features/nodes/types/types'; -import { map } from 'lodash-es'; -import { ReactNode, memo, useCallback } from 'react'; -import FieldHandle from '../FieldHandle'; -import InputFieldComponent from '../InputFieldComponent'; - -interface IAINodeInputProps { - nodeId: string; - - input: InputFieldValue; - template?: InputFieldTemplate | undefined; - connected: boolean; -} - -function IAINodeInput(props: IAINodeInputProps) { - const { nodeId, input, template, connected } = props; - const isValidConnection = useIsValidConnection(); - - return ( - - - {!template ? ( - - Unknown input: {input.name} - - ) : ( - <> - - - - {template?.title} - - - - - - {!['never', 'directOnly'].includes( - template?.inputRequirement ?? '' - ) && ( - - )} - - )} - - - ); -} - -interface IAINodeInputsProps { - nodeId: string; - template: InvocationTemplate; - inputs: Record; -} - -const IAINodeInputs = (props: IAINodeInputsProps) => { - const { nodeId, template, inputs } = props; - - const edges = useAppSelector((state: RootState) => state.nodes.edges); - - const renderIAINodeInputs = useCallback(() => { - const IAINodeInputsToRender: ReactNode[] = []; - const inputSockets = map(inputs); - - inputSockets.forEach((inputSocket, index) => { - const inputTemplate = template.inputs[inputSocket.name]; - - const isConnected = Boolean( - edges.filter((connectedInput) => { - return ( - connectedInput.target === nodeId && - connectedInput.targetHandle === inputSocket.name - ); - }).length - ); - - if (index < inputSockets.length) { - IAINodeInputsToRender.push( - - ); - } - - IAINodeInputsToRender.push( - - ); - }); - - return ( - - {IAINodeInputsToRender} - - ); - }, [edges, inputs, nodeId, template.inputs]); - - return renderIAINodeInputs(); -}; - -export default memo(IAINodeInputs); diff --git a/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeOutputs.tsx b/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeOutputs.tsx deleted file mode 100644 index 2cb0bcde8d..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeOutputs.tsx +++ /dev/null @@ -1,97 +0,0 @@ -import { - InvocationTemplate, - OutputFieldTemplate, - OutputFieldValue, -} from 'features/nodes/types/types'; -import { memo, ReactNode, useCallback } from 'react'; -import { map } from 'lodash-es'; -import { useAppSelector } from 'app/store/storeHooks'; -import { RootState } from 'app/store/store'; -import { Box, Flex, FormControl, FormLabel, HStack } from '@chakra-ui/react'; -import FieldHandle from '../FieldHandle'; -import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection'; - -interface IAINodeOutputProps { - nodeId: string; - output: OutputFieldValue; - template?: OutputFieldTemplate | undefined; - connected: boolean; -} - -function IAINodeOutput(props: IAINodeOutputProps) { - const { nodeId, output, template, connected } = props; - const isValidConnection = useIsValidConnection(); - - return ( - - - {!template ? ( - - - Unknown Output: {output.name} - - - ) : ( - <> - - {template?.title} - - - - )} - - - ); -} - -interface IAINodeOutputsProps { - nodeId: string; - template: InvocationTemplate; - outputs: Record; -} - -const IAINodeOutputs = (props: IAINodeOutputsProps) => { - const { nodeId, template, outputs } = props; - - const edges = useAppSelector((state: RootState) => state.nodes.edges); - - const renderIAINodeOutputs = useCallback(() => { - const IAINodeOutputsToRender: ReactNode[] = []; - const outputSockets = map(outputs); - - outputSockets.forEach((outputSocket) => { - const outputTemplate = template.outputs[outputSocket.name]; - - const isConnected = Boolean( - edges.filter((connectedInput) => { - return ( - connectedInput.source === nodeId && - connectedInput.sourceHandle === outputSocket.name - ); - }).length - ); - - IAINodeOutputsToRender.push( - - ); - }); - - return {IAINodeOutputsToRender}; - }, [edges, nodeId, outputs, template.outputs]); - - return renderIAINodeOutputs(); -}; - -export default memo(IAINodeOutputs); diff --git a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx deleted file mode 100644 index 0ecc43ef9c..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/InputFieldComponent.tsx +++ /dev/null @@ -1,252 +0,0 @@ -import { Box } from '@chakra-ui/react'; -import { memo } from 'react'; -import { InputFieldTemplate, InputFieldValue } from '../types/types'; -import ArrayInputFieldComponent from './fields/ArrayInputFieldComponent'; -import BooleanInputFieldComponent from './fields/BooleanInputFieldComponent'; -import ClipInputFieldComponent from './fields/ClipInputFieldComponent'; -import ColorInputFieldComponent from './fields/ColorInputFieldComponent'; -import ConditioningInputFieldComponent from './fields/ConditioningInputFieldComponent'; -import ControlInputFieldComponent from './fields/ControlInputFieldComponent'; -import ControlNetModelInputFieldComponent from './fields/ControlNetModelInputFieldComponent'; -import EnumInputFieldComponent from './fields/EnumInputFieldComponent'; -import ImageCollectionInputFieldComponent from './fields/ImageCollectionInputFieldComponent'; -import ImageInputFieldComponent from './fields/ImageInputFieldComponent'; -import ItemInputFieldComponent from './fields/ItemInputFieldComponent'; -import LatentsInputFieldComponent from './fields/LatentsInputFieldComponent'; -import LoRAModelInputFieldComponent from './fields/LoRAModelInputFieldComponent'; -import ModelInputFieldComponent from './fields/ModelInputFieldComponent'; -import NumberInputFieldComponent from './fields/NumberInputFieldComponent'; -import StringInputFieldComponent from './fields/StringInputFieldComponent'; -import UnetInputFieldComponent from './fields/UnetInputFieldComponent'; -import VaeInputFieldComponent from './fields/VaeInputFieldComponent'; -import VaeModelInputFieldComponent from './fields/VaeModelInputFieldComponent'; -import RefinerModelInputFieldComponent from './fields/RefinerModelInputFieldComponent'; - -type InputFieldComponentProps = { - nodeId: string; - field: InputFieldValue; - template: InputFieldTemplate; -}; - -// build an individual input element based on the schema -const InputFieldComponent = (props: InputFieldComponentProps) => { - const { nodeId, field, template } = props; - const { type } = field; - - if (type === 'string' && template.type === 'string') { - return ( - - ); - } - - if (type === 'boolean' && template.type === 'boolean') { - return ( - - ); - } - - if ( - (type === 'integer' && template.type === 'integer') || - (type === 'float' && template.type === 'float') - ) { - return ( - - ); - } - - if (type === 'enum' && template.type === 'enum') { - return ( - - ); - } - - if (type === 'image' && template.type === 'image') { - return ( - - ); - } - - if (type === 'latents' && template.type === 'latents') { - return ( - - ); - } - - if (type === 'conditioning' && template.type === 'conditioning') { - return ( - - ); - } - - if (type === 'unet' && template.type === 'unet') { - return ( - - ); - } - - if (type === 'clip' && template.type === 'clip') { - return ( - - ); - } - - if (type === 'vae' && template.type === 'vae') { - return ( - - ); - } - - if (type === 'control' && template.type === 'control') { - return ( - - ); - } - - if (type === 'model' && template.type === 'model') { - return ( - - ); - } - - if (type === 'refiner_model' && template.type === 'refiner_model') { - return ( - - ); - } - - if (type === 'vae_model' && template.type === 'vae_model') { - return ( - - ); - } - - if (type === 'lora_model' && template.type === 'lora_model') { - return ( - - ); - } - - if (type === 'controlnet_model' && template.type === 'controlnet_model') { - return ( - - ); - } - - if (type === 'array' && template.type === 'array') { - return ( - - ); - } - - if (type === 'item' && template.type === 'item') { - return ( - - ); - } - - if (type === 'color' && template.type === 'color') { - return ( - - ); - } - - if (type === 'item' && template.type === 'item') { - return ( - - ); - } - - if (type === 'image_collection' && template.type === 'image_collection') { - return ( - - ); - } - - return Unknown field type: {type}; -}; - -export default memo(InputFieldComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeCollapseButton.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeCollapseButton.tsx new file mode 100644 index 0000000000..d67ca10dcc --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeCollapseButton.tsx @@ -0,0 +1,57 @@ +import { ChevronUpIcon } from '@chakra-ui/icons'; +import { useAppDispatch } from 'app/store/storeHooks'; +import IAIIconButton from 'common/components/IAIIconButton'; +import { nodeIsOpenChanged } from 'features/nodes/store/nodesSlice'; +import { NodeData } from 'features/nodes/types/types'; +import { memo, useCallback } from 'react'; +import { NodeProps, useUpdateNodeInternals } from 'reactflow'; + +interface Props { + nodeProps: NodeProps; +} + +const NodeCollapseButton = (props: Props) => { + const { id: nodeId, isOpen } = props.nodeProps.data; + const dispatch = useAppDispatch(); + const updateNodeInternals = useUpdateNodeInternals(); + + const handleClick = useCallback(() => { + dispatch(nodeIsOpenChanged({ nodeId, isOpen: !isOpen })); + updateNodeInternals(nodeId); + }, [dispatch, isOpen, nodeId, updateNodeInternals]); + + return ( + + } + /> + ); +}; + +export default memo(NodeCollapseButton); diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeCollapsedHandles.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeCollapsedHandles.tsx new file mode 100644 index 0000000000..ece24f6f8c --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeCollapsedHandles.tsx @@ -0,0 +1,74 @@ +import { useColorModeValue } from '@chakra-ui/react'; +import { useChakraThemeTokens } from 'common/hooks/useChakraThemeTokens'; +import { + InvocationNodeData, + InvocationTemplate, +} from 'features/nodes/types/types'; +import { map } from 'lodash-es'; +import { CSSProperties, memo, useMemo } from 'react'; +import { Handle, NodeProps, Position } from 'reactflow'; + +interface Props { + nodeProps: NodeProps; + nodeTemplate: InvocationTemplate; +} + +const NodeCollapsedHandles = (props: Props) => { + const { data } = props.nodeProps; + const { base400, base600 } = useChakraThemeTokens(); + const backgroundColor = useColorModeValue(base400, base600); + + const dummyHandleStyles: CSSProperties = useMemo( + () => ({ + borderWidth: 0, + borderRadius: '3px', + width: '1rem', + height: '1rem', + backgroundColor, + zIndex: -1, + }), + [backgroundColor] + ); + + return ( + <> + + {map(data.inputs, (input) => ( + false} + position={Position.Left} + style={{ visibility: 'hidden' }} + /> + ))} + false} + isConnectable={false} + position={Position.Right} + style={{ ...dummyHandleStyles, right: '-0.5rem' }} + /> + {map(data.outputs, (output) => ( + false} + position={Position.Right} + style={{ visibility: 'hidden' }} + /> + ))} + + ); +}; + +export default memo(NodeCollapsedHandles); diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeFooter.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeFooter.tsx new file mode 100644 index 0000000000..3c513ed29a --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeFooter.tsx @@ -0,0 +1,77 @@ +import { + Checkbox, + Flex, + FormControl, + FormLabel, + Spacer, +} from '@chakra-ui/react'; +import { useAppDispatch } from 'app/store/storeHooks'; +import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice'; +import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; +import { + InvocationNodeData, + InvocationTemplate, +} from 'features/nodes/types/types'; +import { some } from 'lodash-es'; +import { ChangeEvent, memo, useCallback, useMemo } from 'react'; +import { NodeProps } from 'reactflow'; + +type Props = { + nodeProps: NodeProps; + nodeTemplate: InvocationTemplate; +}; + +const NodeFooter = (props: Props) => { + const { nodeProps, nodeTemplate } = props; + const dispatch = useAppDispatch(); + + const hasImageOutput = useMemo( + () => + some(nodeTemplate?.outputs, (output) => + ['ImageField', 'ImageCollection'].includes(output.type) + ), + [nodeTemplate?.outputs] + ); + + const handleChangeIsIntermediate = useCallback( + (e: ChangeEvent) => { + dispatch( + fieldBooleanValueChanged({ + nodeId: nodeProps.data.id, + fieldName: 'is_intermediate', + value: !e.target.checked, + }) + ); + }, + [dispatch, nodeProps.data.id] + ); + + return ( + + + {hasImageOutput && ( + + Save Output + + + )} + + ); +}; + +export default memo(NodeFooter); diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeNotesEdit.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeNotesEdit.tsx new file mode 100644 index 0000000000..ab54ca2c44 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeNotesEdit.tsx @@ -0,0 +1,113 @@ +import { + Flex, + FormControl, + FormLabel, + Icon, + Modal, + ModalBody, + ModalCloseButton, + ModalContent, + ModalFooter, + ModalHeader, + ModalOverlay, + Text, + Tooltip, + useDisclosure, +} from '@chakra-ui/react'; +import { useAppDispatch } from 'app/store/storeHooks'; +import IAITextarea from 'common/components/IAITextarea'; +import { nodeNotesChanged } from 'features/nodes/store/nodesSlice'; +import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; +import { + InvocationNodeData, + InvocationTemplate, +} from 'features/nodes/types/types'; +import { ChangeEvent, memo, useCallback } from 'react'; +import { FaInfoCircle } from 'react-icons/fa'; +import { NodeProps } from 'reactflow'; + +interface Props { + nodeProps: NodeProps; + nodeTemplate: InvocationTemplate; +} + +const NodeNotesEdit = (props: Props) => { + const { nodeProps, nodeTemplate } = props; + const { data } = nodeProps; + const { isOpen, onOpen, onClose } = useDisclosure(); + const dispatch = useAppDispatch(); + const handleNotesChanged = useCallback( + (e: ChangeEvent) => { + dispatch(nodeNotesChanged({ nodeId: data.id, notes: e.target.value })); + }, + [data.id, dispatch] + ); + + return ( + <> + + ) : undefined + } + placement="top" + shouldWrapChildren + > + + + + + + + + + + {data.label || nodeTemplate?.title || 'Unknown Node'} + + + + + Notes + + + + + + + + ); +}; + +export default memo(NodeNotesEdit); + +type TooltipContentProps = Props; + +const TooltipContent = (props: TooltipContentProps) => { + return ( + + {props.nodeTemplate?.title} + + {props.nodeTemplate?.description} + + {props.nodeProps.data.notes && {props.nodeProps.data.notes}} + + ); +}; diff --git a/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeResizer.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeResizer.tsx similarity index 73% rename from invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeResizer.tsx rename to invokeai/frontend/web/src/features/nodes/components/Invocation/NodeResizer.tsx index 1aca32ec70..6391e86471 100644 --- a/invokeai/frontend/web/src/features/nodes/components/IAINode/IAINodeResizer.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeResizer.tsx @@ -2,7 +2,10 @@ import { NODE_MIN_WIDTH } from 'features/nodes/types/constants'; import { memo } from 'react'; import { NodeResizeControl, NodeResizerProps } from 'reactflow'; -const IAINodeResizer = (props: NodeResizerProps) => { +// this causes https://github.com/invoke-ai/InvokeAI/issues/4140 +// not using it for now + +const NodeResizer = (props: NodeResizerProps) => { const { ...rest } = props; return ( { ); }; -export default memo(IAINodeResizer); +export default memo(NodeResizer); diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeSettings.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeSettings.tsx new file mode 100644 index 0000000000..bf12358871 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeSettings.tsx @@ -0,0 +1,69 @@ +import { Flex } from '@chakra-ui/react'; +import { useAppDispatch } from 'app/store/storeHooks'; +import IAIIconButton from 'common/components/IAIIconButton'; +import IAIPopover from 'common/components/IAIPopover'; +import IAISwitch from 'common/components/IAISwitch'; +import { fieldBooleanValueChanged } from 'features/nodes/store/nodesSlice'; +import { InvocationNodeData } from 'features/nodes/types/types'; +import { ChangeEvent, memo, useCallback } from 'react'; +import { FaBars } from 'react-icons/fa'; + +interface Props { + data: InvocationNodeData; +} + +const NodeSettings = (props: Props) => { + const { data } = props; + const dispatch = useAppDispatch(); + + const handleChangeIsIntermediate = useCallback( + (e: ChangeEvent) => { + dispatch( + fieldBooleanValueChanged({ + nodeId: data.id, + fieldName: 'is_intermediate', + value: e.target.checked, + }) + ); + }, + [data.id, dispatch] + ); + + return ( + } + /> + } + > + + + + + ); +}; + +export default memo(NodeSettings); diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeStatusIndicator.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeStatusIndicator.tsx new file mode 100644 index 0000000000..6695c4fd3b --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeStatusIndicator.tsx @@ -0,0 +1,185 @@ +import { + Badge, + CircularProgress, + Flex, + Icon, + Image, + Text, + Tooltip, +} from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppSelector } from 'app/store/storeHooks'; +import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; +import { + InvocationNodeData, + NodeExecutionState, + NodeStatus, +} from 'features/nodes/types/types'; +import { memo, useMemo } from 'react'; +import { FaCheck, FaEllipsisH, FaExclamation } from 'react-icons/fa'; +import { NodeProps } from 'reactflow'; + +type Props = { + nodeProps: NodeProps; +}; + +const iconBoxSize = 3; +const circleStyles = { + circle: { + transitionProperty: 'none', + transitionDuration: '0s', + }, + '.chakra-progress__track': { stroke: 'transparent' }, +}; + +const NodeStatusIndicator = (props: Props) => { + const nodeId = props.nodeProps.data.id; + const selectNodeExecutionState = useMemo( + () => + createSelector( + stateSelector, + ({ nodes }) => nodes.nodeExecutionStates[nodeId] + ), + [nodeId] + ); + + const nodeExecutionState = useAppSelector(selectNodeExecutionState); + + if (!nodeExecutionState) { + return null; + } + + return ( + } + placement="top" + > + + + + + ); +}; + +export default memo(NodeStatusIndicator); + +type TooltipLabelProps = { + nodeExecutionState: NodeExecutionState; +}; + +const TooltipLabel = ({ nodeExecutionState }: TooltipLabelProps) => { + const { status, progress, progressImage } = nodeExecutionState; + if (status === NodeStatus.PENDING) { + return Pending; + } + + if (status === NodeStatus.IN_PROGRESS) { + if (progressImage) { + return ( + + + {progress !== null && ( + + {Math.round(progress * 100)}% + + )} + + ); + } + + if (progress !== null) { + return In Progress ({Math.round(progress * 100)}%); + } + + return In Progress; + } + + if (status === NodeStatus.COMPLETED) { + return Completed; + } + + if (status === NodeStatus.FAILED) { + return nodeExecutionState.error; + } + + return null; +}; + +type StatusIconProps = { + nodeExecutionState: NodeExecutionState; +}; + +const StatusIcon = (props: StatusIconProps) => { + const { progress, status } = props.nodeExecutionState; + if (status === NodeStatus.PENDING) { + return ( + + ); + } + if (status === NodeStatus.IN_PROGRESS) { + return progress === null ? ( + + ) : ( + + ); + } + if (status === NodeStatus.COMPLETED) { + return ( + + ); + } + if (status === NodeStatus.FAILED) { + return ( + + ); + } + return null; +}; diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeTitle.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeTitle.tsx new file mode 100644 index 0000000000..fa6a8ea224 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeTitle.tsx @@ -0,0 +1,123 @@ +import { + Box, + Editable, + EditableInput, + EditablePreview, + Flex, + useEditableControls, +} from '@chakra-ui/react'; +import { useAppDispatch } from 'app/store/storeHooks'; +import { nodeLabelChanged } from 'features/nodes/store/nodesSlice'; +import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants'; +import { NodeData } from 'features/nodes/types/types'; +import { MouseEvent, memo, useCallback, useEffect, useState } from 'react'; + +type Props = { + nodeData: NodeData; + title: string; +}; + +const NodeTitle = (props: Props) => { + const { title } = props; + const { id: nodeId, label } = props.nodeData; + const dispatch = useAppDispatch(); + const [localTitle, setLocalTitle] = useState(label || title); + + const handleSubmit = useCallback( + async (newTitle: string) => { + dispatch(nodeLabelChanged({ nodeId, label: newTitle })); + setLocalTitle(newTitle || title); + }, + [nodeId, dispatch, title] + ); + + const handleChange = useCallback((newTitle: string) => { + setLocalTitle(newTitle); + }, []); + + useEffect(() => { + // Another component may change the title; sync local title with global state + setLocalTitle(label || title); + }, [label, title]); + + return ( + + + + + + + + ); +}; + +export default memo(NodeTitle); + +function EditableControls() { + const { isEditing, getEditButtonProps } = useEditableControls(); + const handleDoubleClick = useCallback( + (e: MouseEvent) => { + const { onClick } = getEditButtonProps(); + if (!onClick) { + return; + } + onClick(e); + }, + [getEditButtonProps] + ); + + if (isEditing) { + return null; + } + + return ( + + ); +} diff --git a/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeWrapper.tsx b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeWrapper.tsx new file mode 100644 index 0000000000..2f555d700a --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/Invocation/NodeWrapper.tsx @@ -0,0 +1,96 @@ +import { + Box, + ChakraProps, + useColorModeValue, + useToken, +} from '@chakra-ui/react'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { nodeClicked } from 'features/nodes/store/nodesSlice'; +import { MouseEvent, PropsWithChildren, useCallback, useMemo } from 'react'; +import { DRAG_HANDLE_CLASSNAME, NODE_WIDTH } from '../../types/constants'; +import { NodeData } from 'features/nodes/types/types'; +import { NodeProps } from 'reactflow'; + +const useNodeSelect = (nodeId: string) => { + const dispatch = useAppDispatch(); + + const selectNode = useCallback( + (e: MouseEvent) => { + dispatch(nodeClicked({ nodeId, ctrlOrMeta: e.ctrlKey || e.metaKey })); + }, + [dispatch, nodeId] + ); + + return selectNode; +}; + +type NodeWrapperProps = PropsWithChildren & { + nodeProps: NodeProps; + width?: NonNullable['w']; +}; + +const NodeWrapper = (props: NodeWrapperProps) => { + const { width, children, nodeProps } = props; + const { data, selected } = nodeProps; + const nodeId = data.id; + + const [ + nodeSelectedOutlineLight, + nodeSelectedOutlineDark, + shadowsXl, + shadowsBase, + ] = useToken('shadows', [ + 'nodeSelectedOutline.light', + 'nodeSelectedOutline.dark', + 'shadows.xl', + 'shadows.base', + ]); + + const selectNode = useNodeSelect(nodeId); + + const shadow = useColorModeValue( + nodeSelectedOutlineLight, + nodeSelectedOutlineDark + ); + + const shift = useAppSelector((state) => state.hotkeys.shift); + const opacity = useAppSelector((state) => state.nodes.nodeOpacity); + const className = useMemo( + () => (shift ? DRAG_HANDLE_CLASSNAME : 'nopan'), + [shift] + ); + + return ( + + + {children} + + ); +}; + +export default NodeWrapper; diff --git a/invokeai/frontend/web/src/features/nodes/components/InvocationComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/InvocationComponent.tsx deleted file mode 100644 index 4c031afaff..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/InvocationComponent.tsx +++ /dev/null @@ -1,74 +0,0 @@ -import { Flex, Icon } from '@chakra-ui/react'; -import { FaExclamationCircle } from 'react-icons/fa'; -import { NodeProps } from 'reactflow'; -import { InvocationValue } from '../types/types'; - -import { useAppSelector } from 'app/store/storeHooks'; -import { memo, useMemo } from 'react'; -import { makeTemplateSelector } from '../store/util/makeTemplateSelector'; -import IAINodeHeader from './IAINode/IAINodeHeader'; -import IAINodeInputs from './IAINode/IAINodeInputs'; -import IAINodeOutputs from './IAINode/IAINodeOutputs'; -import IAINodeResizer from './IAINode/IAINodeResizer'; -import NodeWrapper from './NodeWrapper'; - -export const InvocationComponent = memo((props: NodeProps) => { - const { id: nodeId, data, selected } = props; - const { type, inputs, outputs } = data; - - const templateSelector = useMemo(() => makeTemplateSelector(type), [type]); - - const template = useAppSelector(templateSelector); - - if (!template) { - return ( - - - - - - - ); - } - - return ( - - - - - - - - - ); -}); - -InvocationComponent.displayName = 'InvocationComponent'; diff --git a/invokeai/frontend/web/src/features/nodes/components/NodeEditor.tsx b/invokeai/frontend/web/src/features/nodes/components/NodeEditor.tsx index 8c0480774c..8af9fefa90 100644 --- a/invokeai/frontend/web/src/features/nodes/components/NodeEditor.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/NodeEditor.tsx @@ -1,25 +1,45 @@ import { Box } from '@chakra-ui/react'; -import { ReactFlowProvider } from 'reactflow'; +import ResizeHandle from 'features/ui/components/tabs/ResizeHandle'; +import { memo, useState } from 'react'; +import { Panel, PanelGroup } from 'react-resizable-panels'; import 'reactflow/dist/style.css'; - -import { memo } from 'react'; import { Flow } from './Flow'; +import NodeEditorPanelGroup from './panel/NodeEditorPanelGroup'; const NodeEditor = () => { + const [isPanelCollapsed, setIsPanelCollapsed] = useState(false); return ( - - - - - + + + + + + + + + + ); }; diff --git a/invokeai/frontend/web/src/features/nodes/components/NodeEditorSettings.tsx b/invokeai/frontend/web/src/features/nodes/components/NodeEditorSettings.tsx new file mode 100644 index 0000000000..58e2e3564e --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/NodeEditorSettings.tsx @@ -0,0 +1,139 @@ +import { + Divider, + Flex, + Heading, + Modal, + ModalBody, + ModalCloseButton, + ModalContent, + ModalHeader, + ModalOverlay, + useDisclosure, +} from '@chakra-ui/react'; +import { createSelector } from '@reduxjs/toolkit'; +import { stateSelector } from 'app/store/store'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import IAIIconButton from 'common/components/IAIIconButton'; +import IAISwitch from 'common/components/IAISwitch'; +import { ChangeEvent, useCallback } from 'react'; +import { FaCog } from 'react-icons/fa'; +import { + shouldAnimateEdgesChanged, + shouldColorEdgesChanged, + shouldSnapToGridChanged, + shouldValidateGraphChanged, +} from '../store/nodesSlice'; + +const selector = createSelector(stateSelector, ({ nodes }) => { + const { + shouldAnimateEdges, + shouldValidateGraph, + shouldSnapToGrid, + shouldColorEdges, + } = nodes; + return { + shouldAnimateEdges, + shouldValidateGraph, + shouldSnapToGrid, + shouldColorEdges, + }; +}); + +const NodeEditorSettings = () => { + const { isOpen, onOpen, onClose } = useDisclosure(); + const dispatch = useAppDispatch(); + const { + shouldAnimateEdges, + shouldValidateGraph, + shouldSnapToGrid, + shouldColorEdges, + } = useAppSelector(selector); + + const handleChangeShouldValidate = useCallback( + (e: ChangeEvent) => { + dispatch(shouldValidateGraphChanged(e.target.checked)); + }, + [dispatch] + ); + + const handleChangeShouldAnimate = useCallback( + (e: ChangeEvent) => { + dispatch(shouldAnimateEdgesChanged(e.target.checked)); + }, + [dispatch] + ); + + const handleChangeShouldSnap = useCallback( + (e: ChangeEvent) => { + dispatch(shouldSnapToGridChanged(e.target.checked)); + }, + [dispatch] + ); + + const handleChangeShouldColor = useCallback( + (e: ChangeEvent) => { + dispatch(shouldColorEdgesChanged(e.target.checked)); + }, + [dispatch] + ); + + return ( + <> + } + onClick={onOpen} + /> + + + + + Node Editor Settings + + + + General + + + + + + + Advanced + + + + + + + + ); +}; + +export default NodeEditorSettings; diff --git a/invokeai/frontend/web/src/features/nodes/components/NodeGraphOverlay.tsx b/invokeai/frontend/web/src/features/nodes/components/NodeGraphOverlay.tsx index 1d498f19f5..4525dc5f6b 100644 --- a/invokeai/frontend/web/src/features/nodes/components/NodeGraphOverlay.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/NodeGraphOverlay.tsx @@ -1,34 +1,26 @@ -import { Box } from '@chakra-ui/react'; import { RootState } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; -import { memo } from 'react'; +import ImageMetadataJSON from 'features/gallery/components/ImageMetadataViewer/ImageMetadataJSON'; +import { omit } from 'lodash-es'; +import { useMemo } from 'react'; +import { useDebounce } from 'use-debounce'; import { buildNodesGraph } from '../util/graphBuilders/buildNodesGraph'; -const NodeGraphOverlay = () => { - const state = useAppSelector((state: RootState) => state); - const graph = buildNodesGraph(state); - - return ( - - {JSON.stringify(graph, null, 2)} - +const useNodesGraph = () => { + const nodes = useAppSelector((state: RootState) => state.nodes); + const [debouncedNodes] = useDebounce(nodes, 300); + const graph = useMemo( + () => omit(buildNodesGraph(debouncedNodes), 'id'), + [debouncedNodes] ); + + return graph; }; -export default memo(NodeGraphOverlay); +const NodeGraph = () => { + const graph = useNodesGraph(); + + return ; +}; + +export default NodeGraph; diff --git a/invokeai/frontend/web/src/features/nodes/components/NodeOpacitySlider.tsx b/invokeai/frontend/web/src/features/nodes/components/NodeOpacitySlider.tsx new file mode 100644 index 0000000000..693940859f --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/NodeOpacitySlider.tsx @@ -0,0 +1,42 @@ +import { + Box, + Slider, + SliderFilledTrack, + SliderThumb, + SliderTrack, +} from '@chakra-ui/react'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useCallback } from 'react'; +import { nodeOpacityChanged } from '../store/nodesSlice'; + +export default function NodeOpacitySlider() { + const dispatch = useAppDispatch(); + const nodeOpacity = useAppSelector((state) => state.nodes.nodeOpacity); + + const handleChange = useCallback( + (v: number) => { + dispatch(nodeOpacityChanged(v)); + }, + [dispatch] + ); + + return ( + + + + + + + + + ); +} diff --git a/invokeai/frontend/web/src/features/nodes/components/NodeWrapper.tsx b/invokeai/frontend/web/src/features/nodes/components/NodeWrapper.tsx deleted file mode 100644 index bc7944a28b..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/NodeWrapper.tsx +++ /dev/null @@ -1,36 +0,0 @@ -import { Box, useToken } from '@chakra-ui/react'; -import { useAppSelector } from 'app/store/storeHooks'; -import { PropsWithChildren } from 'react'; -import { DRAG_HANDLE_CLASSNAME } from '../hooks/useBuildInvocation'; -import { NODE_MIN_WIDTH } from '../types/constants'; - -type NodeWrapperProps = PropsWithChildren & { - selected: boolean; -}; - -const NodeWrapper = (props: NodeWrapperProps) => { - const [nodeSelectedOutline, nodeShadow] = useToken('shadows', [ - 'nodeSelectedOutline', - 'dark-lg', - ]); - - const shift = useAppSelector((state) => state.hotkeys.shift); - - return ( - - {props.children} - - ); -}; - -export default NodeWrapper; diff --git a/invokeai/frontend/web/src/features/nodes/components/ProgressImageNode.tsx b/invokeai/frontend/web/src/features/nodes/components/ProgressImageNode.tsx deleted file mode 100644 index 142e2a2990..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/ProgressImageNode.tsx +++ /dev/null @@ -1,73 +0,0 @@ -import { Flex, Image } from '@chakra-ui/react'; -import { RootState } from 'app/store/store'; -import { IAINoContentFallback } from 'common/components/IAIImageFallback'; -import { memo } from 'react'; -import { useDispatch, useSelector } from 'react-redux'; -import { NodeProps, OnResize } from 'reactflow'; -import { setProgressNodeSize } from '../store/nodesSlice'; -import IAINodeHeader from './IAINode/IAINodeHeader'; -import IAINodeResizer from './IAINode/IAINodeResizer'; -import NodeWrapper from './NodeWrapper'; - -const ProgressImageNode = (props: NodeProps) => { - const progressImage = useSelector( - (state: RootState) => state.system.progressImage - ); - const progressNodeSize = useSelector( - (state: RootState) => state.nodes.progressNodeSize - ); - const dispatch = useDispatch(); - const { selected } = props; - - const handleResize: OnResize = (_, newSize) => { - dispatch(setProgressNodeSize(newSize)); - }; - - return ( - - - - {progressImage ? ( - - ) : ( - - - - )} - - - - ); -}; - -export default memo(ProgressImageNode); diff --git a/invokeai/frontend/web/src/features/nodes/components/ViewportControls.tsx b/invokeai/frontend/web/src/features/nodes/components/ViewportControls.tsx index 796cdb010e..7416c6c555 100644 --- a/invokeai/frontend/web/src/features/nodes/components/ViewportControls.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/ViewportControls.tsx @@ -2,18 +2,16 @@ import { ButtonGroup, Tooltip } from '@chakra-ui/react'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import IAIIconButton from 'common/components/IAIIconButton'; import { memo, useCallback } from 'react'; -import { - FaCode, - FaExpand, - FaMinus, - FaPlus, - FaInfo, - FaMapMarkerAlt, -} from 'react-icons/fa'; -import { useReactFlow } from 'reactflow'; import { useTranslation } from 'react-i18next'; import { - shouldShowGraphOverlayChanged, + FaExpand, + FaInfo, + FaMapMarkerAlt, + FaMinus, + FaPlus, +} from 'react-icons/fa'; +import { useReactFlow } from 'reactflow'; +import { shouldShowFieldTypeLegendChanged, shouldShowMinimapPanelChanged, } from '../store/nodesSlice'; @@ -22,9 +20,6 @@ const ViewportControls = () => { const { t } = useTranslation(); const { zoomIn, zoomOut, fitView } = useReactFlow(); const dispatch = useAppDispatch(); - const shouldShowGraphOverlay = useAppSelector( - (state) => state.nodes.shouldShowGraphOverlay - ); const shouldShowFieldTypeLegend = useAppSelector( (state) => state.nodes.shouldShowFieldTypeLegend ); @@ -44,10 +39,6 @@ const ViewportControls = () => { fitView(); }, [fitView]); - const handleClickedToggleGraphOverlay = useCallback(() => { - dispatch(shouldShowGraphOverlayChanged(!shouldShowGraphOverlay)); - }, [shouldShowGraphOverlay, dispatch]); - const handleClickedToggleFieldTypeLegend = useCallback(() => { dispatch(shouldShowFieldTypeLegendChanged(!shouldShowFieldTypeLegend)); }, [shouldShowFieldTypeLegend, dispatch]); @@ -79,20 +70,6 @@ const ViewportControls = () => { icon={} /> - - } - /> - ( - + + + + ); diff --git a/invokeai/frontend/web/src/features/nodes/components/panels/MinimapPanel.tsx b/invokeai/frontend/web/src/features/nodes/components/editorPanels/MinimapPanel.tsx similarity index 91% rename from invokeai/frontend/web/src/features/nodes/components/panels/MinimapPanel.tsx rename to invokeai/frontend/web/src/features/nodes/components/editorPanels/MinimapPanel.tsx index 39142ed48e..8b7fb942a6 100644 --- a/invokeai/frontend/web/src/features/nodes/components/panels/MinimapPanel.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/editorPanels/MinimapPanel.tsx @@ -20,7 +20,7 @@ const MinimapPanel = () => { const nodeColor = useColorModeValue( 'var(--invokeai-colors-accent-300)', - 'var(--invokeai-colors-accent-700)' + 'var(--invokeai-colors-accent-600)' ); const maskColor = useColorModeValue( @@ -32,10 +32,9 @@ const MinimapPanel = () => { <> {shouldShowMinimapPanel && ( { return ( @@ -15,9 +14,8 @@ const TopCenterPanel = () => { - - + ); diff --git a/invokeai/frontend/web/src/features/nodes/components/panels/TopLeftPanel.tsx b/invokeai/frontend/web/src/features/nodes/components/editorPanels/TopLeftPanel.tsx similarity index 100% rename from invokeai/frontend/web/src/features/nodes/components/panels/TopLeftPanel.tsx rename to invokeai/frontend/web/src/features/nodes/components/editorPanels/TopLeftPanel.tsx diff --git a/invokeai/frontend/web/src/features/nodes/components/panels/TopRightPanel.tsx b/invokeai/frontend/web/src/features/nodes/components/editorPanels/TopRightPanel.tsx similarity index 55% rename from invokeai/frontend/web/src/features/nodes/components/panels/TopRightPanel.tsx rename to invokeai/frontend/web/src/features/nodes/components/editorPanels/TopRightPanel.tsx index e3e3a871c8..7facf3973f 100644 --- a/invokeai/frontend/web/src/features/nodes/components/panels/TopRightPanel.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/editorPanels/TopRightPanel.tsx @@ -1,22 +1,16 @@ -import { RootState } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { memo } from 'react'; import { Panel } from 'reactflow'; import FieldTypeLegend from '../FieldTypeLegend'; -import NodeGraphOverlay from '../NodeGraphOverlay'; const TopRightPanel = () => { - const shouldShowGraphOverlay = useAppSelector( - (state: RootState) => state.nodes.shouldShowGraphOverlay - ); const shouldShowFieldTypeLegend = useAppSelector( - (state: RootState) => state.nodes.shouldShowFieldTypeLegend + (state) => state.nodes.shouldShowFieldTypeLegend ); return ( {shouldShowFieldTypeLegend && } - {shouldShowGraphOverlay && } ); }; diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ArrayInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ArrayInputFieldComponent.tsx deleted file mode 100644 index 8e478c907c..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/fields/ArrayInputFieldComponent.tsx +++ /dev/null @@ -1,15 +0,0 @@ -import { - ArrayInputFieldTemplate, - ArrayInputFieldValue, -} from 'features/nodes/types/types'; -import { memo } from 'react'; -import { FaList } from 'react-icons/fa'; -import { FieldComponentProps } from './types'; - -const ArrayInputFieldComponent = ( - _props: FieldComponentProps -) => { - return ; -}; - -export default memo(ArrayInputFieldComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/EnumInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/EnumInputFieldComponent.tsx deleted file mode 100644 index 5f26bc4f2a..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/fields/EnumInputFieldComponent.tsx +++ /dev/null @@ -1,37 +0,0 @@ -import { Select } from '@chakra-ui/react'; -import { useAppDispatch } from 'app/store/storeHooks'; -import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; -import { - EnumInputFieldTemplate, - EnumInputFieldValue, -} from 'features/nodes/types/types'; -import { ChangeEvent, memo } from 'react'; -import { FieldComponentProps } from './types'; - -const EnumInputFieldComponent = ( - props: FieldComponentProps -) => { - const { nodeId, field, template } = props; - - const dispatch = useAppDispatch(); - - const handleValueChanged = (e: ChangeEvent) => { - dispatch( - fieldValueChanged({ - nodeId, - fieldName: field.name, - value: e.target.value, - }) - ); - }; - - return ( - - ); -}; - -export default memo(EnumInputFieldComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/FieldContextMenu.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/FieldContextMenu.tsx new file mode 100644 index 0000000000..d9f8f951bc --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/fields/FieldContextMenu.tsx @@ -0,0 +1,47 @@ +import { MenuItem, MenuList } from '@chakra-ui/react'; +import { ContextMenu, ContextMenuProps } from 'chakra-ui-contextmenu'; +import { + InputFieldTemplate, + InputFieldValue, +} from 'features/nodes/types/types'; +import { MouseEvent, useCallback } from 'react'; +import { menuListMotionProps } from 'theme/components/menu'; + +type Props = { + nodeId: string; + field: InputFieldValue; + fieldTemplate: InputFieldTemplate; + children: ContextMenuProps['children']; +}; + +const FieldContextMenu = (props: Props) => { + const skipEvent = useCallback((e: MouseEvent) => { + e.preventDefault(); + }, []); + + return ( + + menuProps={{ + size: 'sm', + isLazy: true, + }} + menuButtonProps={{ + bg: 'transparent', + _hover: { bg: 'transparent' }, + }} + renderMenu={() => ( + + Test + + )} + > + {props.children} + + ); +}; + +export default FieldContextMenu; diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/FieldHandle.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/FieldHandle.tsx new file mode 100644 index 0000000000..f47e68976d --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/fields/FieldHandle.tsx @@ -0,0 +1,122 @@ +import { Tooltip } from '@chakra-ui/react'; +import { CSSProperties, memo, useMemo } from 'react'; +import { Handle, HandleType, NodeProps, Position } from 'reactflow'; +import { + FIELDS, + HANDLE_TOOLTIP_OPEN_DELAY, + colorTokenToCssVar, +} from '../../types/constants'; +import { + InputFieldTemplate, + InputFieldValue, + InvocationNodeData, + InvocationTemplate, + OutputFieldTemplate, + OutputFieldValue, +} from '../../types/types'; + +export const handleBaseStyles: CSSProperties = { + position: 'absolute', + width: '1rem', + height: '1rem', + borderWidth: 0, + zIndex: 1, +}; + +export const inputHandleStyles: CSSProperties = { + left: '-1rem', +}; + +export const outputHandleStyles: CSSProperties = { + right: '-0.5rem', +}; + +type FieldHandleProps = { + nodeProps: NodeProps; + nodeTemplate: InvocationTemplate; + field: InputFieldValue | OutputFieldValue; + fieldTemplate: InputFieldTemplate | OutputFieldTemplate; + handleType: HandleType; + isConnectionInProgress: boolean; + isConnectionStartField: boolean; + connectionError: string | null; +}; + +const FieldHandle = (props: FieldHandleProps) => { + const { + fieldTemplate, + handleType, + isConnectionInProgress, + isConnectionStartField, + connectionError, + } = props; + const { name, type } = fieldTemplate; + const { color, title } = FIELDS[type]; + + const styles: CSSProperties = useMemo(() => { + const s: CSSProperties = { + backgroundColor: colorTokenToCssVar(color), + position: 'absolute', + width: '1rem', + height: '1rem', + borderWidth: 0, + zIndex: 1, + }; + + if (handleType === 'target') { + s.insetInlineStart = '-1rem'; + } else { + s.insetInlineEnd = '-1rem'; + } + + if (isConnectionInProgress && !isConnectionStartField && connectionError) { + s.filter = 'opacity(0.4) grayscale(0.7)'; + } + + if (isConnectionInProgress && connectionError) { + if (isConnectionStartField) { + s.cursor = 'grab'; + } else { + s.cursor = 'not-allowed'; + } + } else { + s.cursor = 'crosshair'; + } + + return s; + }, [ + color, + connectionError, + handleType, + isConnectionInProgress, + isConnectionStartField, + ]); + + const tooltip = useMemo(() => { + if (isConnectionInProgress && isConnectionStartField) { + return title; + } + if (isConnectionInProgress && connectionError) { + return connectionError ?? title; + } + return title; + }, [connectionError, isConnectionInProgress, isConnectionStartField, title]); + + return ( + + + + ); +}; + +export default memo(FieldHandle); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/FieldTitle.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/FieldTitle.tsx new file mode 100644 index 0000000000..fc239addf3 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/fields/FieldTitle.tsx @@ -0,0 +1,161 @@ +import { + Editable, + EditableInput, + EditablePreview, + Flex, + useEditableControls, +} from '@chakra-ui/react'; +import { useAppDispatch } from 'app/store/storeHooks'; +import IAIDraggable from 'common/components/IAIDraggable'; +import { NodeFieldDraggableData } from 'features/dnd/types'; +import { fieldLabelChanged } from 'features/nodes/store/nodesSlice'; +import { + InputFieldTemplate, + InputFieldValue, + InvocationNodeData, + InvocationTemplate, +} from 'features/nodes/types/types'; +import { + MouseEvent, + memo, + useCallback, + useEffect, + useMemo, + useState, +} from 'react'; + +interface Props { + nodeData: InvocationNodeData; + nodeTemplate: InvocationTemplate; + field: InputFieldValue; + fieldTemplate: InputFieldTemplate; + isDraggable?: boolean; +} + +const FieldTitle = (props: Props) => { + const { nodeData, field, fieldTemplate, isDraggable = false } = props; + const { label } = field; + const { title, input } = fieldTemplate; + const { id: nodeId } = nodeData; + const dispatch = useAppDispatch(); + const [localTitle, setLocalTitle] = useState(label || title); + + const draggableData: NodeFieldDraggableData | undefined = useMemo( + () => + input !== 'connection' && isDraggable + ? { + id: `${nodeId}-${field.name}`, + payloadType: 'NODE_FIELD', + payload: { nodeId, field, fieldTemplate }, + } + : undefined, + [field, fieldTemplate, input, isDraggable, nodeId] + ); + + const handleSubmit = useCallback( + async (newTitle: string) => { + dispatch( + fieldLabelChanged({ nodeId, fieldName: field.name, label: newTitle }) + ); + setLocalTitle(newTitle || title); + }, + [dispatch, nodeId, field.name, title] + ); + + const handleChange = useCallback((newTitle: string) => { + setLocalTitle(newTitle); + }, []); + + useEffect(() => { + // Another component may change the title; sync local title with global state + setLocalTitle(label || title); + }, [label, title]); + + return ( + + + + + + + + ); +}; + +export default memo(FieldTitle); + +type EditableControlsProps = { + draggableData?: NodeFieldDraggableData; +}; + +function EditableControls(props: EditableControlsProps) { + const { isEditing, getEditButtonProps } = useEditableControls(); + const handleDoubleClick = useCallback( + (e: MouseEvent) => { + const { onClick } = getEditButtonProps(); + if (!onClick) { + return; + } + onClick(e); + }, + [getEditButtonProps] + ); + + if (isEditing) { + return null; + } + + if (props.draggableData) { + return ( + + ); + } + + return ( + + ); +} diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/FieldTooltipContent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/FieldTooltipContent.tsx new file mode 100644 index 0000000000..bf5cd3cd9b --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/fields/FieldTooltipContent.tsx @@ -0,0 +1,41 @@ +import { Flex, Text } from '@chakra-ui/react'; +import { FIELDS } from 'features/nodes/types/constants'; +import { + InputFieldTemplate, + InputFieldValue, + InvocationNodeData, + InvocationTemplate, + OutputFieldTemplate, + OutputFieldValue, + isInputFieldTemplate, + isInputFieldValue, +} from 'features/nodes/types/types'; +import { startCase } from 'lodash-es'; + +interface Props { + nodeData: InvocationNodeData; + nodeTemplate: InvocationTemplate; + field: InputFieldValue | OutputFieldValue; + fieldTemplate: InputFieldTemplate | OutputFieldTemplate; +} + +const FieldTooltipContent = ({ field, fieldTemplate }: Props) => { + const isInputTemplate = isInputFieldTemplate(fieldTemplate); + + return ( + + + {isInputFieldValue(field) && field.label + ? `${field.label} (${fieldTemplate.title})` + : fieldTemplate.title} + + + {fieldTemplate.description} + + Type: {FIELDS[fieldTemplate.type].title} + {isInputTemplate && Input: {startCase(fieldTemplate.input)}} + + ); +}; + +export default FieldTooltipContent; diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/InputField.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/InputField.tsx new file mode 100644 index 0000000000..67f4369384 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/fields/InputField.tsx @@ -0,0 +1,153 @@ +import { Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react'; +import { useConnectionState } from 'features/nodes/hooks/useConnectionState'; +import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants'; +import { + InputFieldValue, + InvocationNodeData, + InvocationTemplate, +} from 'features/nodes/types/types'; +import { PropsWithChildren, useMemo } from 'react'; +import { NodeProps } from 'reactflow'; +import FieldHandle from './FieldHandle'; +import FieldTitle from './FieldTitle'; +import FieldTooltipContent from './FieldTooltipContent'; +import InputFieldRenderer from './InputFieldRenderer'; + +interface Props { + nodeProps: NodeProps; + nodeTemplate: InvocationTemplate; + field: InputFieldValue; +} + +const InputField = (props: Props) => { + const { nodeProps, nodeTemplate, field } = props; + const { id: nodeId } = nodeProps.data; + + const { + isConnected, + isConnectionInProgress, + isConnectionStartField, + connectionError, + shouldDim, + } = useConnectionState({ nodeId, field, kind: 'input' }); + + const fieldTemplate = useMemo( + () => nodeTemplate.inputs[field.name], + [field.name, nodeTemplate.inputs] + ); + + const isMissingInput = useMemo(() => { + if (!fieldTemplate) { + return false; + } + + if (!fieldTemplate.required) { + return false; + } + + if (!isConnected && fieldTemplate.input === 'connection') { + return true; + } + + if (!field.value && !isConnected && fieldTemplate.input === 'any') { + return true; + } + }, [fieldTemplate, isConnected, field.value]); + + if (!fieldTemplate) { + return ( + + + Unknown input: {field.name} + + + ); + } + + return ( + + + + } + openDelay={HANDLE_TOOLTIP_OPEN_DELAY} + placement="top" + shouldWrapChildren + hasArrow + > + + + + + + + + {fieldTemplate.input !== 'direct' && ( + + )} + + ); +}; + +export default InputField; + +type InputFieldWrapperProps = PropsWithChildren<{ + shouldDim: boolean; +}>; + +const InputFieldWrapper = ({ shouldDim, children }: InputFieldWrapperProps) => ( + + {children} + +); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/InputFieldRenderer.tsx new file mode 100644 index 0000000000..ce9d88af0a --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/fields/InputFieldRenderer.tsx @@ -0,0 +1,293 @@ +import { Box } from '@chakra-ui/react'; +import { memo } from 'react'; +import { + InputFieldTemplate, + InputFieldValue, + InvocationNodeData, + InvocationTemplate, +} from '../../types/types'; +import BooleanInputField from './fieldTypes/BooleanInputField'; +import ClipInputField from './fieldTypes/ClipInputField'; +import CollectionInputField from './fieldTypes/CollectionInputField'; +import CollectionItemInputField from './fieldTypes/CollectionItemInputField'; +import ColorInputField from './fieldTypes/ColorInputField'; +import ConditioningInputField from './fieldTypes/ConditioningInputField'; +import ControlInputField from './fieldTypes/ControlInputField'; +import ControlNetModelInputField from './fieldTypes/ControlNetModelInputField'; +import EnumInputField from './fieldTypes/EnumInputField'; +import ImageCollectionInputField from './fieldTypes/ImageCollectionInputField'; +import ImageInputField from './fieldTypes/ImageInputField'; +import LatentsInputField from './fieldTypes/LatentsInputField'; +import LoRAModelInputField from './fieldTypes/LoRAModelInputField'; +import MainModelInputField from './fieldTypes/MainModelInputField'; +import NumberInputField from './fieldTypes/NumberInputField'; +import RefinerModelInputField from './fieldTypes/RefinerModelInputField'; +import SDXLMainModelInputField from './fieldTypes/SDXLMainModelInputField'; +import StringInputField from './fieldTypes/StringInputField'; +import UnetInputField from './fieldTypes/UnetInputField'; +import VaeInputField from './fieldTypes/VaeInputField'; +import VaeModelInputField from './fieldTypes/VaeModelInputField'; + +type InputFieldProps = { + nodeData: InvocationNodeData; + nodeTemplate: InvocationTemplate; + field: InputFieldValue; + fieldTemplate: InputFieldTemplate; +}; + +// build an individual input element based on the schema +const InputFieldRenderer = (props: InputFieldProps) => { + const { nodeData, nodeTemplate, field, fieldTemplate } = props; + const { type } = field; + + if (type === 'string' && fieldTemplate.type === 'string') { + return ( + + ); + } + + if (type === 'boolean' && fieldTemplate.type === 'boolean') { + return ( + + ); + } + + if ( + (type === 'integer' && fieldTemplate.type === 'integer') || + (type === 'float' && fieldTemplate.type === 'float') || + (type === 'Seed' && fieldTemplate.type === 'Seed') + ) { + return ( + + ); + } + + if (type === 'enum' && fieldTemplate.type === 'enum') { + return ( + + ); + } + + if (type === 'ImageField' && fieldTemplate.type === 'ImageField') { + return ( + + ); + } + + if (type === 'LatentsField' && fieldTemplate.type === 'LatentsField') { + return ( + + ); + } + + if ( + type === 'ConditioningField' && + fieldTemplate.type === 'ConditioningField' + ) { + return ( + + ); + } + + if (type === 'UNetField' && fieldTemplate.type === 'UNetField') { + return ( + + ); + } + + if (type === 'ClipField' && fieldTemplate.type === 'ClipField') { + return ( + + ); + } + + if (type === 'VaeField' && fieldTemplate.type === 'VaeField') { + return ( + + ); + } + + if (type === 'ControlField' && fieldTemplate.type === 'ControlField') { + return ( + + ); + } + + if (type === 'MainModelField' && fieldTemplate.type === 'MainModelField') { + return ( + + ); + } + + if ( + type === 'SDXLRefinerModelField' && + fieldTemplate.type === 'SDXLRefinerModelField' + ) { + return ( + + ); + } + + if (type === 'VaeModelField' && fieldTemplate.type === 'VaeModelField') { + return ( + + ); + } + + if (type === 'LoRAModelField' && fieldTemplate.type === 'LoRAModelField') { + return ( + + ); + } + + if ( + type === 'ControlNetModelField' && + fieldTemplate.type === 'ControlNetModelField' + ) { + return ( + + ); + } + + if (type === 'Collection' && fieldTemplate.type === 'Collection') { + return ( + + ); + } + + if (type === 'CollectionItem' && fieldTemplate.type === 'CollectionItem') { + return ( + + ); + } + + if (type === 'ColorField' && fieldTemplate.type === 'ColorField') { + return ( + + ); + } + + if (type === 'ImageCollection' && fieldTemplate.type === 'ImageCollection') { + return ( + + ); + } + + if ( + type === 'SDXLMainModelField' && + fieldTemplate.type === 'SDXLMainModelField' + ) { + return ( + + ); + } + + return Unknown field type: {type}; +}; + +export default memo(InputFieldRenderer); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/ItemInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/ItemInputFieldComponent.tsx deleted file mode 100644 index 6fa89345bf..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/fields/ItemInputFieldComponent.tsx +++ /dev/null @@ -1,15 +0,0 @@ -import { - ItemInputFieldTemplate, - ItemInputFieldValue, -} from 'features/nodes/types/types'; -import { memo } from 'react'; -import { FaAddressCard } from 'react-icons/fa'; -import { FieldComponentProps } from './types'; - -const ItemInputFieldComponent = ( - _props: FieldComponentProps -) => { - return ; -}; - -export default memo(ItemInputFieldComponent); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/LinearViewField.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/LinearViewField.tsx new file mode 100644 index 0000000000..98a8000b1a --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/fields/LinearViewField.tsx @@ -0,0 +1,88 @@ +import { Flex, FormControl, FormLabel, Tooltip } from '@chakra-ui/react'; +import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants'; +import { + InputFieldTemplate, + InputFieldValue, + InvocationNodeData, + InvocationTemplate, +} from 'features/nodes/types/types'; +import { memo } from 'react'; +import FieldTitle from './FieldTitle'; +import FieldTooltipContent from './FieldTooltipContent'; +import InputFieldRenderer from './InputFieldRenderer'; + +type Props = { + nodeData: InvocationNodeData; + nodeTemplate: InvocationTemplate; + field: InputFieldValue; + fieldTemplate: InputFieldTemplate; +}; + +const LinearViewField = ({ + nodeData, + nodeTemplate, + field, + fieldTemplate, +}: Props) => { + // const dispatch = useAppDispatch(); + // const handleRemoveField = useCallback(() => { + // dispatch( + // workflowExposedFieldRemoved({ + // nodeId: nodeData.id, + // fieldName: field.name, + // }) + // ); + // }, [dispatch, field.name, nodeData.id]); + + return ( + + + + } + openDelay={HANDLE_TOOLTIP_OPEN_DELAY} + placement="top" + shouldWrapChildren + hasArrow + > + + + + + + + + ); +}; + +export default memo(LinearViewField); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/OutputField.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/OutputField.tsx new file mode 100644 index 0000000000..5a29d1ab7e --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/fields/OutputField.tsx @@ -0,0 +1,114 @@ +import { + Flex, + FormControl, + FormLabel, + Spacer, + Tooltip, +} from '@chakra-ui/react'; +import { useConnectionState } from 'features/nodes/hooks/useConnectionState'; +import { HANDLE_TOOLTIP_OPEN_DELAY } from 'features/nodes/types/constants'; +import { + InvocationNodeData, + InvocationTemplate, + OutputFieldValue, +} from 'features/nodes/types/types'; +import { PropsWithChildren, useMemo } from 'react'; +import { NodeProps } from 'reactflow'; +import FieldHandle from './FieldHandle'; +import FieldTooltipContent from './FieldTooltipContent'; + +interface Props { + nodeProps: NodeProps; + nodeTemplate: InvocationTemplate; + field: OutputFieldValue; +} + +const OutputField = (props: Props) => { + const { nodeTemplate, nodeProps, field } = props; + + const { + isConnected, + isConnectionInProgress, + isConnectionStartField, + connectionError, + shouldDim, + } = useConnectionState({ nodeId: nodeProps.data.id, field, kind: 'output' }); + + const fieldTemplate = useMemo( + () => nodeTemplate.outputs[field.name], + [field.name, nodeTemplate] + ); + + if (!fieldTemplate) { + return ( + + + Unknown output: {field.name} + + + ); + } + + return ( + + + + } + openDelay={HANDLE_TOOLTIP_OPEN_DELAY} + placement="top" + shouldWrapChildren + hasArrow + > + + + {fieldTemplate?.title} + + + + + + ); +}; + +export default OutputField; + +type OutputFieldWrapperProps = PropsWithChildren<{ + shouldDim: boolean; +}>; + +const OutputFieldWrapper = ({ + shouldDim, + children, +}: OutputFieldWrapperProps) => ( + + {children} + +); diff --git a/invokeai/frontend/web/src/features/nodes/components/fields/StringInputFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/fields/StringInputFieldComponent.tsx deleted file mode 100644 index 18cf7e997f..0000000000 --- a/invokeai/frontend/web/src/features/nodes/components/fields/StringInputFieldComponent.tsx +++ /dev/null @@ -1,36 +0,0 @@ -import { Input, Textarea } from '@chakra-ui/react'; -import { useAppDispatch } from 'app/store/storeHooks'; -import { fieldValueChanged } from 'features/nodes/store/nodesSlice'; -import { - StringInputFieldTemplate, - StringInputFieldValue, -} from 'features/nodes/types/types'; -import { ChangeEvent, memo } from 'react'; -import { FieldComponentProps } from './types'; - -const StringInputFieldComponent = ( - props: FieldComponentProps -) => { - const { nodeId, field } = props; - const dispatch = useAppDispatch(); - - const handleValueChanged = ( - e: ChangeEvent - ) => { - dispatch( - fieldValueChanged({ - nodeId, - fieldName: field.name, - value: e.target.value, - }) - ); - }; - - return ['prompt', 'style'].includes(field.name.toLowerCase()) ? ( -