From 4d68cd8dbb1518b6f932910e89076a0f9b651779 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 16 May 2024 19:17:56 +1000 Subject: [PATCH] feat(ui): recreate edge auto-add-node logic --- .../flow/AddNodePopover/AddNodePopover.tsx | 136 ++++++++++-------- .../features/nodes/components/flow/Flow.tsx | 32 ----- .../src/features/nodes/hooks/useConnection.ts | 7 +- .../src/features/nodes/store/nodesSlice.ts | 8 ++ 4 files changed, 87 insertions(+), 96 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx index 6d33905f4c..d9602a9679 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx @@ -4,11 +4,22 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; import { Combobox, Flex, Popover, PopoverAnchor, PopoverBody, PopoverContent } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; import { useAppToaster } from 'app/components/Toaster'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useAppDispatch, useAppStore } from 'app/store/storeHooks'; import type { SelectInstance } from 'chakra-react-select'; import { useBuildNode } from 'features/nodes/hooks/useBuildNode'; -import { $templates, addNodePopoverClosed, addNodePopoverOpened, nodeAdded } from 'features/nodes/store/nodesSlice'; +import { + $isAddNodePopoverOpen, + $pendingConnection, + $templates, + closeAddNodePopover, + connectionMade, + nodeAdded, + openAddNodePopover, +} from 'features/nodes/store/nodesSlice'; +import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle'; import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes'; +import type { AnyNode } from 'features/nodes/types/invocation'; +import { isInvocationNode } from 'features/nodes/types/invocation'; import { filter, map, memoize, some } from 'lodash-es'; import type { KeyboardEventHandler } from 'react'; import { memo, useCallback, useMemo, useRef } from 'react'; @@ -17,6 +28,7 @@ import { useHotkeys } from 'react-hotkeys-hook'; import type { HotkeyCallback } from 'react-hotkeys-hook/dist/types'; import { useTranslation } from 'react-i18next'; import type { FilterOptionOption } from 'react-select/dist/declarations/src/filters'; +import { assert } from 'tsafe'; const createRegex = memoize( (inputValue: string) => @@ -50,26 +62,29 @@ const AddNodePopover = () => { const selectRef = useRef | null>(null); const inputRef = useRef(null); const templates = useStore($templates); + const pendingConnection = useStore($pendingConnection); + const isOpen = useStore($isAddNodePopoverOpen); + const store = useAppStore(); - const fieldFilter = useAppSelector((s) => s.nodes.present.connectionStartFieldType); - const handleFilter = useAppSelector((s) => s.nodes.present.connectionStartParams?.handleType); + const filteredTemplates = useMemo(() => { + // If we have a connection in progress, we need to filter the node choices + if (!pendingConnection) { + return map(templates); + } + + return filter(templates, (template) => { + const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind; + const fields = pendingFieldKind === 'input' ? template.outputs : template.inputs; + return some(fields, (field) => { + const sourceType = pendingFieldKind === 'input' ? field.type : pendingConnection.fieldTemplate.type; + const targetType = pendingFieldKind === 'output' ? field.type : pendingConnection.fieldTemplate.type; + return validateSourceAndTargetTypes(sourceType, targetType); + }); + }); + }, [templates, pendingConnection]); const options = useMemo(() => { - // If we have a connection in progress, we need to filter the node choices - const filteredNodeTemplates = fieldFilter - ? filter(templates, (template) => { - const handles = handleFilter === 'source' ? template.inputs : template.outputs; - - return some(handles, (handle) => { - const sourceType = handleFilter === 'source' ? fieldFilter : handle.type; - const targetType = handleFilter === 'target' ? fieldFilter : handle.type; - - return validateSourceAndTargetTypes(sourceType, targetType); - }); - }) - : map(templates); - - const options: ComboboxOption[] = map(filteredNodeTemplates, (template) => { + const _options: ComboboxOption[] = map(filteredTemplates, (template) => { return { label: template.title, value: template.type, @@ -79,15 +94,15 @@ const AddNodePopover = () => { }); //We only want these nodes if we're not filtered - if (fieldFilter === null) { - options.push({ + if (!pendingConnection) { + _options.push({ label: t('nodes.currentImage'), value: 'current_image', description: t('nodes.currentImageDescription'), tags: ['progress'], }); - options.push({ + _options.push({ label: t('nodes.notes'), value: 'notes', description: t('nodes.notesDescription'), @@ -95,15 +110,13 @@ const AddNodePopover = () => { }); } - options.sort((a, b) => a.label.localeCompare(b.label)); + _options.sort((a, b) => a.label.localeCompare(b.label)); - return options; - }, [fieldFilter, handleFilter, t, templates]); - - const isOpen = useAppSelector((s) => s.nodes.present.isAddNodePopoverOpen); + return _options; + }, [filteredTemplates, pendingConnection, t]); const addNode = useCallback( - (nodeType: string) => { + (nodeType: string): AnyNode | null => { const invocation = buildInvocation(nodeType); if (!invocation) { const errorMessage = t('nodes.unknownNode', { @@ -113,10 +126,11 @@ const AddNodePopover = () => { status: 'error', title: errorMessage, }); - return; + return null; } dispatch(nodeAdded(invocation)); + return invocation; }, [dispatch, buildInvocation, toaster, t] ); @@ -126,52 +140,50 @@ const AddNodePopover = () => { if (!v) { return; } - addNode(v.value); - dispatch(addNodePopoverClosed()); + const node = addNode(v.value); + + // Auto-connect an edge if we just added a node and have a pending connection + if (pendingConnection && isInvocationNode(node)) { + const template = templates[node.data.type]; + assert(template, 'Template not found'); + const { nodes, edges } = store.getState().nodes.present; + const connection = getFirstValidConnection(nodes, edges, pendingConnection, node, template); + if (connection) { + dispatch(connectionMade(connection)); + } + } + + closeAddNodePopover(); }, - [addNode, dispatch] + [addNode, dispatch, pendingConnection, store, templates] ); - const onClose = useCallback(() => { - dispatch(addNodePopoverClosed()); - }, [dispatch]); - - const onOpen = useCallback(() => { - dispatch(addNodePopoverOpened()); - }, [dispatch]); - - const handleHotkeyOpen: HotkeyCallback = useCallback( - (e) => { - e.preventDefault(); - onOpen(); - flushSync(() => { - selectRef.current?.inputRef?.focus(); - }); - }, - [onOpen] - ); + const handleHotkeyOpen: HotkeyCallback = useCallback((e) => { + e.preventDefault(); + openAddNodePopover(); + flushSync(() => { + selectRef.current?.inputRef?.focus(); + }); + }, []); const handleHotkeyClose: HotkeyCallback = useCallback(() => { - onClose(); - }, [onClose]); + closeAddNodePopover(); + }, []); useHotkeys(['shift+a', 'space'], handleHotkeyOpen); useHotkeys(['escape'], handleHotkeyClose); - const onKeyDown: KeyboardEventHandler = useCallback( - (e) => { - if (e.key === 'Escape') { - onClose(); - } - }, - [onClose] - ); + const onKeyDown: KeyboardEventHandler = useCallback((e) => { + if (e.key === 'Escape') { + closeAddNodePopover(); + } + }, []); const noOptionsMessage = useCallback(() => t('nodes.noMatchingNodes'), [t]); return ( { noOptionsMessage={noOptionsMessage} filterOption={filterOption} onChange={onChange} - onMenuClose={onClose} + onMenuClose={closeAddNodePopover} onKeyDown={onKeyDown} inputRef={inputRef} closeMenuOnSelect={false} diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx index 16634d2ada..fa04d60bca 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx @@ -1,5 +1,4 @@ import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library'; -import { useStore } from '@nanostores/react'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useConnection } from 'features/nodes/hooks/useConnection'; import { useCopyPaste } from 'features/nodes/hooks/useCopyPaste'; @@ -7,7 +6,6 @@ import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection' import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher'; import { $cursorPos, - $pendingConnection, connectionMade, edgeAdded, edgeChangeStarted, @@ -100,36 +98,6 @@ export const Flow = memo(() => { [dispatch] ); - // const onConnectStart: OnConnectStart = useCallback( - // (event, params) => { - // dispatch(connectionStarted(params)); - // }, - // [dispatch] - // ); - - // const onConnect: OnConnect = useCallback( - // (connection) => { - // dispatch(connectionMade(connection)); - // }, - // [dispatch] - // ); - - // const onConnectEnd: OnConnectEnd = useCallback(() => { - // const cursorPosition = $cursorPos.get(); - // if (!cursorPosition) { - // return; - // } - // dispatch( - // connectionEnded({ - // cursorPosition, - // mouseOverNodeId: $mouseOverNode.get(), - // }) - // ); - // }, [dispatch]); - - const pendingConnection = useStore($pendingConnection); - console.log(pendingConnection) - const onEdgesDelete: OnEdgesDelete = useCallback( (edges) => { dispatch(edgesDeleted(edges)); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts index fb6212cd26..468a0bd645 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts @@ -1,7 +1,7 @@ import { useStore } from '@nanostores/react'; import { useAppStore } from 'app/store/storeHooks'; import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode'; -import { $pendingConnection, $templates, connectionMade } from 'features/nodes/store/nodesSlice'; +import { $isAddNodePopoverOpen, $pendingConnection, $templates, connectionMade } from 'features/nodes/store/nodesSlice'; import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { useCallback, useMemo } from 'react'; @@ -59,8 +59,11 @@ export const useConnection = () => { if (connection) { dispatch(connectionMade(connection)); } + $pendingConnection.set(null); + } else { + // The mouse is not over a node - we should open the add node popover + $isAddNodePopoverOpen.set(true); } - $pendingConnection.set(null); }, [store, templates]); const api = useMemo(() => ({ onConnectStart, onConnect, onConnectEnd }), [onConnectStart, onConnect, onConnectEnd]); diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 0418515f5d..4ba4e2c0fe 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -760,6 +760,14 @@ export const $copiedNodes = atom([]); export const $copiedEdges = atom([]); export const $pendingConnection = atom(null); export const $isModifyingEdge = atom(false); +export const $isAddNodePopoverOpen = atom(false); +export const closeAddNodePopover = () => { + $isAddNodePopoverOpen.set(false); + $pendingConnection.set(null); +}; +export const openAddNodePopover = () => { + $isAddNodePopoverOpen.set(true); +}; export const selectNodesSlice = (state: RootState) => state.nodes.present;