diff --git a/invokeai/frontend/web/package.json b/invokeai/frontend/web/package.json index 585d28007a..4cedfd81e8 100644 --- a/invokeai/frontend/web/package.json +++ b/invokeai/frontend/web/package.json @@ -54,6 +54,7 @@ "dateformat": "^5.0.3", "formik": "^2.2.9", "framer-motion": "^9.0.4", + "fuse.js": "^6.6.2", "i18next": "^22.4.10", "i18next-browser-languagedetector": "^7.0.1", "i18next-http-backend": "^2.1.1", diff --git a/invokeai/frontend/web/src/features/nodes/components/search/NodeSearch.tsx b/invokeai/frontend/web/src/features/nodes/components/search/NodeSearch.tsx index 9a05a6d2a0..a24688eb8a 100644 --- a/invokeai/frontend/web/src/features/nodes/components/search/NodeSearch.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/search/NodeSearch.tsx @@ -5,6 +5,7 @@ import IAIInput from 'common/components/IAIInput'; import { Panel } from 'reactflow'; import { map } from 'lodash'; import { + ChangeEvent, FocusEvent, KeyboardEvent, memo, @@ -19,6 +20,8 @@ import { useBuildInvocation } from 'features/nodes/hooks/useBuildInvocation'; import { makeToast } from 'features/system/hooks/useToastWatcher'; import { addToast } from 'features/system/store/systemSlice'; import { nodeAdded } from '../../store/nodesSlice'; +import Fuse from 'fuse.js'; +import { InvocationTemplate } from 'features/nodes/types/types'; interface NodeListItemProps { title: string; @@ -55,6 +58,9 @@ const NodeSearch = () => { ); const nodes = map(invocationTemplates); + const [filteredNodes, setFilteredNodes] = useState< + Fuse.FuseResult[] + >([]); const buildInvocation = useBuildInvocation(); const dispatch = useAppDispatch(); @@ -64,6 +70,21 @@ const NodeSearch = () => { const [focusedIndex, setFocusedIndex] = useState(-1); const nodeSearchRef = useRef(null); + const fuseOptions = { + findAllMatches: true, + threshold: 0, + ignoreLocation: true, + keys: ['title', 'type', 'tags'], + }; + + const fuse = new Fuse(nodes, fuseOptions); + + const findNode = (e: ChangeEvent) => { + setSearchText(e.target.value); + setFilteredNodes(fuse.search(e.target.value)); + setShowNodeList(true); + }; + const addNode = useCallback( (nodeType: AnyInvocationType) => { const invocation = buildInvocation(nodeType); @@ -85,8 +106,24 @@ const NodeSearch = () => { const renderNodeList = () => { const nodeListToRender: ReactNode[] = []; - nodes.forEach(({ title, description, type }, index) => { - if (title.toLowerCase().includes(searchText)) { + if (searchText.length > 0) { + filteredNodes.forEach(({ item }, index) => { + const { title, description, type } = item; + if (title.toLowerCase().includes(searchText)) { + nodeListToRender.push( + + ); + } + }); + } else { + nodes.forEach(({ title, description, type }, index) => { nodeListToRender.push( { addNode={addNode} /> ); - } else { - ; - } - }); + }); + } return ( - + {nodeListToRender} ); @@ -128,12 +150,21 @@ const NodeSearch = () => { if (key === 'ArrowDown') { setShowNodeList(true); - nextIndex = (focusedIndex + 1) % nodes.length; + if (searchText.length > 0) { + nextIndex = (focusedIndex + 1) % filteredNodes.length; + } else { + nextIndex = (focusedIndex + 1) % nodes.length; + } } if (key === 'ArrowUp') { setShowNodeList(true); - nextIndex = (focusedIndex + nodes.length - 1) % nodes.length; + if (searchText.length > 0) { + nextIndex = + (focusedIndex + filteredNodes.length - 1) % filteredNodes.length; + } else { + nextIndex = (focusedIndex + filteredNodes.length - 1) % nodes.length; + } } // # TODO Handle Blur @@ -141,7 +172,14 @@ const NodeSearch = () => { // } if (key === 'Enter') { - const selectedNodeType = nodes[focusedIndex].type; + let selectedNodeType: AnyInvocationType; + + if (searchText.length > 0) { + selectedNodeType = filteredNodes[focusedIndex].item.type; + } else { + selectedNodeType = nodes[focusedIndex].type; + } + addNode(selectedNodeType); setShowNodeList(false); } @@ -163,13 +201,7 @@ const NodeSearch = () => { onBlur={searchInputBlurHandler} ref={nodeSearchRef} > - { - setSearchText(e.target.value); - setShowNodeList(true); - }} - /> + {showNodeList && renderNodeList()} diff --git a/invokeai/frontend/web/yarn.lock b/invokeai/frontend/web/yarn.lock index 4c7edbffd7..8688f1fd41 100644 --- a/invokeai/frontend/web/yarn.lock +++ b/invokeai/frontend/web/yarn.lock @@ -3454,6 +3454,11 @@ functions-have-names@^1.2.2: resolved "https://registry.yarnpkg.com/functions-have-names/-/functions-have-names-1.2.3.tgz#0404fe4ee2ba2f607f0e0ec3c80bae994133b834" integrity sha512-xckBUXyTIqT97tq2x2AMb+g163b5JFysYk0x4qxNFwbfQkmNZoiRHb6sPzI9/QV33WeuvVYBUIiD4NzNIyqaRQ== +fuse.js@^6.6.2: + version "6.6.2" + resolved "https://registry.yarnpkg.com/fuse.js/-/fuse.js-6.6.2.tgz#fe463fed4b98c0226ac3da2856a415576dc9a111" + integrity sha512-cJaJkxCCxC8qIIcPBF9yGxY0W/tVZS3uEISDxhYIdtk8OL93pe+6Zj7LjCqVV4dzbqcriOZ+kQ/NE4RXZHsIGA== + get-amd-module-type@^3.0.0: version "3.0.2" resolved "https://registry.yarnpkg.com/get-amd-module-type/-/get-amd-module-type-3.0.2.tgz#46550cee2b8e1fa4c3f2c8a5753c36990aa49ab0"