mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(ui): recreate edge auto-add-node logic
This commit is contained in:
parent
2c1fa30639
commit
4d68cd8dbb
@ -4,11 +4,22 @@ import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
|
|||||||
import { Combobox, Flex, Popover, PopoverAnchor, PopoverBody, PopoverContent } from '@invoke-ai/ui-library';
|
import { Combobox, Flex, Popover, PopoverAnchor, PopoverBody, PopoverContent } from '@invoke-ai/ui-library';
|
||||||
import { useStore } from '@nanostores/react';
|
import { useStore } from '@nanostores/react';
|
||||||
import { useAppToaster } from 'app/components/Toaster';
|
import { useAppToaster } from 'app/components/Toaster';
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppStore } from 'app/store/storeHooks';
|
||||||
import type { SelectInstance } from 'chakra-react-select';
|
import type { SelectInstance } from 'chakra-react-select';
|
||||||
import { useBuildNode } from 'features/nodes/hooks/useBuildNode';
|
import { useBuildNode } from 'features/nodes/hooks/useBuildNode';
|
||||||
import { $templates, addNodePopoverClosed, addNodePopoverOpened, nodeAdded } from 'features/nodes/store/nodesSlice';
|
import {
|
||||||
|
$isAddNodePopoverOpen,
|
||||||
|
$pendingConnection,
|
||||||
|
$templates,
|
||||||
|
closeAddNodePopover,
|
||||||
|
connectionMade,
|
||||||
|
nodeAdded,
|
||||||
|
openAddNodePopover,
|
||||||
|
} from 'features/nodes/store/nodesSlice';
|
||||||
|
import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle';
|
||||||
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
|
import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes';
|
||||||
|
import type { AnyNode } from 'features/nodes/types/invocation';
|
||||||
|
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||||
import { filter, map, memoize, some } from 'lodash-es';
|
import { filter, map, memoize, some } from 'lodash-es';
|
||||||
import type { KeyboardEventHandler } from 'react';
|
import type { KeyboardEventHandler } from 'react';
|
||||||
import { memo, useCallback, useMemo, useRef } from 'react';
|
import { memo, useCallback, useMemo, useRef } from 'react';
|
||||||
@ -17,6 +28,7 @@ import { useHotkeys } from 'react-hotkeys-hook';
|
|||||||
import type { HotkeyCallback } from 'react-hotkeys-hook/dist/types';
|
import type { HotkeyCallback } from 'react-hotkeys-hook/dist/types';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import type { FilterOptionOption } from 'react-select/dist/declarations/src/filters';
|
import type { FilterOptionOption } from 'react-select/dist/declarations/src/filters';
|
||||||
|
import { assert } from 'tsafe';
|
||||||
|
|
||||||
const createRegex = memoize(
|
const createRegex = memoize(
|
||||||
(inputValue: string) =>
|
(inputValue: string) =>
|
||||||
@ -50,26 +62,29 @@ const AddNodePopover = () => {
|
|||||||
const selectRef = useRef<SelectInstance<ComboboxOption> | null>(null);
|
const selectRef = useRef<SelectInstance<ComboboxOption> | null>(null);
|
||||||
const inputRef = useRef<HTMLInputElement>(null);
|
const inputRef = useRef<HTMLInputElement>(null);
|
||||||
const templates = useStore($templates);
|
const templates = useStore($templates);
|
||||||
|
const pendingConnection = useStore($pendingConnection);
|
||||||
|
const isOpen = useStore($isAddNodePopoverOpen);
|
||||||
|
const store = useAppStore();
|
||||||
|
|
||||||
const fieldFilter = useAppSelector((s) => s.nodes.present.connectionStartFieldType);
|
const filteredTemplates = useMemo(() => {
|
||||||
const handleFilter = useAppSelector((s) => s.nodes.present.connectionStartParams?.handleType);
|
// If we have a connection in progress, we need to filter the node choices
|
||||||
|
if (!pendingConnection) {
|
||||||
|
return map(templates);
|
||||||
|
}
|
||||||
|
|
||||||
|
return filter(templates, (template) => {
|
||||||
|
const pendingFieldKind = pendingConnection.fieldTemplate.fieldKind;
|
||||||
|
const fields = pendingFieldKind === 'input' ? template.outputs : template.inputs;
|
||||||
|
return some(fields, (field) => {
|
||||||
|
const sourceType = pendingFieldKind === 'input' ? field.type : pendingConnection.fieldTemplate.type;
|
||||||
|
const targetType = pendingFieldKind === 'output' ? field.type : pendingConnection.fieldTemplate.type;
|
||||||
|
return validateSourceAndTargetTypes(sourceType, targetType);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}, [templates, pendingConnection]);
|
||||||
|
|
||||||
const options = useMemo(() => {
|
const options = useMemo(() => {
|
||||||
// If we have a connection in progress, we need to filter the node choices
|
const _options: ComboboxOption[] = map(filteredTemplates, (template) => {
|
||||||
const filteredNodeTemplates = fieldFilter
|
|
||||||
? filter(templates, (template) => {
|
|
||||||
const handles = handleFilter === 'source' ? template.inputs : template.outputs;
|
|
||||||
|
|
||||||
return some(handles, (handle) => {
|
|
||||||
const sourceType = handleFilter === 'source' ? fieldFilter : handle.type;
|
|
||||||
const targetType = handleFilter === 'target' ? fieldFilter : handle.type;
|
|
||||||
|
|
||||||
return validateSourceAndTargetTypes(sourceType, targetType);
|
|
||||||
});
|
|
||||||
})
|
|
||||||
: map(templates);
|
|
||||||
|
|
||||||
const options: ComboboxOption[] = map(filteredNodeTemplates, (template) => {
|
|
||||||
return {
|
return {
|
||||||
label: template.title,
|
label: template.title,
|
||||||
value: template.type,
|
value: template.type,
|
||||||
@ -79,15 +94,15 @@ const AddNodePopover = () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
//We only want these nodes if we're not filtered
|
//We only want these nodes if we're not filtered
|
||||||
if (fieldFilter === null) {
|
if (!pendingConnection) {
|
||||||
options.push({
|
_options.push({
|
||||||
label: t('nodes.currentImage'),
|
label: t('nodes.currentImage'),
|
||||||
value: 'current_image',
|
value: 'current_image',
|
||||||
description: t('nodes.currentImageDescription'),
|
description: t('nodes.currentImageDescription'),
|
||||||
tags: ['progress'],
|
tags: ['progress'],
|
||||||
});
|
});
|
||||||
|
|
||||||
options.push({
|
_options.push({
|
||||||
label: t('nodes.notes'),
|
label: t('nodes.notes'),
|
||||||
value: 'notes',
|
value: 'notes',
|
||||||
description: t('nodes.notesDescription'),
|
description: t('nodes.notesDescription'),
|
||||||
@ -95,15 +110,13 @@ const AddNodePopover = () => {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
options.sort((a, b) => a.label.localeCompare(b.label));
|
_options.sort((a, b) => a.label.localeCompare(b.label));
|
||||||
|
|
||||||
return options;
|
return _options;
|
||||||
}, [fieldFilter, handleFilter, t, templates]);
|
}, [filteredTemplates, pendingConnection, t]);
|
||||||
|
|
||||||
const isOpen = useAppSelector((s) => s.nodes.present.isAddNodePopoverOpen);
|
|
||||||
|
|
||||||
const addNode = useCallback(
|
const addNode = useCallback(
|
||||||
(nodeType: string) => {
|
(nodeType: string): AnyNode | null => {
|
||||||
const invocation = buildInvocation(nodeType);
|
const invocation = buildInvocation(nodeType);
|
||||||
if (!invocation) {
|
if (!invocation) {
|
||||||
const errorMessage = t('nodes.unknownNode', {
|
const errorMessage = t('nodes.unknownNode', {
|
||||||
@ -113,10 +126,11 @@ const AddNodePopover = () => {
|
|||||||
status: 'error',
|
status: 'error',
|
||||||
title: errorMessage,
|
title: errorMessage,
|
||||||
});
|
});
|
||||||
return;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
dispatch(nodeAdded(invocation));
|
dispatch(nodeAdded(invocation));
|
||||||
|
return invocation;
|
||||||
},
|
},
|
||||||
[dispatch, buildInvocation, toaster, t]
|
[dispatch, buildInvocation, toaster, t]
|
||||||
);
|
);
|
||||||
@ -126,52 +140,50 @@ const AddNodePopover = () => {
|
|||||||
if (!v) {
|
if (!v) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
addNode(v.value);
|
const node = addNode(v.value);
|
||||||
dispatch(addNodePopoverClosed());
|
|
||||||
|
// Auto-connect an edge if we just added a node and have a pending connection
|
||||||
|
if (pendingConnection && isInvocationNode(node)) {
|
||||||
|
const template = templates[node.data.type];
|
||||||
|
assert(template, 'Template not found');
|
||||||
|
const { nodes, edges } = store.getState().nodes.present;
|
||||||
|
const connection = getFirstValidConnection(nodes, edges, pendingConnection, node, template);
|
||||||
|
if (connection) {
|
||||||
|
dispatch(connectionMade(connection));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
closeAddNodePopover();
|
||||||
},
|
},
|
||||||
[addNode, dispatch]
|
[addNode, dispatch, pendingConnection, store, templates]
|
||||||
);
|
);
|
||||||
|
|
||||||
const onClose = useCallback(() => {
|
const handleHotkeyOpen: HotkeyCallback = useCallback((e) => {
|
||||||
dispatch(addNodePopoverClosed());
|
e.preventDefault();
|
||||||
}, [dispatch]);
|
openAddNodePopover();
|
||||||
|
flushSync(() => {
|
||||||
const onOpen = useCallback(() => {
|
selectRef.current?.inputRef?.focus();
|
||||||
dispatch(addNodePopoverOpened());
|
});
|
||||||
}, [dispatch]);
|
}, []);
|
||||||
|
|
||||||
const handleHotkeyOpen: HotkeyCallback = useCallback(
|
|
||||||
(e) => {
|
|
||||||
e.preventDefault();
|
|
||||||
onOpen();
|
|
||||||
flushSync(() => {
|
|
||||||
selectRef.current?.inputRef?.focus();
|
|
||||||
});
|
|
||||||
},
|
|
||||||
[onOpen]
|
|
||||||
);
|
|
||||||
|
|
||||||
const handleHotkeyClose: HotkeyCallback = useCallback(() => {
|
const handleHotkeyClose: HotkeyCallback = useCallback(() => {
|
||||||
onClose();
|
closeAddNodePopover();
|
||||||
}, [onClose]);
|
}, []);
|
||||||
|
|
||||||
useHotkeys(['shift+a', 'space'], handleHotkeyOpen);
|
useHotkeys(['shift+a', 'space'], handleHotkeyOpen);
|
||||||
useHotkeys(['escape'], handleHotkeyClose);
|
useHotkeys(['escape'], handleHotkeyClose);
|
||||||
const onKeyDown: KeyboardEventHandler = useCallback(
|
const onKeyDown: KeyboardEventHandler = useCallback((e) => {
|
||||||
(e) => {
|
if (e.key === 'Escape') {
|
||||||
if (e.key === 'Escape') {
|
closeAddNodePopover();
|
||||||
onClose();
|
}
|
||||||
}
|
}, []);
|
||||||
},
|
|
||||||
[onClose]
|
|
||||||
);
|
|
||||||
|
|
||||||
const noOptionsMessage = useCallback(() => t('nodes.noMatchingNodes'), [t]);
|
const noOptionsMessage = useCallback(() => t('nodes.noMatchingNodes'), [t]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<Popover
|
<Popover
|
||||||
isOpen={isOpen}
|
isOpen={isOpen}
|
||||||
onClose={onClose}
|
onClose={closeAddNodePopover}
|
||||||
placement="bottom"
|
placement="bottom"
|
||||||
openDelay={0}
|
openDelay={0}
|
||||||
closeDelay={0}
|
closeDelay={0}
|
||||||
@ -201,7 +213,7 @@ const AddNodePopover = () => {
|
|||||||
noOptionsMessage={noOptionsMessage}
|
noOptionsMessage={noOptionsMessage}
|
||||||
filterOption={filterOption}
|
filterOption={filterOption}
|
||||||
onChange={onChange}
|
onChange={onChange}
|
||||||
onMenuClose={onClose}
|
onMenuClose={closeAddNodePopover}
|
||||||
onKeyDown={onKeyDown}
|
onKeyDown={onKeyDown}
|
||||||
inputRef={inputRef}
|
inputRef={inputRef}
|
||||||
closeMenuOnSelect={false}
|
closeMenuOnSelect={false}
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
|
import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
|
||||||
import { useStore } from '@nanostores/react';
|
|
||||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||||
import { useConnection } from 'features/nodes/hooks/useConnection';
|
import { useConnection } from 'features/nodes/hooks/useConnection';
|
||||||
import { useCopyPaste } from 'features/nodes/hooks/useCopyPaste';
|
import { useCopyPaste } from 'features/nodes/hooks/useCopyPaste';
|
||||||
@ -7,7 +6,6 @@ import { useIsValidConnection } from 'features/nodes/hooks/useIsValidConnection'
|
|||||||
import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher';
|
import { useWorkflowWatcher } from 'features/nodes/hooks/useWorkflowWatcher';
|
||||||
import {
|
import {
|
||||||
$cursorPos,
|
$cursorPos,
|
||||||
$pendingConnection,
|
|
||||||
connectionMade,
|
connectionMade,
|
||||||
edgeAdded,
|
edgeAdded,
|
||||||
edgeChangeStarted,
|
edgeChangeStarted,
|
||||||
@ -100,36 +98,6 @@ export const Flow = memo(() => {
|
|||||||
[dispatch]
|
[dispatch]
|
||||||
);
|
);
|
||||||
|
|
||||||
// const onConnectStart: OnConnectStart = useCallback(
|
|
||||||
// (event, params) => {
|
|
||||||
// dispatch(connectionStarted(params));
|
|
||||||
// },
|
|
||||||
// [dispatch]
|
|
||||||
// );
|
|
||||||
|
|
||||||
// const onConnect: OnConnect = useCallback(
|
|
||||||
// (connection) => {
|
|
||||||
// dispatch(connectionMade(connection));
|
|
||||||
// },
|
|
||||||
// [dispatch]
|
|
||||||
// );
|
|
||||||
|
|
||||||
// const onConnectEnd: OnConnectEnd = useCallback(() => {
|
|
||||||
// const cursorPosition = $cursorPos.get();
|
|
||||||
// if (!cursorPosition) {
|
|
||||||
// return;
|
|
||||||
// }
|
|
||||||
// dispatch(
|
|
||||||
// connectionEnded({
|
|
||||||
// cursorPosition,
|
|
||||||
// mouseOverNodeId: $mouseOverNode.get(),
|
|
||||||
// })
|
|
||||||
// );
|
|
||||||
// }, [dispatch]);
|
|
||||||
|
|
||||||
const pendingConnection = useStore($pendingConnection);
|
|
||||||
console.log(pendingConnection)
|
|
||||||
|
|
||||||
const onEdgesDelete: OnEdgesDelete = useCallback(
|
const onEdgesDelete: OnEdgesDelete = useCallback(
|
||||||
(edges) => {
|
(edges) => {
|
||||||
dispatch(edgesDeleted(edges));
|
dispatch(edgesDeleted(edges));
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import { useStore } from '@nanostores/react';
|
import { useStore } from '@nanostores/react';
|
||||||
import { useAppStore } from 'app/store/storeHooks';
|
import { useAppStore } from 'app/store/storeHooks';
|
||||||
import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
import { $mouseOverNode } from 'features/nodes/hooks/useMouseOverNode';
|
||||||
import { $pendingConnection, $templates, connectionMade } from 'features/nodes/store/nodesSlice';
|
import { $isAddNodePopoverOpen, $pendingConnection, $templates, connectionMade } from 'features/nodes/store/nodesSlice';
|
||||||
import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle';
|
import { getFirstValidConnection } from 'features/nodes/store/util/findConnectionToValidHandle';
|
||||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||||
import { useCallback, useMemo } from 'react';
|
import { useCallback, useMemo } from 'react';
|
||||||
@ -59,8 +59,11 @@ export const useConnection = () => {
|
|||||||
if (connection) {
|
if (connection) {
|
||||||
dispatch(connectionMade(connection));
|
dispatch(connectionMade(connection));
|
||||||
}
|
}
|
||||||
|
$pendingConnection.set(null);
|
||||||
|
} else {
|
||||||
|
// The mouse is not over a node - we should open the add node popover
|
||||||
|
$isAddNodePopoverOpen.set(true);
|
||||||
}
|
}
|
||||||
$pendingConnection.set(null);
|
|
||||||
}, [store, templates]);
|
}, [store, templates]);
|
||||||
|
|
||||||
const api = useMemo(() => ({ onConnectStart, onConnect, onConnectEnd }), [onConnectStart, onConnect, onConnectEnd]);
|
const api = useMemo(() => ({ onConnectStart, onConnect, onConnectEnd }), [onConnectStart, onConnect, onConnectEnd]);
|
||||||
|
@ -760,6 +760,14 @@ export const $copiedNodes = atom<AnyNode[]>([]);
|
|||||||
export const $copiedEdges = atom<InvocationNodeEdge[]>([]);
|
export const $copiedEdges = atom<InvocationNodeEdge[]>([]);
|
||||||
export const $pendingConnection = atom<PendingConnection | null>(null);
|
export const $pendingConnection = atom<PendingConnection | null>(null);
|
||||||
export const $isModifyingEdge = atom(false);
|
export const $isModifyingEdge = atom(false);
|
||||||
|
export const $isAddNodePopoverOpen = atom(false);
|
||||||
|
export const closeAddNodePopover = () => {
|
||||||
|
$isAddNodePopoverOpen.set(false);
|
||||||
|
$pendingConnection.set(null);
|
||||||
|
};
|
||||||
|
export const openAddNodePopover = () => {
|
||||||
|
$isAddNodePopoverOpen.set(true);
|
||||||
|
};
|
||||||
|
|
||||||
export const selectNodesSlice = (state: RootState) => state.nodes.present;
|
export const selectNodesSlice = (state: RootState) => state.nodes.present;
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user