diff --git a/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImagePreview.tsx b/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImagePreview.tsx index 585bd91954..f7d186d11a 100644 --- a/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImagePreview.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/CurrentImage/CurrentImagePreview.tsx @@ -11,13 +11,11 @@ import type { import ProgressImage from 'features/gallery/components/CurrentImage/ProgressImage'; import ImageMetadataViewer from 'features/gallery/components/ImageMetadataViewer/ImageMetadataViewer'; import NextPrevImageButtons from 'features/gallery/components/NextPrevImageButtons'; -import { useNextPrevImage } from 'features/gallery/hooks/useNextPrevImage'; import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors'; import type { AnimationProps } from 'framer-motion'; import { AnimatePresence, motion } from 'framer-motion'; import type { CSSProperties } from 'react'; import { memo, useCallback, useMemo, useRef, useState } from 'react'; -import { useHotkeys } from 'react-hotkeys-hook'; import { useTranslation } from 'react-i18next'; import { FaImage } from 'react-icons/fa'; import { useGetImageDTOQuery } from 'services/api/endpoints/images'; @@ -39,61 +37,6 @@ const CurrentImagePreview = () => { (s) => s.ui.shouldShowProgressInViewer ); - const { - handlePrevImage, - handleNextImage, - isOnLastImage, - handleLoadMoreImages, - areMoreImagesAvailable, - isFetching, - handleTopImage, - handleBottomImage - } = useNextPrevImage(); - - useHotkeys( - 'left', - () => { - handlePrevImage(); - }, - [handlePrevImage] - ); - - useHotkeys( - 'right', - () => { - if (isOnLastImage && areMoreImagesAvailable && !isFetching) { - handleLoadMoreImages(); - return; - } - if (!isOnLastImage) { - handleNextImage(); - } - }, - [ - isOnLastImage, - areMoreImagesAvailable, - handleLoadMoreImages, - isFetching, - handleNextImage, - ] - ); - - useHotkeys( - 'up', - () => { - handleTopImage(); - }, - [handleTopImage] - ); - - useHotkeys( - 'down', - () => { - handleBottomImage(); - }, - [handleBottomImage] - ); - const { currentData: imageDTO } = useGetImageDTOQuery(imageName ?? skipToken); const draggableData = useMemo(() => { diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx index 95a1b5fb2e..78689795ba 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx @@ -13,9 +13,9 @@ import type { ImageDTOsDraggableData, TypesafeDraggableData, } from 'features/dnd/types'; -import type { VirtuosoGalleryContext } from 'features/gallery/components/ImageGrid/types'; +import { getGalleryImageDataTestId } from 'features/gallery/components/ImageGrid/getGalleryImageDataTestId'; import { useMultiselect } from 'features/gallery/hooks/useMultiselect'; -import { useScrollToVisible } from 'features/gallery/hooks/useScrollToVisible'; +import { useScrollIntoView } from 'features/gallery/hooks/useScrollIntoView'; import type { MouseEvent } from 'react'; import { memo, useCallback, useMemo, useState } from 'react'; import { useTranslation } from 'react-i18next'; @@ -35,12 +35,11 @@ const imageIconStyleOverrides: SystemStyleObject = { interface HoverableImageProps { imageName: string; index: number; - virtuosoContext: VirtuosoGalleryContext; } const GalleryImage = (props: HoverableImageProps) => { const dispatch = useAppDispatch(); - const { imageName, virtuosoContext } = props; + const { imageName } = props; const { currentData: imageDTO } = useGetImageDTOQuery(imageName); const shift = useStore($shift); const { t } = useTranslation(); @@ -50,11 +49,10 @@ const GalleryImage = (props: HoverableImageProps) => { const customStarUi = useStore($customStarUI); - const imageContainerRef = useScrollToVisible( + const imageContainerRef = useScrollIntoView( isSelected, props.index, - selectionCount, - virtuosoContext + selectionCount ); const handleDelete = useCallback( @@ -131,12 +129,22 @@ const GalleryImage = (props: HoverableImageProps) => { return ''; }, [imageDTO?.starred, customStarUi]); + const dataTestId = useMemo( + () => getGalleryImageDataTestId(imageDTO?.image_name), + [imageDTO?.image_name] + ); + if (!imageDTO) { return ; } return ( - + { ); const selectedBoardId = useAppSelector((s) => s.gallery.selectedBoardId); const { currentViewTotal } = useBoardTotal(selectedBoardId); - const queryArgs = useAppSelector(selectListImagesBaseQueryArgs); - const virtuosoRangeRef = useRef(null); - const virtuosoRef = useRef(null); - - const { currentData, isFetching, isSuccess, isError } = - useListImagesQuery(queryArgs); - - const [listImages] = useLazyListImagesQuery(); - - const areMoreAvailable = useMemo(() => { - if (!currentData || !currentViewTotal) { - return false; - } - return currentData.ids.length < currentViewTotal; - }, [currentData, currentViewTotal]); - - const handleLoadMoreImages = useCallback(() => { - if (!areMoreAvailable) { - return; - } - - listImages({ - ...queryArgs, - offset: currentData?.ids.length ?? 0, - limit: IMAGE_LIMIT, - }); - }, [areMoreAvailable, listImages, queryArgs, currentData?.ids.length]); - - const virtuosoContext = useMemo(() => { - return { - virtuosoRef, - rootRef, - virtuosoRangeRef, - }; - }, []); - - const itemContentFunc: ItemContent = - useCallback( - (index, imageName, virtuosoContext) => ( - - ), - [] - ); + const { + areMoreImagesAvailable, + handleLoadMoreImages, + queryResult: { currentData, isFetching, isSuccess, isError }, + } = useGalleryImages(); + useGalleryHotkeys(); + const itemContentFunc: ItemContent = useCallback( + (index, imageName) => ( + + ), + [] + ); useEffect(() => { // Initialize the gallery's custom scrollbar @@ -116,8 +79,10 @@ const GalleryImageGrid = () => { }, []); useEffect(() => { - $useNextPrevImageState.setKey('virtuosoRef', virtuosoRef); - $useNextPrevImageState.setKey('virtuosoRangeRef', virtuosoRangeRef); + virtuosoGridRefs.set({ rootRef, virtuosoRangeRef, virtuosoRef }); + return () => { + virtuosoGridRefs.set({}); + }; }, []); if (!currentData) { @@ -142,7 +107,7 @@ const GalleryImageGrid = () => { if (isSuccess && currentData) { return ( <> - + { itemContent={itemContentFunc} ref={virtuosoRef} rangeChanged={onRangeChanged} - context={virtuosoContext} overscan={10} /> ( {props.children} diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/getGalleryImageDataTestId.ts b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/getGalleryImageDataTestId.ts new file mode 100644 index 0000000000..ce3ab4ae46 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/getGalleryImageDataTestId.ts @@ -0,0 +1,2 @@ +export const getGalleryImageDataTestId = (imageName?: string) => + `gallery-image-${imageName}`; diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/types.ts b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/types.ts index 1e5ac93bdb..e43a55270e 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/types.ts +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/types.ts @@ -1,8 +1,11 @@ +import { atom } from 'nanostores'; import type { RefObject } from 'react'; import type { ListRange, VirtuosoGridHandle } from 'react-virtuoso'; -export type VirtuosoGalleryContext = { - virtuosoRef: RefObject; - rootRef: RefObject; - virtuosoRangeRef: RefObject; +export type VirtuosoGridRefs = { + virtuosoRef?: RefObject; + rootRef?: RefObject; + virtuosoRangeRef?: RefObject; }; + +export const virtuosoGridRefs = atom({}); diff --git a/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx b/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx index 70514faad0..64d48e748b 100644 --- a/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/NextPrevImageButtons.tsx @@ -1,7 +1,8 @@ import type { ChakraProps } from '@chakra-ui/react'; import { Box, Flex, Spinner } from '@chakra-ui/react'; import { InvIconButton } from 'common/components/InvIconButton/InvIconButton'; -import { useNextPrevImage } from 'features/gallery/hooks/useNextPrevImage'; +import { useGalleryImages } from 'features/gallery/hooks/useGalleryImages'; +import { useGalleryNavigation } from 'features/gallery/hooks/useGalleryNavigation'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; import { FaAngleDoubleRight, FaAngleLeft, FaAngleRight } from 'react-icons/fa'; @@ -14,15 +15,14 @@ const nextPrevButtonStyles: ChakraProps['sx'] = { const NextPrevImageButtons = () => { const { t } = useTranslation(); + const { handleLeftImage, handleRightImage, isOnFirstImage, isOnLastImage } = + useGalleryNavigation(); + const { - handlePrevImage, - handleNextImage, - isOnFirstImage, - isOnLastImage, - handleLoadMoreImages, areMoreImagesAvailable, - isFetching, - } = useNextPrevImage(); + handleLoadMoreImages, + queryResult: { isFetching }, + } = useGalleryImages(); return ( @@ -37,7 +37,7 @@ const NextPrevImageButtons = () => { aria-label={t('accessibility.previousImage')} icon={} variant="unstyled" - onClick={handlePrevImage} + onClick={handleLeftImage} boxSize={16} sx={nextPrevButtonStyles} /> @@ -54,7 +54,7 @@ const NextPrevImageButtons = () => { aria-label={t('accessibility.nextImage')} icon={} variant="unstyled" - onClick={handleNextImage} + onClick={handleRightImage} boxSize={16} sx={nextPrevButtonStyles} /> diff --git a/invokeai/frontend/web/src/features/gallery/hooks/useGalleryHotkeys.ts b/invokeai/frontend/web/src/features/gallery/hooks/useGalleryHotkeys.ts new file mode 100644 index 0000000000..e027900b93 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/hooks/useGalleryHotkeys.ts @@ -0,0 +1,79 @@ +import { useGalleryImages } from 'features/gallery/hooks/useGalleryImages'; +import { useGalleryNavigation } from 'features/gallery/hooks/useGalleryNavigation'; +import { useHotkeys } from 'react-hotkeys-hook'; + +/** + * Registers gallery hotkeys. This hook is a singleton. + */ +export const useGalleryHotkeys = () => { + const { + areMoreImagesAvailable, + handleLoadMoreImages, + queryResult: { isFetching }, + } = useGalleryImages(); + + const { + handleLeftImage, + handleRightImage, + handleUpImage, + handleDownImage, + isOnLastImage, + areImagesBelowCurrent, + } = useGalleryNavigation(); + + useHotkeys( + 'left', + () => { + handleLeftImage(); + }, + [handleLeftImage] + ); + + useHotkeys( + 'right', + () => { + if (isOnLastImage && areMoreImagesAvailable && !isFetching) { + handleLoadMoreImages(); + return; + } + if (!isOnLastImage) { + handleRightImage(); + } + }, + [ + isOnLastImage, + areMoreImagesAvailable, + handleLoadMoreImages, + isFetching, + handleRightImage, + ] + ); + + useHotkeys( + 'up', + () => { + handleUpImage(); + }, + { preventDefault: true }, + [handleUpImage] + ); + + useHotkeys( + 'down', + () => { + if (!areImagesBelowCurrent && areMoreImagesAvailable && !isFetching) { + handleLoadMoreImages(); + return; + } + handleDownImage(); + }, + { preventDefault: true }, + [ + areImagesBelowCurrent, + areMoreImagesAvailable, + handleLoadMoreImages, + isFetching, + handleDownImage, + ] + ); +}; diff --git a/invokeai/frontend/web/src/features/gallery/hooks/useGalleryImages.ts b/invokeai/frontend/web/src/features/gallery/hooks/useGalleryImages.ts new file mode 100644 index 0000000000..471f0846f7 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/hooks/useGalleryImages.ts @@ -0,0 +1,73 @@ +import { useStore } from '@nanostores/react'; +import { useAppSelector } from 'app/store/storeHooks'; +import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors'; +import { IMAGE_LIMIT } from 'features/gallery/store/types'; +import { atom } from 'nanostores'; +import { useCallback, useMemo } from 'react'; +import { + useGetBoardAssetsTotalQuery, + useGetBoardImagesTotalQuery, +} from 'services/api/endpoints/boards'; +import { useListImagesQuery } from 'services/api/endpoints/images'; +import type { ListImagesArgs } from 'services/api/types'; + +export type UseGalleryImagesReturn = { + handleLoadMoreImages: () => void; + areMoreImagesAvailable: boolean; + queryResult: ReturnType; +}; + +// The gallery is a singleton but multiple components need access to its query data. +// If we don't define the query args outside of the hook, then each component will +// have its own query args and trigger multiple requests. We use an atom to store +// the query args outside of the hook so that all consumers use the same query args. +const $queryArgs = atom(null); + +/** + * Provides access to the gallery images and a way to imperatively fetch more. + * + * This hook is a singleton. + */ +export const useGalleryImages = (): UseGalleryImagesReturn => { + const galleryView = useAppSelector((s) => s.gallery.galleryView); + const baseQueryArgs = useAppSelector(selectListImagesBaseQueryArgs); + const queryArgs = useStore($queryArgs); + const queryResult = useListImagesQuery(queryArgs ?? baseQueryArgs); + const boardId = useMemo( + () => baseQueryArgs.board_id ?? 'none', + [baseQueryArgs.board_id] + ); + const { data: assetsTotal } = useGetBoardAssetsTotalQuery(boardId); + const { data: imagesTotal } = useGetBoardImagesTotalQuery(boardId); + const currentViewTotal = useMemo( + () => (galleryView === 'images' ? imagesTotal?.total : assetsTotal?.total), + [assetsTotal?.total, galleryView, imagesTotal?.total] + ); + const loadedImagesCount = useMemo( + () => queryResult.data?.ids.length ?? 0, + [queryResult.data?.ids.length] + ); + const areMoreImagesAvailable = useMemo(() => { + if (!currentViewTotal || !queryResult.data) { + return false; + } + return queryResult.data.ids.length < currentViewTotal; + }, [queryResult.data, currentViewTotal]); + const handleLoadMoreImages = useCallback(() => { + // To load more images, we update the query args with an offset and limit. + const _queryArgs: ListImagesArgs = loadedImagesCount + ? { + ...baseQueryArgs, + offset: loadedImagesCount, + limit: IMAGE_LIMIT, + } + : baseQueryArgs; + $queryArgs.set(_queryArgs); + }, [baseQueryArgs, loadedImagesCount]); + + return { + areMoreImagesAvailable, + handleLoadMoreImages, + queryResult, + }; +}; diff --git a/invokeai/frontend/web/src/features/gallery/hooks/useGalleryNavigation.ts b/invokeai/frontend/web/src/features/gallery/hooks/useGalleryNavigation.ts new file mode 100644 index 0000000000..4744de8069 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/hooks/useGalleryNavigation.ts @@ -0,0 +1,216 @@ +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { getGalleryImageDataTestId } from 'features/gallery/components/ImageGrid/getGalleryImageDataTestId'; +import { GALLERY_IMAGE_PADDING_PX } from 'features/gallery/components/ImageGrid/ImageGridItemContainer'; +import { virtuosoGridRefs } from 'features/gallery/components/ImageGrid/types'; +import { useGalleryImages } from 'features/gallery/hooks/useGalleryImages'; +import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors'; +import { imageSelected } from 'features/gallery/store/gallerySlice'; +import { getIsVisible } from 'features/gallery/util/getIsVisible'; +import { getScrollToIndexAlign } from 'features/gallery/util/getScrollToIndexAlign'; +import { clamp } from 'lodash-es'; +import { useCallback, useMemo } from 'react'; +import type { ImageDTO } from 'services/api/types'; +import { imagesSelectors } from 'services/api/util'; + +/** + * This hook is used to navigate the gallery using the arrow keys. + * + * The gallery is rendered as a grid. In order to navigate the grid, + * we need to know how many images are in each row and whether or not + * an image is visible in the gallery. + * + * We use direct DOM query selectors to check if an image is visible + * to avoid having to track a ref for each image. + */ + +/** + * Gets the number of images per row in the gallery by grabbing their DOM elements. + */ +const getImagesPerRow = (): number => { + const imageRect = Object.values( + document.getElementsByClassName('gallerygrid-image') + )[0]?.getBoundingClientRect(); + + // We have to manually take into account the padding of the image container, else + // imagesPerRow will be wrong when the gallery is large or images are very small. + const widthOfGalleryImage = imageRect + ? imageRect.width + GALLERY_IMAGE_PADDING_PX * 2 + : 0; + + const galleryGridRect = document + .getElementById('gallery-grid') + ?.getBoundingClientRect(); + + const widthOfGalleryGrid = galleryGridRect?.width ?? 0; + + const imagesPerRow = Math.floor(widthOfGalleryGrid / widthOfGalleryImage); + + return imagesPerRow; +}; + +/** + * Scrolls to the image with the given name. + * If the image is not fully visible, it will not be scrolled to. + * @param imageName The image name to scroll to. + * @param index The index of the image in the gallery. + */ +const scrollToImage = (imageName: string, index: number) => { + const virtuosoContext = virtuosoGridRefs.get(); + const range = virtuosoContext.virtuosoRangeRef?.current; + const root = virtuosoContext.rootRef?.current; + const virtuoso = virtuosoContext.virtuosoRef?.current; + + if (!range || !virtuoso || !root) { + return; + } + + const imageElement = document.querySelector( + `[data-testid="${getGalleryImageDataTestId(imageName)}"]` + ); + const itemRect = imageElement?.getBoundingClientRect(); + const rootRect = root.getBoundingClientRect(); + if (!itemRect || !getIsVisible(itemRect, rootRect)) { + virtuoso.scrollToIndex({ + index, + align: getScrollToIndexAlign(index, range), + }); + } +}; + +// Utilities to get the image to the left, right, up, or down of the current image. + +const getLeftImage = (images: ImageDTO[], currentIndex: number) => { + const index = clamp(currentIndex - 1, 0, images.length - 1); + const image = images[index]; + return { index, image }; +}; + +const getRightImage = (images: ImageDTO[], currentIndex: number) => { + const index = clamp(currentIndex + 1, 0, images.length - 1); + const image = images[index]; + return { index, image }; +}; + +const getUpImage = (images: ImageDTO[], currentIndex: number) => { + const imagesPerRow = getImagesPerRow(); + // If we are on the first row, we want to stay on the first row, not go to first image + const isOnFirstRow = currentIndex < imagesPerRow; + const index = isOnFirstRow + ? currentIndex + : clamp(currentIndex - imagesPerRow, 0, images.length - 1); + const image = images[index]; + return { index, image }; +}; + +const getDownImage = (images: ImageDTO[], currentIndex: number) => { + const imagesPerRow = getImagesPerRow(); + // If we are on the first row, we want to stay on the first row, not go to last image + const isOnLastRow = currentIndex >= images.length - imagesPerRow; + const index = isOnLastRow + ? currentIndex + : clamp(currentIndex + imagesPerRow, 0, images.length - 1); + const image = images[index]; + return { index, image }; +}; + +const getImageFuncs = { + left: getLeftImage, + right: getRightImage, + up: getUpImage, + down: getDownImage, +}; + +export type UseGalleryNavigationReturn = { + handleLeftImage: () => void; + handleRightImage: () => void; + handleUpImage: () => void; + handleDownImage: () => void; + isOnFirstImage: boolean; + isOnLastImage: boolean; + areImagesBelowCurrent: boolean; +}; + +/** + * Provides access to the gallery navigation via arrow keys. + * Also provides information about the current image's position in the gallery, + * useful for determining whether to load more images or display navigatin + * buttons. + */ +export const useGalleryNavigation = (): UseGalleryNavigationReturn => { + const dispatch = useAppDispatch(); + const lastSelectedImage = useAppSelector(selectLastSelectedImage); + const { + queryResult: { data }, + } = useGalleryImages(); + const loadedImagesCount = useMemo( + () => data?.ids.length ?? 0, + [data?.ids.length] + ); + const lastSelectedImageIndex = useMemo(() => { + if (!data || !lastSelectedImage) { + return 0; + } + return imagesSelectors + .selectAll(data) + .findIndex((i) => i.image_name === lastSelectedImage.image_name); + }, [lastSelectedImage, data]); + + const handleNavigation = useCallback( + (direction: 'left' | 'right' | 'up' | 'down') => { + if (!data) { + return; + } + const { index, image } = getImageFuncs[direction]( + imagesSelectors.selectAll(data), + lastSelectedImageIndex + ); + if (!image) { + return; + } + dispatch(imageSelected(image)); + scrollToImage(image.image_name, index); + }, + [dispatch, lastSelectedImageIndex, data] + ); + + const isOnFirstImage = useMemo( + () => lastSelectedImageIndex === 0, + [lastSelectedImageIndex] + ); + + const isOnLastImage = useMemo( + () => lastSelectedImageIndex === loadedImagesCount - 1, + [lastSelectedImageIndex, loadedImagesCount] + ); + + const areImagesBelowCurrent = useMemo(() => { + const imagesPerRow = getImagesPerRow(); + return lastSelectedImageIndex + imagesPerRow < loadedImagesCount; + }, [lastSelectedImageIndex, loadedImagesCount]); + + const handleLeftImage = useCallback(() => { + handleNavigation('left'); + }, [handleNavigation]); + + const handleRightImage = useCallback(() => { + handleNavigation('right'); + }, [handleNavigation]); + + const handleUpImage = useCallback(() => { + handleNavigation('up'); + }, [handleNavigation]); + + const handleDownImage = useCallback(() => { + handleNavigation('down'); + }, [handleNavigation]); + + return { + handleLeftImage, + handleRightImage, + handleUpImage, + handleDownImage, + isOnFirstImage, + isOnLastImage, + areImagesBelowCurrent, + }; +}; diff --git a/invokeai/frontend/web/src/features/gallery/hooks/useMultiselect.ts b/invokeai/frontend/web/src/features/gallery/hooks/useMultiselect.ts index 3112ac3351..d8ee86fad4 100644 --- a/invokeai/frontend/web/src/features/gallery/hooks/useMultiselect.ts +++ b/invokeai/frontend/web/src/features/gallery/hooks/useMultiselect.ts @@ -1,6 +1,6 @@ import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors'; +import { useGalleryImages } from 'features/gallery/hooks/useGalleryImages'; import { selectGallerySlice, selectionChanged, @@ -8,29 +8,24 @@ import { import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus'; import type { MouseEvent } from 'react'; import { useCallback, useMemo } from 'react'; -import { useListImagesQuery } from 'services/api/endpoints/images'; import type { ImageDTO } from 'services/api/types'; import { imagesSelectors } from 'services/api/util'; -const selector = createMemoizedSelector( - [selectGallerySlice, selectListImagesBaseQueryArgs], - (gallery, queryArgs) => { - return { - queryArgs, - selection: gallery.selection, - }; - } +const selectGallerySelection = createMemoizedSelector( + selectGallerySlice, + (gallery) => gallery.selection ); +const EMPTY_ARRAY: ImageDTO[] = []; + export const useMultiselect = (imageDTO?: ImageDTO) => { const dispatch = useAppDispatch(); - const { queryArgs, selection } = useAppSelector(selector); - - const { imageDTOs } = useListImagesQuery(queryArgs, { - selectFromResult: (result) => ({ - imageDTOs: result.data ? imagesSelectors.selectAll(result.data) : [], - }), - }); + const selection = useAppSelector(selectGallerySelection); + const { data } = useGalleryImages().queryResult; + const imageDTOs = useMemo( + () => (data ? imagesSelectors.selectAll(data) : EMPTY_ARRAY), + [data] + ); const isMultiSelectEnabled = useFeatureStatus('multiselect').isFeatureEnabled; diff --git a/invokeai/frontend/web/src/features/gallery/hooks/useNextPrevImage.ts b/invokeai/frontend/web/src/features/gallery/hooks/useNextPrevImage.ts deleted file mode 100644 index ede9004516..0000000000 --- a/invokeai/frontend/web/src/features/gallery/hooks/useNextPrevImage.ts +++ /dev/null @@ -1,244 +0,0 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; -import type { RootState } from 'app/store/store'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; -import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors'; -import { imageSelected } from 'features/gallery/store/gallerySlice'; -import { IMAGE_LIMIT } from 'features/gallery/store/types'; -import { getScrollToIndexAlign } from 'features/gallery/util/getScrollToIndexAlign'; -import { clamp } from 'lodash-es'; -import { map } from 'nanostores'; -import type { RefObject } from 'react'; -import { useCallback } from 'react'; -import type { ListRange, VirtuosoGridHandle } from 'react-virtuoso'; -import { boardsApi } from 'services/api/endpoints/boards'; -import { - imagesApi, - useLazyListImagesQuery, -} from 'services/api/endpoints/images'; -import type { ListImagesArgs } from 'services/api/types'; -import { imagesSelectors } from 'services/api/util'; - -export type UseNextPrevImageState = { - virtuosoRef: RefObject | undefined; - virtuosoRangeRef: RefObject | undefined; -}; - -export const $useNextPrevImageState = map({ - virtuosoRef: undefined, - virtuosoRangeRef: undefined, -}); - -export const nextPrevImageButtonsSelector = createMemoizedSelector( - [(state: RootState) => state, selectListImagesBaseQueryArgs], - (state, baseQueryArgs) => { - const { data, status } = - imagesApi.endpoints.listImages.select(baseQueryArgs)(state); - - const { data: totalsData } = - state.gallery.galleryView === 'images' - ? boardsApi.endpoints.getBoardImagesTotal.select( - baseQueryArgs.board_id ?? 'none' - )(state) - : boardsApi.endpoints.getBoardAssetsTotal.select( - baseQueryArgs.board_id ?? 'none' - )(state); - - const lastSelectedImage = - state.gallery.selection[state.gallery.selection.length - 1]; - - - const isFetching = status === 'pending'; - - if (!data || !lastSelectedImage || totalsData?.total === 0) { - return { - isFetching, - queryArgs: baseQueryArgs, - isOnFirstImage: true, - isOnLastImage: true, - }; - } - - const queryArgs: ListImagesArgs = { - ...baseQueryArgs, - offset: data.ids.length, - limit: IMAGE_LIMIT, - }; - - const images = imagesSelectors.selectAll(data); - - const currentImageIndex = images.findIndex( - (i) => i.image_name === lastSelectedImage.image_name - ); - const widthOfGalleryImage = Object.values(document.getElementsByClassName("gallerygrid-image"))[0]?.clientWidth - const widthOfGalleryGrid = document.getElementById("gallery-grid")?.clientWidth - - - const imagesPerRow = Math.floor((widthOfGalleryGrid ?? 0) / (widthOfGalleryImage ?? 1)) - - - const nextImageIndex = clamp(currentImageIndex + 1, 0, images.length - 1); - const prevImageIndex = clamp(currentImageIndex - 1, 0, images.length - 1); - const topImageIndex = clamp(currentImageIndex - imagesPerRow,0, images.length - 1) - const bottomImageIndex = clamp(currentImageIndex + imagesPerRow,0, images.length - 1) - - const nextImageId = images[nextImageIndex]?.image_name; - const prevImageId = images[prevImageIndex]?.image_name; - const topImageId = images[topImageIndex]?.image_name - const bottomImageId = images[bottomImageIndex]?.image_name - - const nextImage = nextImageId - ? imagesSelectors.selectById(data, nextImageId) - : undefined; - const prevImage = prevImageId - ? imagesSelectors.selectById(data, prevImageId) - : undefined; - const topImage = topImageId - ? imagesSelectors.selectById(data, topImageId) - : undefined; - const bottomImage = bottomImageId - ? imagesSelectors.selectById(data, bottomImageId) - : undefined; - - const imagesLength = images.length; - - return { - loadedImagesCount: images.length, - currentImageIndex, - areMoreImagesAvailable: (totalsData?.total ?? 0) > imagesLength, - isFetching: status === 'pending', - nextImage, - prevImage, - nextImageIndex, - prevImageIndex, - queryArgs, - topImageIndex, - topImage, - bottomImageIndex, - bottomImage - }; - } -); - -export const useNextPrevImage = () => { - const dispatch = useAppDispatch(); - - const { - nextImage, - nextImageIndex, - prevImage, - prevImageIndex, - areMoreImagesAvailable, - isFetching, - queryArgs, - loadedImagesCount, - currentImageIndex, - topImageIndex, - topImage, - bottomImageIndex, - bottomImage - } = useAppSelector(nextPrevImageButtonsSelector); - - const handlePrevImage = useCallback(() => { - prevImage && dispatch(imageSelected(prevImage)); - const range = $useNextPrevImageState.get().virtuosoRangeRef?.current; - const virtuoso = $useNextPrevImageState.get().virtuosoRef?.current; - if (!range || !virtuoso) { - return; - } - - if ( - prevImageIndex !== undefined && - (prevImageIndex < range.startIndex || prevImageIndex > range.endIndex) - ) { - virtuoso.scrollToIndex({ - index: prevImageIndex, - behavior: 'smooth', - align: getScrollToIndexAlign(prevImageIndex, range), - }); - } - }, [dispatch, prevImage, prevImageIndex]); - - const handleNextImage = useCallback(() => { - nextImage && dispatch(imageSelected(nextImage)); - const range = $useNextPrevImageState.get().virtuosoRangeRef?.current; - const virtuoso = $useNextPrevImageState.get().virtuosoRef?.current; - if (!range || !virtuoso) { - return; - } - - if ( - nextImageIndex !== undefined && - (nextImageIndex < range.startIndex || nextImageIndex > range.endIndex) - ) { - virtuoso.scrollToIndex({ - index: nextImageIndex, - behavior: 'smooth', - align: getScrollToIndexAlign(nextImageIndex, range), - }); - } - }, [dispatch, nextImage, nextImageIndex]); - - const handleTopImage = useCallback(() => { - topImage && dispatch(imageSelected(topImage)); - const range = $useNextPrevImageState.get().virtuosoRangeRef?.current; - const virtuoso = $useNextPrevImageState.get().virtuosoRef?.current; - - if (!range || !virtuoso) { - return; - } - - if ( - topImageIndex !== undefined && - (topImageIndex < range.startIndex || topImageIndex > range.endIndex) - ) { - virtuoso.scrollToIndex({ - index: topImageIndex, - behavior: 'smooth', - align: getScrollToIndexAlign(topImageIndex, range), - }); - } - },[dispatch, topImage, topImageIndex]) - - const handleBottomImage = useCallback(() => { - bottomImage && dispatch(imageSelected(bottomImage)); - const range = $useNextPrevImageState.get().virtuosoRangeRef?.current; - const virtuoso = $useNextPrevImageState.get().virtuosoRef?.current; - - if (!range || !virtuoso) { - return; - } - - if ( - bottomImageIndex !== undefined && - (bottomImageIndex < range.startIndex || bottomImageIndex > range.endIndex) - ) { - virtuoso.scrollToIndex({ - index: bottomImageIndex, - behavior: 'smooth', - align: getScrollToIndexAlign(bottomImageIndex, range), - }); - } - },[dispatch, bottomImage, bottomImageIndex]) - - const [listImages] = useLazyListImagesQuery(); - - const handleLoadMoreImages = useCallback(() => { - listImages(queryArgs); - }, [listImages, queryArgs]); - - return { - handlePrevImage, - handleNextImage, - isOnFirstImage: currentImageIndex === 0, - isOnLastImage: - currentImageIndex !== undefined && - currentImageIndex === loadedImagesCount - 1, - nextImage, - prevImage, - areMoreImagesAvailable, - handleLoadMoreImages, - isFetching, - handleTopImage, - handleBottomImage - }; -}; diff --git a/invokeai/frontend/web/src/features/gallery/hooks/useScrollIntoView.ts b/invokeai/frontend/web/src/features/gallery/hooks/useScrollIntoView.ts new file mode 100644 index 0000000000..6bcafba073 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/hooks/useScrollIntoView.ts @@ -0,0 +1,52 @@ +import { virtuosoGridRefs } from 'features/gallery/components/ImageGrid/types'; +import { getIsVisible } from 'features/gallery/util/getIsVisible'; +import { getScrollToIndexAlign } from 'features/gallery/util/getScrollToIndexAlign'; +import { useEffect, useRef } from 'react'; + +/** + * Scrolls an image into view when it is selected. This is necessary because + * the image grid is virtualized, so the image may not be visible when it is + * selected. + * + * Also handles when an image is selected programmatically - for example, when + * auto-switching the new gallery images. + * + * @param isSelected Whether the image is selected. + * @param index The index of the image in the gallery. + * @param selectionCount The number of images selected. + * @returns + */ +export const useScrollIntoView = ( + isSelected: boolean, + index: number, + selectionCount: number +) => { + const imageContainerRef = useRef(null); + + useEffect(() => { + if (!isSelected || selectionCount !== 1) { + return; + } + + const virtuosoContext = virtuosoGridRefs.get(); + const range = virtuosoContext.virtuosoRangeRef?.current; + const root = virtuosoContext.rootRef?.current; + const virtuoso = virtuosoContext.virtuosoRef?.current; + + if (!range || !virtuoso || !root) { + return; + } + + const itemRect = imageContainerRef.current?.getBoundingClientRect(); + const rootRect = root.getBoundingClientRect(); + + if (!itemRect || !getIsVisible(itemRect, rootRect)) { + virtuoso.scrollToIndex({ + index, + align: getScrollToIndexAlign(index, range), + }); + } + }, [isSelected, index, selectionCount]); + + return imageContainerRef; +}; diff --git a/invokeai/frontend/web/src/features/gallery/hooks/useScrollToVisible.ts b/invokeai/frontend/web/src/features/gallery/hooks/useScrollToVisible.ts deleted file mode 100644 index 27eb0e62a4..0000000000 --- a/invokeai/frontend/web/src/features/gallery/hooks/useScrollToVisible.ts +++ /dev/null @@ -1,46 +0,0 @@ -import type { VirtuosoGalleryContext } from 'features/gallery/components/ImageGrid/types'; -import { getScrollToIndexAlign } from 'features/gallery/util/getScrollToIndexAlign'; -import { useEffect, useRef } from 'react'; - -export const useScrollToVisible = ( - isSelected: boolean, - index: number, - selectionCount: number, - virtuosoContext: VirtuosoGalleryContext -) => { - const imageContainerRef = useRef(null); - - useEffect(() => { - if ( - !isSelected || - selectionCount !== 1 || - !virtuosoContext.rootRef.current || - !virtuosoContext.virtuosoRef.current || - !virtuosoContext.virtuosoRangeRef.current || - !imageContainerRef.current - ) { - return; - } - - const itemRect = imageContainerRef.current.getBoundingClientRect(); - const rootRect = virtuosoContext.rootRef.current.getBoundingClientRect(); - const itemIsVisible = - itemRect.top >= rootRect.top && - itemRect.bottom <= rootRect.bottom && - itemRect.left >= rootRect.left && - itemRect.right <= rootRect.right; - - if (!itemIsVisible) { - virtuosoContext.virtuosoRef.current.scrollToIndex({ - index, - behavior: 'smooth', - align: getScrollToIndexAlign( - index, - virtuosoContext.virtuosoRangeRef.current - ), - }); - } - }, [isSelected, index, selectionCount, virtuosoContext]); - - return imageContainerRef; -}; diff --git a/invokeai/frontend/web/src/features/gallery/util/getIsVisible.ts b/invokeai/frontend/web/src/features/gallery/util/getIsVisible.ts new file mode 100644 index 0000000000..9f288ac8a0 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/util/getIsVisible.ts @@ -0,0 +1,12 @@ +/** + * Gets whether the item is visible in the root element. + */ + +export const getIsVisible = (itemRect: DOMRect, rootRect: DOMRect) => { + return ( + itemRect.top >= rootRect.top && + itemRect.bottom <= rootRect.bottom && + itemRect.left >= rootRect.left && + itemRect.right <= rootRect.right + ); +};