mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): recreate edge auto-add-node logic
This commit is contained in:
parent
2c1fa30639
commit
4d68cd8dbb
@ -4,11 +4,22 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
||||
import { Combobox, Flex, Popover, PopoverAnchor, PopoverBody, PopoverContent } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppToaster } from 'app/components/Toaster';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAppDispatch, useAppStore } from 'app/store/storeHooks';
|
||||
import type { SelectInstance } from 'chakra-react-select';
|
||||
import { useBuildNode } from 'features/nodes/hooks/useBuildNode';
|
||||
import { $templates, addNodePopoverClosed, addNodePopoverOpened, nodeAdded } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
$isAddNodePopoverOpen,
|
||||
$pendingConnection,
|
||||
$templates,
|
||||
closeAddNodePopover,
|
||||
connectionMade,
|
||||
nodeAdded,
|
||||
openAddNodePopover,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle';
|
||||
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
|
||||
import type { AnyNode } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { filter, map, memoize, some } from 'lodash-es';
|
||||
import type { KeyboardEventHandler } from 'react';
|
||||
import { memo, useCallback, useMemo, useRef } from 'react';
|
||||
@ -17,6 +28,7 @@ import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import type { HotkeyCallback } from 'react-hotkeys-hook/dist/types';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import type { FilterOptionOption } from 'react-select/dist/declarations/src/filters';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const createRegex = memoize(
|
||||
(inputValue: string) =>
|
||||
@ -50,26 +62,29 @@ const AddNodePopover = () => {
|
||||
const selectRef = useRef<SelectInstance<ComboboxOption> | null>(null);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
const templates = useStore($templates);
|
||||
const pendingConnection = useStore($pendingConnection);
|
||||
const isOpen = useStore($isAddNodePopoverOpen);
|
||||
const store = useAppStore();
|
||||
|
||||
const fieldFilter = useAppSelector((s) => s.nodes.present.connectionStartFieldType);
|
||||
const handleFilter = useAppSelector((s) => s.nodes.present.connectionStartParams?.handleType);
|
||||
const filteredTemplates = useMemo(() => {
|
||||
// If we have a connection in progress, we need to filter the node choices
|
||||
if (!pendingConnection) {
|
||||
return map(templates);
|
||||
}
|
||||
|
||||
return filter(templates, (template) => {
|
||||
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind;
|
||||
const fields = pendingFieldKind === 'input' ? template.outputs : template.inputs;
|
||||
return some(fields, (field) => {
|
||||
const sourceType = pendingFieldKind === 'input' ? field.type : pendingConnection.fieldTemplate.type;
|
||||
const targetType = pendingFieldKind === 'output' ? field.type : pendingConnection.fieldTemplate.type;
|
||||
return validateSourceAndTargetTypes(sourceType, targetType);
|
||||
});
|
||||
});
|
||||
}, [templates, pendingConnection]);
|
||||
|
||||
const options = useMemo(() => {
|
||||
// If we have a connection in progress, we need to filter the node choices
|
||||
const filteredNodeTemplates = fieldFilter
|
||||
? filter(templates, (template) => {
|
||||
const handles = handleFilter === 'source' ? template.inputs : template.outputs;
|
||||
|
||||
return some(handles, (handle) => {
|
||||
const sourceType = handleFilter === 'source' ? fieldFilter : handle.type;
|
||||
const targetType = handleFilter === 'target' ? fieldFilter : handle.type;
|
||||
|
||||
return validateSourceAndTargetTypes(sourceType, targetType);
|
||||
});
|
||||
})
|
||||
: map(templates);
|
||||
|
||||
const options: ComboboxOption[] = map(filteredNodeTemplates, (template) => {
|
||||
const _options: ComboboxOption[] = map(filteredTemplates, (template) => {
|
||||
return {
|
||||
label: template.title,
|
||||
value: template.type,
|
||||
@ -79,15 +94,15 @@ const AddNodePopover = () => {
|
||||
});
|
||||
|
||||
//We only want these nodes if we're not filtered
|
||||
if (fieldFilter === null) {
|
||||
options.push({
|
||||
if (!pendingConnection) {
|
||||
_options.push({
|
||||
label: t('nodes.currentImage'),
|
||||
value: 'current_image',
|
||||
description: t('nodes.currentImageDescription'),
|
||||
tags: ['progress'],
|
||||
});
|
||||
|
||||
options.push({
|
||||
_options.push({
|
||||
label: t('nodes.notes'),
|
||||
value: 'notes',
|
||||
description: t('nodes.notesDescription'),
|
||||
@ -95,15 +110,13 @@ const AddNodePopover = () => {
|
||||
});
|
||||
}
|
||||
|
||||
options.sort((a, b) => a.label.localeCompare(b.label));
|
||||
_options.sort((a, b) => a.label.localeCompare(b.label));
|
||||
|
||||
return options;
|
||||
}, [fieldFilter, handleFilter, t, templates]);
|
||||
|
||||
const isOpen = useAppSelector((s) => s.nodes.present.isAddNodePopoverOpen);
|
||||
return _options;
|
||||
}, [filteredTemplates, pendingConnection, t]);
|
||||
|
||||
const addNode = useCallback(
|
||||
(nodeType: string) => {
|
||||
(nodeType: string): AnyNode | null => {
|
||||
const invocation = buildInvocation(nodeType);
|
||||
if (!invocation) {
|
||||
const errorMessage = t('nodes.unknownNode', {
|
||||
@ -113,10 +126,11 @@ const AddNodePopover = () => {
|
||||
status: 'error',
|
||||
title: errorMessage,
|
||||
});
|
||||
return;
|
||||
return null;
|
||||
}
|
||||
|
||||
dispatch(nodeAdded(invocation));
|
||||
return invocation;
|
||||
},
|
||||
[dispatch, buildInvocation, toaster, t]
|
||||
);
|
||||
@ -126,52 +140,50 @@ const AddNodePopover = () => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
addNode(v.value);
|
||||
dispatch(addNodePopoverClosed());
|
||||
const node = addNode(v.value);
|
||||
|
||||
// Auto-connect an edge if we just added a node and have a pending connection
|
||||
if (pendingConnection && isInvocationNode(node)) {
|
||||
const template = templates[node.data.type];
|
||||
assert(template, 'Template not found');
|
||||
const { nodes, edges } = store.getState().nodes.present;
|
||||
const connection = getFirstValidConnection(nodes, edges, pendingConnection, node, template);
|
||||
if (connection) {
|
||||
dispatch(connectionMade(connection));
|
||||
}
|
||||
}
|
||||
|
||||
closeAddNodePopover();
|
||||
},
|
||||
[addNode, dispatch]
|
||||
[addNode, dispatch, pendingConnection, store, templates]
|
||||
);
|
||||
|
||||
const onClose = useCallback(() => {
|
||||
dispatch(addNodePopoverClosed());
|
||||
}, [dispatch]);
|
||||
|
||||
const onOpen = useCallback(() => {
|
||||
dispatch(addNodePopoverOpened());
|
||||
}, [dispatch]);
|
||||
|
||||
const handleHotkeyOpen: HotkeyCallback = useCallback(
|
||||
(e) => {
|
||||
e.preventDefault();
|
||||
onOpen();
|
||||
flushSync(() => {
|
||||
selectRef.current?.inputRef?.focus();
|
||||
});
|
||||
},
|
||||
[onOpen]
|
||||
);
|
||||
const handleHotkeyOpen: HotkeyCallback = useCallback((e) => {
|
||||
e.preventDefault();
|
||||
openAddNodePopover();
|
||||
flushSync(() => {
|
||||
selectRef.current?.inputRef?.focus();
|
||||
});
|
||||
}, []);
|
||||
|
||||
const handleHotkeyClose: HotkeyCallback = useCallback(() => {
|
||||
onClose();
|
||||
}, [onClose]);
|
||||
closeAddNodePopover();
|
||||
}, []);
|
||||
|
||||
useHotkeys(['shift+a', 'space'], handleHotkeyOpen);
|
||||
useHotkeys(['escape'], handleHotkeyClose);
|
||||
const onKeyDown: KeyboardEventHandler = useCallback(
|
||||
(e) => {
|
||||
if (e.key === 'Escape') {
|
||||
onClose();
|
||||
}
|
||||
},
|
||||
[onClose]
|
||||
);
|
||||
const onKeyDown: KeyboardEventHandler = useCallback((e) => {
|
||||
if (e.key === 'Escape') {
|
||||
closeAddNodePopover();
|
||||
}
|
||||
}, []);
|
||||
|
||||
const noOptionsMessage = useCallback(() => t('nodes.noMatchingNodes'), [t]);
|
||||
|
||||
return (
|
||||
<Popover
|
||||
isOpen={isOpen}
|
||||
onClose={onClose}
|
||||
onClose={closeAddNodePopover}
|
||||
placement="bottom"
|
||||
openDelay={0}
|
||||
closeDelay={0}
|
||||
@ -201,7 +213,7 @@ const AddNodePopover = () => {
|
||||
noOptionsMessage={noOptionsMessage}
|
||||
filterOption={filterOption}
|
||||
onChange={onChange}
|
||||
onMenuClose={onClose}
|
||||
onMenuClose={closeAddNodePopover}
|
||||
onKeyDown={onKeyDown}
|
||||
inputRef={inputRef}
|
||||
closeMenuOnSelect={false}
|
||||
|
@ -1,5 +1,4 @@
|
||||
import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useConnection } from 'features/nodes/hooks/useConnection';
|
||||
import { useCopyPaste } from 'features/nodes/hooks/useCopyPaste';
|
||||
@ -7,7 +6,6 @@ import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection'
|
||||
import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher';
|
||||
import {
|
||||
$cursorPos,
|
||||
$pendingConnection,
|
||||
connectionMade,
|
||||
edgeAdded,
|
||||
edgeChangeStarted,
|
||||
@ -100,36 +98,6 @@ export const Flow = memo(() => {
|
||||
[dispatch]
|
||||
);
|
||||
|
||||
// const onConnectStart: OnConnectStart = useCallback(
|
||||
// (event, params) => {
|
||||
// dispatch(connectionStarted(params));
|
||||
// },
|
||||
// [dispatch]
|
||||
// );
|
||||
|
||||
// const onConnect: OnConnect = useCallback(
|
||||
// (connection) => {
|
||||
// dispatch(connectionMade(connection));
|
||||
// },
|
||||
// [dispatch]
|
||||
// );
|
||||
|
||||
// const onConnectEnd: OnConnectEnd = useCallback(() => {
|
||||
// const cursorPosition = $cursorPos.get();
|
||||
// if (!cursorPosition) {
|
||||
// return;
|
||||
// }
|
||||
// dispatch(
|
||||
// connectionEnded({
|
||||
// cursorPosition,
|
||||
// mouseOverNodeId: $mouseOverNode.get(),
|
||||
// })
|
||||
// );
|
||||
// }, [dispatch]);
|
||||
|
||||
const pendingConnection = useStore($pendingConnection);
|
||||
console.log(pendingConnection)
|
||||
|
||||
const onEdgesDelete: OnEdgesDelete = useCallback(
|
||||
(edges) => {
|
||||
dispatch(edgesDeleted(edges));
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
||||
import { $pendingConnection, $templates, connectionMade } from 'features/nodes/store/nodesSlice';
|
||||
import { $isAddNodePopoverOpen, $pendingConnection, $templates, connectionMade } from 'features/nodes/store/nodesSlice';
|
||||
import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
@ -59,8 +59,11 @@ export const useConnection = () => {
|
||||
if (connection) {
|
||||
dispatch(connectionMade(connection));
|
||||
}
|
||||
$pendingConnection.set(null);
|
||||
} else {
|
||||
// The mouse is not over a node - we should open the add node popover
|
||||
$isAddNodePopoverOpen.set(true);
|
||||
}
|
||||
$pendingConnection.set(null);
|
||||
}, [store, templates]);
|
||||
|
||||
const api = useMemo(() => ({ onConnectStart, onConnect, onConnectEnd }), [onConnectStart, onConnect, onConnectEnd]);
|
||||
|
@ -760,6 +760,14 @@ export const $copiedNodes = atom<AnyNode[]>([]);
|
||||
export const $copiedEdges = atom<InvocationNodeEdge[]>([]);
|
||||
export const $pendingConnection = atom<PendingConnection | null>(null);
|
||||
export const $isModifyingEdge = atom(false);
|
||||
export const $isAddNodePopoverOpen = atom(false);
|
||||
export const closeAddNodePopover = () => {
|
||||
$isAddNodePopoverOpen.set(false);
|
||||
$pendingConnection.set(null);
|
||||
};
|
||||
export const openAddNodePopover = () => {
|
||||
$isAddNodePopoverOpen.set(true);
|
||||
};
|
||||
|
||||
export const selectNodesSlice = (state: RootState) => state.nodes.present;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user