feat(ui): recreate edge auto-add-node logic

This commit is contained in:
psychedelicious 2024-05-16 19:17:56 +10:00
parent 2c1fa30639
commit 4d68cd8dbb
4 changed files with 87 additions and 96 deletions

View File

@ -4,11 +4,22 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, Flex, Popover, PopoverAnchor, PopoverBody, PopoverContent } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppToaster } from 'app/components/Toaster';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAppDispatch, useAppStore } from 'app/store/storeHooks';
import type { SelectInstance } from 'chakra-react-select';
import { useBuildNode } from 'features/nodes/hooks/useBuildNode';
import { $templates, addNodePopoverClosed, addNodePopoverOpened, nodeAdded } from 'features/nodes/store/nodesSlice';
import {
$isAddNodePopoverOpen,
$pendingConnection,
$templates,
closeAddNodePopover,
connectionMade,
nodeAdded,
openAddNodePopover,
} from 'features/nodes/store/nodesSlice';
import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle';
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
import type { AnyNode } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { filter, map, memoize, some } from 'lodash-es';
import type { KeyboardEventHandler } from 'react';
import { memo, useCallback, useMemo, useRef } from 'react';
@ -17,6 +28,7 @@ import { useHotkeys } from 'react-hotkeys-hook';
import type { HotkeyCallback } from 'react-hotkeys-hook/dist/types';
import { useTranslation } from 'react-i18next';
import type { FilterOptionOption } from 'react-select/dist/declarations/src/filters';
import { assert } from 'tsafe';
const createRegex = memoize(
(inputValue: string) =>
@ -50,26 +62,29 @@ const AddNodePopover = () => {
const selectRef = useRef<SelectInstance<ComboboxOption> | null>(null);
const inputRef = useRef<HTMLInputElement>(null);
const templates = useStore($templates);
const pendingConnection = useStore($pendingConnection);
const isOpen = useStore($isAddNodePopoverOpen);
const store = useAppStore();
const fieldFilter = useAppSelector((s) => s.nodes.present.connectionStartFieldType);
const handleFilter = useAppSelector((s) => s.nodes.present.connectionStartParams?.handleType);
const filteredTemplates = useMemo(() => {
// If we have a connection in progress, we need to filter the node choices
if (!pendingConnection) {
return map(templates);
}
return filter(templates, (template) => {
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind;
const fields = pendingFieldKind === 'input' ? template.outputs : template.inputs;
return some(fields, (field) => {
const sourceType = pendingFieldKind === 'input' ? field.type : pendingConnection.fieldTemplate.type;
const targetType = pendingFieldKind === 'output' ? field.type : pendingConnection.fieldTemplate.type;
return validateSourceAndTargetTypes(sourceType, targetType);
});
});
}, [templates, pendingConnection]);
const options = useMemo(() => {
// If we have a connection in progress, we need to filter the node choices
const filteredNodeTemplates = fieldFilter
? filter(templates, (template) => {
const handles = handleFilter === 'source' ? template.inputs : template.outputs;
return some(handles, (handle) => {
const sourceType = handleFilter === 'source' ? fieldFilter : handle.type;
const targetType = handleFilter === 'target' ? fieldFilter : handle.type;
return validateSourceAndTargetTypes(sourceType, targetType);
});
})
: map(templates);
const options: ComboboxOption[] = map(filteredNodeTemplates, (template) => {
const _options: ComboboxOption[] = map(filteredTemplates, (template) => {
return {
label: template.title,
value: template.type,
@ -79,15 +94,15 @@ const AddNodePopover = () => {
});
//We only want these nodes if we're not filtered
if (fieldFilter === null) {
options.push({
if (!pendingConnection) {
_options.push({
label: t('nodes.currentImage'),
value: 'current_image',
description: t('nodes.currentImageDescription'),
tags: ['progress'],
});
options.push({
_options.push({
label: t('nodes.notes'),
value: 'notes',
description: t('nodes.notesDescription'),
@ -95,15 +110,13 @@ const AddNodePopover = () => {
});
}
options.sort((a, b) => a.label.localeCompare(b.label));
_options.sort((a, b) => a.label.localeCompare(b.label));
return options;
}, [fieldFilter, handleFilter, t, templates]);
const isOpen = useAppSelector((s) => s.nodes.present.isAddNodePopoverOpen);
return _options;
}, [filteredTemplates, pendingConnection, t]);
const addNode = useCallback(
(nodeType: string) => {
(nodeType: string): AnyNode | null => {
const invocation = buildInvocation(nodeType);
if (!invocation) {
const errorMessage = t('nodes.unknownNode', {
@ -113,10 +126,11 @@ const AddNodePopover = () => {
status: 'error',
title: errorMessage,
});
return;
return null;
}
dispatch(nodeAdded(invocation));
return invocation;
},
[dispatch, buildInvocation, toaster, t]
);
@ -126,52 +140,50 @@ const AddNodePopover = () => {
if (!v) {
return;
}
addNode(v.value);
dispatch(addNodePopoverClosed());
const node = addNode(v.value);
// Auto-connect an edge if we just added a node and have a pending connection
if (pendingConnection && isInvocationNode(node)) {
const template = templates[node.data.type];
assert(template, 'Template not found');
const { nodes, edges } = store.getState().nodes.present;
const connection = getFirstValidConnection(nodes, edges, pendingConnection, node, template);
if (connection) {
dispatch(connectionMade(connection));
}
}
closeAddNodePopover();
},
[addNode, dispatch]
[addNode, dispatch, pendingConnection, store, templates]
);
const onClose = useCallback(() => {
dispatch(addNodePopoverClosed());
}, [dispatch]);
const onOpen = useCallback(() => {
dispatch(addNodePopoverOpened());
}, [dispatch]);
const handleHotkeyOpen: HotkeyCallback = useCallback(
(e) => {
e.preventDefault();
onOpen();
flushSync(() => {
selectRef.current?.inputRef?.focus();
});
},
[onOpen]
);
const handleHotkeyOpen: HotkeyCallback = useCallback((e) => {
e.preventDefault();
openAddNodePopover();
flushSync(() => {
selectRef.current?.inputRef?.focus();
});
}, []);
const handleHotkeyClose: HotkeyCallback = useCallback(() => {
onClose();
}, [onClose]);
closeAddNodePopover();
}, []);
useHotkeys(['shift+a', 'space'], handleHotkeyOpen);
useHotkeys(['escape'], handleHotkeyClose);
const onKeyDown: KeyboardEventHandler = useCallback(
(e) => {
if (e.key === 'Escape') {
onClose();
}
},
[onClose]
);
const onKeyDown: KeyboardEventHandler = useCallback((e) => {
if (e.key === 'Escape') {
closeAddNodePopover();
}
}, []);
const noOptionsMessage = useCallback(() => t('nodes.noMatchingNodes'), [t]);
return (
<Popover
isOpen={isOpen}
onClose={onClose}
onClose={closeAddNodePopover}
placement="bottom"
openDelay={0}
closeDelay={0}
@ -201,7 +213,7 @@ const AddNodePopover = () => {
noOptionsMessage={noOptionsMessage}
filterOption={filterOption}
onChange={onChange}
onMenuClose={onClose}
onMenuClose={closeAddNodePopover}
onKeyDown={onKeyDown}
inputRef={inputRef}
closeMenuOnSelect={false}

View File

@ -1,5 +1,4 @@
import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useConnection } from 'features/nodes/hooks/useConnection';
import { useCopyPaste } from 'features/nodes/hooks/useCopyPaste';
@ -7,7 +6,6 @@ import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection'
import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher';
import {
$cursorPos,
$pendingConnection,
connectionMade,
edgeAdded,
edgeChangeStarted,
@ -100,36 +98,6 @@ export const Flow = memo(() => {
[dispatch]
);
// const onConnectStart: OnConnectStart = useCallback(
// (event, params) => {
// dispatch(connectionStarted(params));
// },
// [dispatch]
// );
// const onConnect: OnConnect = useCallback(
// (connection) => {
// dispatch(connectionMade(connection));
// },
// [dispatch]
// );
// const onConnectEnd: OnConnectEnd = useCallback(() => {
// const cursorPosition = $cursorPos.get();
// if (!cursorPosition) {
// return;
// }
// dispatch(
// connectionEnded({
// cursorPosition,
// mouseOverNodeId: $mouseOverNode.get(),
// })
// );
// }, [dispatch]);
const pendingConnection = useStore($pendingConnection);
console.log(pendingConnection)
const onEdgesDelete: OnEdgesDelete = useCallback(
(edges) => {
dispatch(edgesDeleted(edges));

View File

@ -1,7 +1,7 @@
import { useStore } from '@nanostores/react';
import { useAppStore } from 'app/store/storeHooks';
import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
import { $pendingConnection, $templates, connectionMade } from 'features/nodes/store/nodesSlice';
import { $isAddNodePopoverOpen, $pendingConnection, $templates, connectionMade } from 'features/nodes/store/nodesSlice';
import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { useCallback, useMemo } from 'react';
@ -59,8 +59,11 @@ export const useConnection = () => {
if (connection) {
dispatch(connectionMade(connection));
}
$pendingConnection.set(null);
} else {
// The mouse is not over a node - we should open the add node popover
$isAddNodePopoverOpen.set(true);
}
$pendingConnection.set(null);
}, [store, templates]);
const api = useMemo(() => ({ onConnectStart, onConnect, onConnectEnd }), [onConnectStart, onConnect, onConnectEnd]);

View File

@ -760,6 +760,14 @@ export const $copiedNodes = atom<AnyNode[]>([]);
export const $copiedEdges = atom<InvocationNodeEdge[]>([]);
export const $pendingConnection = atom<PendingConnection | null>(null);
export const $isModifyingEdge = atom(false);
export const $isAddNodePopoverOpen = atom(false);
export const closeAddNodePopover = () => {
$isAddNodePopoverOpen.set(false);
$pendingConnection.set(null);
};
export const openAddNodePopover = () => {
$isAddNodePopoverOpen.set(true);
};
export const selectNodesSlice = (state: RootState) => state.nodes.present;