mirror of
https://github.com/invoke-ai/InvokeAI
synced 2025-07-26 05:17:55 +00:00
refactor: gallery scroll (improved impl)
This commit is contained in:
@ -622,3 +622,31 @@ async def get_image_collection(
|
||||
return image_dtos
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to get collection images")
|
||||
|
||||
|
||||
@images_router.get("/names", operation_id="get_image_names")
|
||||
async def get_image_names(
|
||||
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."),
|
||||
categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."),
|
||||
is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate images."),
|
||||
board_id: Optional[str] = Query(
|
||||
default=None,
|
||||
description="The board id to filter by. Use 'none' to find images without a board.",
|
||||
),
|
||||
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
|
||||
search_term: Optional[str] = Query(default=None, description="The term to search for"),
|
||||
) -> list[str]:
|
||||
"""Gets ordered list of all image names (starred first, then unstarred)"""
|
||||
|
||||
try:
|
||||
image_names = ApiDependencies.invoker.services.images.get_image_names(
|
||||
order_dir=order_dir,
|
||||
image_origin=image_origin,
|
||||
categories=categories,
|
||||
is_intermediate=is_intermediate,
|
||||
board_id=board_id,
|
||||
search_term=search_term,
|
||||
)
|
||||
return image_names
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to get image names")
|
||||
|
@ -126,3 +126,16 @@ class ImageRecordStorageBase(ABC):
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
"""Gets images from a specific collection (starred or unstarred)."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_image_names(
|
||||
self,
|
||||
order_dir: SQLiteDirection = SQLiteDirection.Descending,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> list[str]:
|
||||
"""Gets ordered list of all image names (starred first, then unstarred)."""
|
||||
pass
|
||||
|
@ -561,3 +561,76 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
|
||||
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
|
||||
|
||||
def get_image_names(
|
||||
self,
|
||||
order_dir: SQLiteDirection = SQLiteDirection.Descending,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> list[str]:
|
||||
cursor = self._conn.cursor()
|
||||
|
||||
# Base query to get image names in order (starred first, then unstarred)
|
||||
query = """--sql
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
query_conditions = ""
|
||||
query_params: list[Union[int, str, bool]] = []
|
||||
|
||||
if image_origin is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.image_origin = ?
|
||||
"""
|
||||
query_params.append(image_origin.value)
|
||||
|
||||
if categories is not None:
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
query_conditions += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
|
||||
if is_intermediate is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
query_params.append(is_intermediate)
|
||||
|
||||
if board_id == "none":
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
elif board_id is not None:
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
query_params.append(board_id)
|
||||
|
||||
if search_term:
|
||||
query_conditions += """--sql
|
||||
AND (
|
||||
images.metadata LIKE ?
|
||||
OR images.created_at LIKE ?
|
||||
)
|
||||
"""
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
|
||||
# Order by starred first, then by created_at
|
||||
query += query_conditions + f"""--sql
|
||||
ORDER BY images.starred DESC, images.created_at {order_dir.value}
|
||||
"""
|
||||
|
||||
cursor.execute(query, query_params)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
|
||||
return [row[0] for row in result]
|
||||
|
@ -176,3 +176,16 @@ class ImageServiceABC(ABC):
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets images from a specific collection (starred or unstarred)."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_image_names(
|
||||
self,
|
||||
order_dir: SQLiteDirection = SQLiteDirection.Descending,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> list[str]:
|
||||
"""Gets ordered list of all image names (starred first, then unstarred)."""
|
||||
pass
|
||||
|
@ -376,3 +376,25 @@ class ImageService(ImageServiceABC):
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Problem getting collection images")
|
||||
raise e
|
||||
|
||||
def get_image_names(
|
||||
self,
|
||||
order_dir: SQLiteDirection = SQLiteDirection.Descending,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> list[str]:
|
||||
try:
|
||||
return self.__invoker.services.image_records.get_image_names(
|
||||
order_dir=order_dir,
|
||||
image_origin=image_origin,
|
||||
categories=categories,
|
||||
is_intermediate=is_intermediate,
|
||||
board_id=board_id,
|
||||
search_term=search_term,
|
||||
)
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Problem getting image names")
|
||||
raise e
|
||||
|
@ -5,7 +5,7 @@ import { selectImageCollectionQueryArgs } from 'features/gallery/store/gallerySe
|
||||
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { uniq } from 'lodash-es';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageCategory, ImageDTO, SQLiteDirection } from 'services/api/types';
|
||||
import type { ImageCategory, SQLiteDirection } from 'services/api/types';
|
||||
|
||||
// Type for image collection query arguments
|
||||
type ImageCollectionQueryArgs = {
|
||||
@ -17,53 +17,12 @@ type ImageCollectionQueryArgs = {
|
||||
};
|
||||
|
||||
/**
|
||||
* Helper function to get all cached image data from collection queries
|
||||
* Returns a combined array of starred images followed by unstarred images
|
||||
* Helper function to get cached image names list for selection operations
|
||||
* Returns an ordered array of image names (starred first, then unstarred)
|
||||
*/
|
||||
const getCachedImageList = (state: RootState, queryArgs: ImageCollectionQueryArgs): ImageDTO[] => {
|
||||
const countsQueryResult = imagesApi.endpoints.getImageCollectionCounts.select(queryArgs)(state);
|
||||
|
||||
if (!countsQueryResult.data) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const { starred_count, unstarred_count } = countsQueryResult.data;
|
||||
|
||||
const imageDTOs: ImageDTO[] = [];
|
||||
|
||||
// Add starred images first (in order)
|
||||
if (starred_count > 0) {
|
||||
for (let offset = 0; offset < starred_count; offset += 50) {
|
||||
const queryResult = imagesApi.endpoints.getImageCollection.select({
|
||||
collection: 'starred',
|
||||
offset,
|
||||
limit: 50,
|
||||
...queryArgs,
|
||||
})(state);
|
||||
|
||||
if (queryResult.data?.items) {
|
||||
imageDTOs.push(...queryResult.data.items);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add unstarred images (in order)
|
||||
if (unstarred_count > 0) {
|
||||
for (let offset = 0; offset < unstarred_count; offset += 50) {
|
||||
const queryResult = imagesApi.endpoints.getImageCollection.select({
|
||||
collection: 'unstarred',
|
||||
offset,
|
||||
limit: 50,
|
||||
...queryArgs,
|
||||
})(state);
|
||||
|
||||
if (queryResult.data?.items) {
|
||||
imageDTOs.push(...queryResult.data.items);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return imageDTOs;
|
||||
const getCachedImageNames = (state: RootState, queryArgs: ImageCollectionQueryArgs): string[] => {
|
||||
const queryResult = imagesApi.endpoints.getImageNames.select(queryArgs)(state);
|
||||
return queryResult.data || [];
|
||||
};
|
||||
|
||||
export const galleryImageClicked = createAction<{
|
||||
@ -93,12 +52,12 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen
|
||||
const state = getState();
|
||||
const queryArgs = selectImageCollectionQueryArgs(state);
|
||||
|
||||
// Get all cached image data
|
||||
const imageDTOs = getCachedImageList(state, queryArgs);
|
||||
// Get cached image names for selection operations
|
||||
const imageNames = getCachedImageNames(state, queryArgs);
|
||||
|
||||
// If we don't have the image data cached, we can't perform selection operations
|
||||
// This can happen if the user clicks on an image before all data is loaded
|
||||
if (imageDTOs.length === 0) {
|
||||
// If we don't have the image names cached, we can't perform selection operations
|
||||
// This can happen if the user clicks on an image before the names are loaded
|
||||
if (imageNames.length === 0) {
|
||||
// For basic click without modifiers, we can still set selection
|
||||
if (!shiftKey && !ctrlKey && !metaKey && !altKey) {
|
||||
dispatch(selectionChanged([imageName]));
|
||||
@ -117,13 +76,13 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen
|
||||
} else if (shiftKey) {
|
||||
const rangeEndImageName = imageName;
|
||||
const lastSelectedImage = selection.at(-1);
|
||||
const lastClickedIndex = imageDTOs.findIndex((n) => n.image_name === lastSelectedImage);
|
||||
const currentClickedIndex = imageDTOs.findIndex((n) => n.image_name === rangeEndImageName);
|
||||
const lastClickedIndex = imageNames.findIndex((name) => name === lastSelectedImage);
|
||||
const currentClickedIndex = imageNames.findIndex((name) => name === rangeEndImageName);
|
||||
if (lastClickedIndex > -1 && currentClickedIndex > -1) {
|
||||
// We have a valid range!
|
||||
const start = Math.min(lastClickedIndex, currentClickedIndex);
|
||||
const end = Math.max(lastClickedIndex, currentClickedIndex);
|
||||
const imagesToSelect = imageDTOs.slice(start, end + 1).map(({ image_name }) => image_name);
|
||||
const imagesToSelect = imageNames.slice(start, end + 1);
|
||||
dispatch(selectionChanged(uniq(selection.concat(imagesToSelect))));
|
||||
}
|
||||
} else if (ctrlKey || metaKey) {
|
||||
|
@ -14,7 +14,11 @@ import type {
|
||||
VirtuosoGridHandle,
|
||||
} from 'react-virtuoso';
|
||||
import { VirtuosoGrid } from 'react-virtuoso';
|
||||
import { useGetImageCollectionCountsQuery, useGetImageCollectionQuery } from 'services/api/endpoints/images';
|
||||
import {
|
||||
useGetImageCollectionCountsQuery,
|
||||
useGetImageCollectionQuery,
|
||||
useGetImageNamesQuery,
|
||||
} from 'services/api/endpoints/images';
|
||||
import type { ImageCategory, SQLiteDirection } from 'services/api/types';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
|
||||
@ -167,6 +171,10 @@ export const NewGallery = memo(() => {
|
||||
|
||||
const { counts, isLoading } = useGetImageCollectionCountsQuery(queryArgs, getImageCollectionCountsOptions);
|
||||
|
||||
// Load image names for selection operations - this is lightweight and ensures
|
||||
// selection operations work even before image data is fully loaded
|
||||
useGetImageNamesQuery(queryArgs);
|
||||
|
||||
// Reset scroll position when query parameters change
|
||||
useEffect(() => {
|
||||
if (virtuosoRef.current && counts.total_count > 0) {
|
||||
|
@ -5,11 +5,13 @@ import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/type
|
||||
import type { components, paths } from 'services/api/schema';
|
||||
import type {
|
||||
GraphAndWorkflowResponse,
|
||||
ImageCategory,
|
||||
ImageDTO,
|
||||
ImageUploadEntryRequest,
|
||||
ImageUploadEntryResponse,
|
||||
ListImagesArgs,
|
||||
ListImagesResponse,
|
||||
SQLiteDirection,
|
||||
UploadImageArg,
|
||||
} from 'services/api/types';
|
||||
import { getCategories, getListImagesUrl } from 'services/api/util';
|
||||
@ -471,6 +473,26 @@ export const imagesApi = api.injectEndpoints({
|
||||
dispatch(imagesApi.util.upsertQueryEntries(updates));
|
||||
},
|
||||
}),
|
||||
/**
|
||||
* Get ordered list of image names for selection operations
|
||||
*/
|
||||
getImageNames: build.query<
|
||||
string[],
|
||||
{
|
||||
image_origin?: 'internal' | 'external' | null;
|
||||
categories?: ImageCategory[] | null;
|
||||
is_intermediate?: boolean | null;
|
||||
board_id?: string | null;
|
||||
search_term?: string | null;
|
||||
order_dir?: SQLiteDirection;
|
||||
}
|
||||
>({
|
||||
query: (queryArgs) => ({
|
||||
url: buildImagesUrl('names', queryArgs),
|
||||
method: 'GET',
|
||||
}),
|
||||
providesTags: ['ImageNameList', 'FetchOnReconnect'],
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
@ -495,6 +517,7 @@ export const {
|
||||
useGetImageCollectionCountsQuery,
|
||||
useGetImageCollectionQuery,
|
||||
useLazyGetImageCollectionQuery,
|
||||
useGetImageNamesQuery,
|
||||
} = imagesApi;
|
||||
|
||||
/**
|
||||
|
Reference in New Issue
Block a user