feat(ui): recreate edge auto-add-node logic

This commit is contained in:
psychedelicious 2024-05-16 19:17:56 +10:00
parent 2c1fa30639
commit 4d68cd8dbb
4 changed files with 87 additions and 96 deletions

View File

@ -4,11 +4,22 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, Flex, Popover, PopoverAnchor, PopoverBody, PopoverContent } from '@invoke-ai/ui-library'; import { Combobox, Flex, Popover, PopoverAnchor, PopoverBody, PopoverContent } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react'; import { useStore } from '@nanostores/react';
import { useAppToaster } from 'app/components/Toaster'; import { useAppToaster } from 'app/components/Toaster';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppStore } from 'app/store/storeHooks';
import type { SelectInstance } from 'chakra-react-select'; import type { SelectInstance } from 'chakra-react-select';
import { useBuildNode } from 'features/nodes/hooks/useBuildNode'; import { useBuildNode } from 'features/nodes/hooks/useBuildNode';
import { $templates, addNodePopoverClosed, addNodePopoverOpened, nodeAdded } from 'features/nodes/store/nodesSlice'; import {
$isAddNodePopoverOpen,
$pendingConnection,
$templates,
closeAddNodePopover,
connectionMade,
nodeAdded,
openAddNodePopover,
} from 'features/nodes/store/nodesSlice';
import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle';
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes'; import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
import type { AnyNode } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { filter, map, memoize, some } from 'lodash-es'; import { filter, map, memoize, some } from 'lodash-es';
import type { KeyboardEventHandler } from 'react'; import type { KeyboardEventHandler } from 'react';
import { memo, useCallback, useMemo, useRef } from 'react'; import { memo, useCallback, useMemo, useRef } from 'react';
@ -17,6 +28,7 @@ import { useHotkeys } from 'react-hotkeys-hook';
import type { HotkeyCallback } from 'react-hotkeys-hook/dist/types'; import type { HotkeyCallback } from 'react-hotkeys-hook/dist/types';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import type { FilterOptionOption } from 'react-select/dist/declarations/src/filters'; import type { FilterOptionOption } from 'react-select/dist/declarations/src/filters';
import { assert } from 'tsafe';
const createRegex = memoize( const createRegex = memoize(
(inputValue: string) => (inputValue: string) =>
@ -50,26 +62,29 @@ const AddNodePopover = () => {
const selectRef = useRef<SelectInstance<ComboboxOption> | null>(null); const selectRef = useRef<SelectInstance<ComboboxOption> | null>(null);
const inputRef = useRef<HTMLInputElement>(null); const inputRef = useRef<HTMLInputElement>(null);
const templates = useStore($templates); const templates = useStore($templates);
const pendingConnection = useStore($pendingConnection);
const isOpen = useStore($isAddNodePopoverOpen);
const store = useAppStore();
const fieldFilter = useAppSelector((s) => s.nodes.present.connectionStartFieldType); const filteredTemplates = useMemo(() => {
const handleFilter = useAppSelector((s) => s.nodes.present.connectionStartParams?.handleType); // If we have a connection in progress, we need to filter the node choices
if (!pendingConnection) {
return map(templates);
}
return filter(templates, (template) => {
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind;
const fields = pendingFieldKind === 'input' ? template.outputs : template.inputs;
return some(fields, (field) => {
const sourceType = pendingFieldKind === 'input' ? field.type : pendingConnection.fieldTemplate.type;
const targetType = pendingFieldKind === 'output' ? field.type : pendingConnection.fieldTemplate.type;
return validateSourceAndTargetTypes(sourceType, targetType);
});
});
}, [templates, pendingConnection]);
const options = useMemo(() => { const options = useMemo(() => {
// If we have a connection in progress, we need to filter the node choices const _options: ComboboxOption[] = map(filteredTemplates, (template) => {
const filteredNodeTemplates = fieldFilter
? filter(templates, (template) => {
const handles = handleFilter === 'source' ? template.inputs : template.outputs;
return some(handles, (handle) => {
const sourceType = handleFilter === 'source' ? fieldFilter : handle.type;
const targetType = handleFilter === 'target' ? fieldFilter : handle.type;
return validateSourceAndTargetTypes(sourceType, targetType);
});
})
: map(templates);
const options: ComboboxOption[] = map(filteredNodeTemplates, (template) => {
return { return {
label: template.title, label: template.title,
value: template.type, value: template.type,
@ -79,15 +94,15 @@ const AddNodePopover = () => {
}); });
//We only want these nodes if we're not filtered //We only want these nodes if we're not filtered
if (fieldFilter === null) { if (!pendingConnection) {
options.push({ _options.push({
label: t('nodes.currentImage'), label: t('nodes.currentImage'),
value: 'current_image', value: 'current_image',
description: t('nodes.currentImageDescription'), description: t('nodes.currentImageDescription'),
tags: ['progress'], tags: ['progress'],
}); });
options.push({ _options.push({
label: t('nodes.notes'), label: t('nodes.notes'),
value: 'notes', value: 'notes',
description: t('nodes.notesDescription'), description: t('nodes.notesDescription'),
@ -95,15 +110,13 @@ const AddNodePopover = () => {
}); });
} }
options.sort((a, b) => a.label.localeCompare(b.label)); _options.sort((a, b) => a.label.localeCompare(b.label));
return options; return _options;
}, [fieldFilter, handleFilter, t, templates]); }, [filteredTemplates, pendingConnection, t]);
const isOpen = useAppSelector((s) => s.nodes.present.isAddNodePopoverOpen);
const addNode = useCallback( const addNode = useCallback(
(nodeType: string) => { (nodeType: string): AnyNode | null => {
const invocation = buildInvocation(nodeType); const invocation = buildInvocation(nodeType);
if (!invocation) { if (!invocation) {
const errorMessage = t('nodes.unknownNode', { const errorMessage = t('nodes.unknownNode', {
@ -113,10 +126,11 @@ const AddNodePopover = () => {
status: 'error', status: 'error',
title: errorMessage, title: errorMessage,
}); });
return; return null;
} }
dispatch(nodeAdded(invocation)); dispatch(nodeAdded(invocation));
return invocation;
}, },
[dispatch, buildInvocation, toaster, t] [dispatch, buildInvocation, toaster, t]
); );
@ -126,52 +140,50 @@ const AddNodePopover = () => {
if (!v) { if (!v) {
return; return;
} }
addNode(v.value); const node = addNode(v.value);
dispatch(addNodePopoverClosed());
// Auto-connect an edge if we just added a node and have a pending connection
if (pendingConnection && isInvocationNode(node)) {
const template = templates[node.data.type];
assert(template, 'Template not found');
const { nodes, edges } = store.getState().nodes.present;
const connection = getFirstValidConnection(nodes, edges, pendingConnection, node, template);
if (connection) {
dispatch(connectionMade(connection));
}
}
closeAddNodePopover();
}, },
[addNode, dispatch] [addNode, dispatch, pendingConnection, store, templates]
); );
const onClose = useCallback(() => { const handleHotkeyOpen: HotkeyCallback = useCallback((e) => {
dispatch(addNodePopoverClosed()); e.preventDefault();
}, [dispatch]); openAddNodePopover();
flushSync(() => {
const onOpen = useCallback(() => { selectRef.current?.inputRef?.focus();
dispatch(addNodePopoverOpened()); });
}, [dispatch]); }, []);
const handleHotkeyOpen: HotkeyCallback = useCallback(
(e) => {
e.preventDefault();
onOpen();
flushSync(() => {
selectRef.current?.inputRef?.focus();
});
},
[onOpen]
);
const handleHotkeyClose: HotkeyCallback = useCallback(() => { const handleHotkeyClose: HotkeyCallback = useCallback(() => {
onClose(); closeAddNodePopover();
}, [onClose]); }, []);
useHotkeys(['shift+a', 'space'], handleHotkeyOpen); useHotkeys(['shift+a', 'space'], handleHotkeyOpen);
useHotkeys(['escape'], handleHotkeyClose); useHotkeys(['escape'], handleHotkeyClose);
const onKeyDown: KeyboardEventHandler = useCallback( const onKeyDown: KeyboardEventHandler = useCallback((e) => {
(e) => { if (e.key === 'Escape') {
if (e.key === 'Escape') { closeAddNodePopover();
onClose(); }
} }, []);
},
[onClose]
);
const noOptionsMessage = useCallback(() => t('nodes.noMatchingNodes'), [t]); const noOptionsMessage = useCallback(() => t('nodes.noMatchingNodes'), [t]);
return ( return (
<Popover <Popover
isOpen={isOpen} isOpen={isOpen}
onClose={onClose} onClose={closeAddNodePopover}
placement="bottom" placement="bottom"
openDelay={0} openDelay={0}
closeDelay={0} closeDelay={0}
@ -201,7 +213,7 @@ const AddNodePopover = () => {
noOptionsMessage={noOptionsMessage} noOptionsMessage={noOptionsMessage}
filterOption={filterOption} filterOption={filterOption}
onChange={onChange} onChange={onChange}
onMenuClose={onClose} onMenuClose={closeAddNodePopover}
onKeyDown={onKeyDown} onKeyDown={onKeyDown}
inputRef={inputRef} inputRef={inputRef}
closeMenuOnSelect={false} closeMenuOnSelect={false}

View File

@ -1,5 +1,4 @@
import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library'; import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useConnection } from 'features/nodes/hooks/useConnection'; import { useConnection } from 'features/nodes/hooks/useConnection';
import { useCopyPaste } from 'features/nodes/hooks/useCopyPaste'; import { useCopyPaste } from 'features/nodes/hooks/useCopyPaste';
@ -7,7 +6,6 @@ import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection'
import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher'; import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher';
import { import {
$cursorPos, $cursorPos,
$pendingConnection,
connectionMade, connectionMade,
edgeAdded, edgeAdded,
edgeChangeStarted, edgeChangeStarted,
@ -100,36 +98,6 @@ export const Flow = memo(() => {
[dispatch] [dispatch]
); );
// const onConnectStart: OnConnectStart = useCallback(
// (event, params) => {
// dispatch(connectionStarted(params));
// },
// [dispatch]
// );
// const onConnect: OnConnect = useCallback(
// (connection) => {
// dispatch(connectionMade(connection));
// },
// [dispatch]
// );
// const onConnectEnd: OnConnectEnd = useCallback(() => {
// const cursorPosition = $cursorPos.get();
// if (!cursorPosition) {
// return;
// }
// dispatch(
// connectionEnded({
// cursorPosition,
// mouseOverNodeId: $mouseOverNode.get(),
// })
// );
// }, [dispatch]);
const pendingConnection = useStore($pendingConnection);
console.log(pendingConnection)
const onEdgesDelete: OnEdgesDelete = useCallback( const onEdgesDelete: OnEdgesDelete = useCallback(
(edges) => { (edges) => {
dispatch(edgesDeleted(edges)); dispatch(edgesDeleted(edges));

View File

@ -1,7 +1,7 @@
import { useStore } from '@nanostores/react'; import { useStore } from '@nanostores/react';
import { useAppStore } from 'app/store/storeHooks'; import { useAppStore } from 'app/store/storeHooks';
import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode'; import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
import { $pendingConnection, $templates, connectionMade } from 'features/nodes/store/nodesSlice'; import { $isAddNodePopoverOpen, $pendingConnection, $templates, connectionMade } from 'features/nodes/store/nodesSlice';
import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle'; import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle';
import { isInvocationNode } from 'features/nodes/types/invocation'; import { isInvocationNode } from 'features/nodes/types/invocation';
import { useCallback, useMemo } from 'react'; import { useCallback, useMemo } from 'react';
@ -59,8 +59,11 @@ export const useConnection = () => {
if (connection) { if (connection) {
dispatch(connectionMade(connection)); dispatch(connectionMade(connection));
} }
$pendingConnection.set(null);
} else {
// The mouse is not over a node - we should open the add node popover
$isAddNodePopoverOpen.set(true);
} }
$pendingConnection.set(null);
}, [store, templates]); }, [store, templates]);
const api = useMemo(() => ({ onConnectStart, onConnect, onConnectEnd }), [onConnectStart, onConnect, onConnectEnd]); const api = useMemo(() => ({ onConnectStart, onConnect, onConnectEnd }), [onConnectStart, onConnect, onConnectEnd]);

View File

@ -760,6 +760,14 @@ export const $copiedNodes = atom<AnyNode[]>([]);
export const $copiedEdges = atom<InvocationNodeEdge[]>([]); export const $copiedEdges = atom<InvocationNodeEdge[]>([]);
export const $pendingConnection = atom<PendingConnection | null>(null); export const $pendingConnection = atom<PendingConnection | null>(null);
export const $isModifyingEdge = atom(false); export const $isModifyingEdge = atom(false);
export const $isAddNodePopoverOpen = atom(false);
export const closeAddNodePopover = () => {
$isAddNodePopoverOpen.set(false);
$pendingConnection.set(null);
};
export const openAddNodePopover = () => {
$isAddNodePopoverOpen.set(true);
};
export const selectNodesSlice = (state: RootState) => state.nodes.present; export const selectNodesSlice = (state: RootState) => state.nodes.present;