refactor: remove unused methods/routes, fix some gallery invalidation issues

This commit is contained in:
psychedelicious
2025-06-25 13:44:57 +10:00
parent 98368b0665
commit b2b42be51c
22 changed files with 139 additions and 657 deletions

View File

@ -1,7 +1,7 @@
import io
import json
import traceback
from typing import ClassVar, Literal, Optional
from typing import ClassVar, Optional
from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile
from fastapi.responses import FileResponse
@ -14,7 +14,6 @@ from invokeai.app.api.extract_metadata_from_image import extract_metadata_from_i
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageCollectionCounts,
ImageRecordChanges,
ResourceOrigin,
)
@ -565,67 +564,6 @@ async def get_bulk_download_item(
raise HTTPException(status_code=404)
@images_router.get(
"/collections/counts", operation_id="get_image_collection_counts", response_model=ImageCollectionCounts
)
async def get_image_collection_counts(
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to count."),
categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."),
is_intermediate: Optional[bool] = Query(default=None, description="Whether to include intermediate images."),
board_id: Optional[str] = Query(
default=None,
description="The board id to filter by. Use 'none' to find images without a board.",
),
search_term: Optional[str] = Query(default=None, description="The term to search for"),
) -> ImageCollectionCounts:
"""Gets counts for starred and unstarred image collections"""
try:
return ApiDependencies.invoker.services.images.get_collection_counts(
image_origin=image_origin,
categories=categories,
is_intermediate=is_intermediate,
board_id=board_id,
search_term=search_term,
)
except Exception:
raise HTTPException(status_code=500, detail="Failed to get collection counts")
@images_router.get("/collections/{collection}", operation_id="get_image_collection")
async def get_image_collection(
collection: Literal["starred", "unstarred"] = Path(..., description="The collection to retrieve from"),
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."),
categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."),
is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate images."),
board_id: Optional[str] = Query(
default=None,
description="The board id to filter by. Use 'none' to find images without a board.",
),
offset: int = Query(default=0, description="The offset within the collection"),
limit: int = Query(default=50, description="The number of images to return"),
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
search_term: Optional[str] = Query(default=None, description="The term to search for"),
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets images from a specific collection (starred or unstarred)"""
try:
image_dtos = ApiDependencies.invoker.services.images.get_collection_images(
collection=collection,
offset=offset,
limit=limit,
order_dir=order_dir,
image_origin=image_origin,
categories=categories,
is_intermediate=is_intermediate,
board_id=board_id,
search_term=search_term,
)
return image_dtos
except Exception:
raise HTTPException(status_code=500, detail="Failed to get collection images")
@images_router.get("/names", operation_id="get_image_names")
async def get_image_names(
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."),
@ -636,12 +574,14 @@ async def get_image_names(
description="The board id to filter by. Use 'none' to find images without a board.",
),
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
starred_first: bool = Query(default=True, description="Whether to sort by starred images first"),
search_term: Optional[str] = Query(default=None, description="The term to search for"),
) -> list[str]:
"""Gets ordered list of all image names (starred first, then unstarred)"""
try:
image_names = ApiDependencies.invoker.services.images.get_image_names(
starred_first=starred_first,
order_dir=order_dir,
image_origin=image_origin,
categories=categories,

View File

@ -587,9 +587,9 @@ def invocation(
for field_name, field_info in cls.model_fields.items():
annotation = field_info.annotation
assert annotation is not None, f"{field_name} on invocation {invocation_type} has no type annotation."
assert isinstance(field_info.json_schema_extra, dict), (
f"{field_name} on invocation {invocation_type} has a non-dict json_schema_extra, did you forget to use InputField?"
)
assert isinstance(
field_info.json_schema_extra, dict
), f"{field_name} on invocation {invocation_type} has a non-dict json_schema_extra, did you forget to use InputField?"
original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info)
@ -712,9 +712,9 @@ def invocation_output(
for field_name, field_info in cls.model_fields.items():
annotation = field_info.annotation
assert annotation is not None, f"{field_name} on invocation output {output_type} has no type annotation."
assert isinstance(field_info.json_schema_extra, dict), (
f"{field_name} on invocation output {output_type} has a non-dict json_schema_extra, did you forget to use InputField?"
)
assert isinstance(
field_info.json_schema_extra, dict
), f"{field_name} on invocation output {output_type} has a non-dict json_schema_extra, did you forget to use InputField?"
cls._original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info)

View File

@ -184,9 +184,9 @@ class SegmentAnythingInvocation(BaseInvocation):
# Find the largest mask.
return [max(masks, key=lambda x: float(x.sum()))]
elif self.mask_filter == "highest_box_score":
assert bounding_boxes is not None, (
"Bounding boxes must be provided to use the 'highest_box_score' mask filter."
)
assert (
bounding_boxes is not None
), "Bounding boxes must be provided to use the 'highest_box_score' mask filter."
assert len(masks) == len(bounding_boxes)
# Find the index of the bounding box with the highest score.
# Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most

View File

@ -482,9 +482,9 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
try:
# Meta is not included in the model fields, so we need to validate it separately
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
assert config.schema_version == CONFIG_SCHEMA_VERSION, (
f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
)
assert (
config.schema_version == CONFIG_SCHEMA_VERSION
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
return config
except Exception as e:
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e

View File

@ -1,11 +1,10 @@
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Literal, Optional
from typing import Optional
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageCollectionCounts,
ImageRecord,
ImageRecordChanges,
ResourceOrigin,
@ -99,37 +98,10 @@ class ImageRecordStorageBase(ABC):
"""Gets the most recent image for a board."""
pass
@abstractmethod
def get_collection_counts(
self,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> ImageCollectionCounts:
"""Gets counts for starred and unstarred image collections."""
pass
@abstractmethod
def get_collection_images(
self,
collection: Literal["starred", "unstarred"],
offset: int = 0,
limit: int = 10,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
"""Gets images from a specific collection (starred or unstarred)."""
pass
@abstractmethod
def get_image_names(
self,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,

View File

@ -1,13 +1,12 @@
import sqlite3
from datetime import datetime
from typing import Literal, Optional, Union, cast
from typing import Optional, Union, cast
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator
from invokeai.app.services.image_records.image_records_base import ImageRecordStorageBase
from invokeai.app.services.image_records.image_records_common import (
IMAGE_DTO_COLS,
ImageCategory,
ImageCollectionCounts,
ImageRecord,
ImageRecordChanges,
ImageRecordDeleteException,
@ -388,182 +387,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
return deserialize_image_record(dict(result))
def get_collection_counts(
self,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> ImageCollectionCounts:
cursor = self._conn.cursor()
# Build the base query conditions (same as get_many)
base_query = """--sql
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
if categories is not None:
category_strings = [c.value for c in set(categories)]
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
for c in category_strings:
query_params.append(c)
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
"""
query_params.append(is_intermediate)
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
if search_term:
query_conditions += """--sql
AND (
images.metadata LIKE ?
OR images.created_at LIKE ?
)
"""
query_params.append(f"%{search_term.lower()}%")
query_params.append(f"%{search_term.lower()}%")
# Get starred count
starred_query = f"SELECT COUNT(*) {base_query} {query_conditions} AND images.starred = TRUE;"
cursor.execute(starred_query, query_params)
starred_count = cast(int, cursor.fetchone()[0])
# Get unstarred count
unstarred_query = f"SELECT COUNT(*) {base_query} {query_conditions} AND images.starred = FALSE;"
cursor.execute(unstarred_query, query_params)
unstarred_count = cast(int, cursor.fetchone()[0])
return ImageCollectionCounts(starred_count=starred_count, unstarred_count=unstarred_count)
def get_collection_images(
self,
collection: Literal["starred", "unstarred"],
offset: int = 0,
limit: int = 10,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
cursor = self._conn.cursor()
# Base queries
count_query = """--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
images_query = f"""--sql
SELECT {IMAGE_DTO_COLS}
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
# Add starred/unstarred filter
is_starred = collection == "starred"
query_conditions += """--sql
AND images.starred = ?
"""
query_params.append(is_starred)
if image_origin is not None:
query_conditions += """--sql
AND images.image_origin = ?
"""
query_params.append(image_origin.value)
if categories is not None:
category_strings = [c.value for c in set(categories)]
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"""--sql
AND images.image_category IN ( {placeholders} )
"""
for c in category_strings:
query_params.append(c)
if is_intermediate is not None:
query_conditions += """--sql
AND images.is_intermediate = ?
"""
query_params.append(is_intermediate)
if board_id == "none":
query_conditions += """--sql
AND board_images.board_id IS NULL
"""
elif board_id is not None:
query_conditions += """--sql
AND board_images.board_id = ?
"""
query_params.append(board_id)
if search_term:
query_conditions += """--sql
AND (
images.metadata LIKE ?
OR images.created_at LIKE ?
)
"""
query_params.append(f"%{search_term.lower()}%")
query_params.append(f"%{search_term.lower()}%")
# Add ordering and pagination
query_pagination = f"""--sql
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
"""
# Execute images query
images_query += query_conditions + query_pagination + ";"
images_params = query_params.copy()
images_params.extend([limit, offset])
cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
# Execute count query
count_query += query_conditions + ";"
cursor.execute(count_query, query_params)
count = cast(int, cursor.fetchone()[0])
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
def get_image_names(
self,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
@ -625,13 +451,20 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
query_params.append(f"%{search_term.lower()}%")
query_params.append(f"%{search_term.lower()}%")
# Order by starred first, then by created_at
query += (
query_conditions
+ f"""--sql
ORDER BY images.starred DESC, images.created_at {order_dir.value}
"""
)
if starred_first:
query += (
query_conditions
+ f"""--sql
ORDER BY images.starred DESC, images.created_at {order_dir.value}
"""
)
else:
query += (
query_conditions
+ f"""--sql
ORDER BY images.created_at {order_dir.value}
"""
)
cursor.execute(query, query_params)
result = cast(list[sqlite3.Row], cursor.fetchall())

View File

@ -1,12 +1,11 @@
from abc import ABC, abstractmethod
from typing import Callable, Literal, Optional
from typing import Callable, Optional
from PIL.Image import Image as PILImageType
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageCollectionCounts,
ImageRecord,
ImageRecordChanges,
ResourceOrigin,
@ -149,37 +148,10 @@ class ImageServiceABC(ABC):
"""Deletes all images on a board."""
pass
@abstractmethod
def get_collection_counts(
self,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> ImageCollectionCounts:
"""Gets counts for starred and unstarred image collections."""
pass
@abstractmethod
def get_collection_images(
self,
collection: Literal["starred", "unstarred"],
offset: int = 0,
limit: int = 10,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets images from a specific collection (starred or unstarred)."""
pass
@abstractmethod
def get_image_names(
self,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
@ -187,5 +159,5 @@ class ImageServiceABC(ABC):
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> list[str]:
"""Gets ordered list of all image names (starred first, then unstarred)."""
"""Gets ordered list of all image names."""
pass

View File

@ -1,4 +1,4 @@
from typing import Literal, Optional
from typing import Optional
from PIL.Image import Image as PILImageType
@ -10,7 +10,6 @@ from invokeai.app.services.image_files.image_files_common import (
)
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageCollectionCounts,
ImageRecord,
ImageRecordChanges,
ImageRecordDeleteException,
@ -311,73 +310,9 @@ class ImageService(ImageServiceABC):
self.__invoker.services.logger.error("Problem getting intermediates count")
raise e
def get_collection_counts(
self,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> ImageCollectionCounts:
try:
return self.__invoker.services.image_records.get_collection_counts(
image_origin=image_origin,
categories=categories,
is_intermediate=is_intermediate,
board_id=board_id,
search_term=search_term,
)
except Exception as e:
self.__invoker.services.logger.error("Problem getting collection counts")
raise e
def get_collection_images(
self,
collection: Literal["starred", "unstarred"],
offset: int = 0,
limit: int = 10,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageDTO]:
try:
results = self.__invoker.services.image_records.get_collection_images(
collection=collection,
offset=offset,
limit=limit,
order_dir=order_dir,
image_origin=image_origin,
categories=categories,
is_intermediate=is_intermediate,
board_id=board_id,
search_term=search_term,
)
image_dtos = [
image_record_to_dto(
image_record=r,
image_url=self.__invoker.services.urls.get_image_url(r.image_name),
thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True),
board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
)
for r in results.items
]
return OffsetPaginatedResults[ImageDTO](
items=image_dtos,
offset=results.offset,
limit=results.limit,
total=results.total,
)
except Exception as e:
self.__invoker.services.logger.error("Problem getting collection images")
raise e
def get_image_names(
self,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
@ -387,6 +322,7 @@ class ImageService(ImageServiceABC):
) -> list[str]:
try:
return self.__invoker.services.image_records.get_image_names(
starred_first=starred_first,
order_dir=order_dir,
image_origin=image_origin,
categories=categories,

View File

@ -379,13 +379,13 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
bytes_ = path.read_bytes()
workflow_from_file = WorkflowValidator.validate_json(bytes_)
assert workflow_from_file.id.startswith("default_"), (
f'Invalid default workflow ID (must start with "default_"): {workflow_from_file.id}'
)
assert workflow_from_file.id.startswith(
"default_"
), f'Invalid default workflow ID (must start with "default_"): {workflow_from_file.id}'
assert workflow_from_file.meta.category is WorkflowCategory.Default, (
f"Invalid default workflow category: {workflow_from_file.meta.category}"
)
assert (
workflow_from_file.meta.category is WorkflowCategory.Default
), f"Invalid default workflow category: {workflow_from_file.meta.category}"
workflows_from_file.append(workflow_from_file)

View File

@ -381,7 +381,7 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
state_dict = mod.load_state_dict()
for key in state_dict.keys():
if type(key) is int:
if isinstance(key, int):
continue
if key.startswith(("lora_te_", "lora_unet_", "lora_te1_", "lora_te2_", "lora_transformer_")):

View File

@ -115,19 +115,19 @@ class ModelMerger(object):
base_models: Set[BaseModelType] = set()
variant = None if self._installer.app_config.precision == "float32" else "fp16"
assert len(model_keys) <= 2 or interp == MergeInterpolationMethod.AddDifference, (
"When merging three models, only the 'add_difference' merge method is supported"
)
assert (
len(model_keys) <= 2 or interp == MergeInterpolationMethod.AddDifference
), "When merging three models, only the 'add_difference' merge method is supported"
for key in model_keys:
info = store.get_model(key)
model_names.append(info.name)
assert isinstance(info, MainDiffusersConfig), (
f"{info.name} ({info.key}) is not a diffusers model. It must be optimized before merging"
)
assert info.variant == ModelVariantType("normal"), (
f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged"
)
assert isinstance(
info, MainDiffusersConfig
), f"{info.name} ({info.key}) is not a diffusers model. It must be optimized before merging"
assert info.variant == ModelVariantType(
"normal"
), f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged"
# tally base models used
base_models.add(info.base)

View File

@ -1,6 +1,6 @@
import { isAnyOf } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors';
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';
@ -13,7 +13,7 @@ export const addBoardIdSelectedListener = (startAppListening: AppStartListening)
const state = getState();
const queryArgs = selectListImagesQueryArgs(state);
const queryArgs = { ...selectListImagesBaseQueryArgs(state), offset: 0 };
// wait until the board has some images - maybe it already has some from a previous fetch
// must use getState() to ensure we do not have stale state

View File

@ -1,29 +1,9 @@
import { createAction } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import type { RootState } from 'app/store/store';
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
import { selectListImageNamesQueryArgs } from 'features/gallery/store/gallerySelectors';
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import { uniq } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageCategory, SQLiteDirection } from 'services/api/types';
// Type for image collection query arguments
type ImageCollectionQueryArgs = {
board_id?: string;
categories?: ImageCategory[];
search_term?: string;
order_dir?: SQLiteDirection;
is_intermediate: boolean;
};
/**
* Helper function to get cached image names list for selection operations
* Returns an ordered array of image names (starred first, then unstarred)
*/
const getCachedImageNames = (state: RootState, queryArgs: ImageCollectionQueryArgs): string[] => {
const queryResult = imagesApi.endpoints.getImageNames.select(queryArgs)(state);
return queryResult.data || [];
};
export const galleryImageClicked = createAction<{
imageName: string;
@ -50,10 +30,8 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen
effect: (action, { dispatch, getState }) => {
const { imageName, shiftKey, ctrlKey, metaKey, altKey } = action.payload;
const state = getState();
const queryArgs = selectListImagesQueryArgs(state);
// Get cached image names for selection operations
const imageNames = getCachedImageNames(state, queryArgs);
const queryArgs = selectListImageNamesQueryArgs(state);
const imageNames = imagesApi.endpoints.getImageNames.select(queryArgs)(state).data ?? [];
// If we don't have the image names cached, we can't perform selection operations
// This can happen if the user clicks on an image before the names are loaded

View File

@ -10,7 +10,6 @@ import {
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type { CanvasState, RefImagesState } from 'features/controlLayers/store/types';
import type { ImageUsage } from 'features/deleteImageModal/store/types';
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
import { imageSelected } from 'features/gallery/store/gallerySlice';
import { fieldImageCollectionValueChanged, fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { selectNodesSlice } from 'features/nodes/store/selectors';
@ -81,14 +80,8 @@ const handleDeletions = async (image_names: string[], dispatch: AppDispatch, get
await dispatch(imagesApi.endpoints.deleteImages.initiate({ image_names }, { track: false })).unwrap();
if (intersection(state.gallery.selection, image_names).length > 0) {
// Some selected images were deleted, need to select the next image
const queryArgs = selectListImagesQueryArgs(state);
const { data } = imagesApi.endpoints.listImages.select(queryArgs)(state);
if (data) {
// When we delete multiple images, we clear the selection. Then, the the next time we load images, we will
// select the first one. This is handled below in the listener for `imagesApi.endpoints.listImages.matchFulfilled`.
dispatch(imageSelected(null));
}
// Some selected images were deleted, clear selection
dispatch(imageSelected(null));
}
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist

View File

@ -5,6 +5,7 @@ import { CanvasAlertsInvocationProgress } from 'features/controlLayers/component
import { DndImage } from 'features/dnd/DndImage';
import ImageMetadataViewer from 'features/gallery/components/ImageMetadataViewer/ImageMetadataViewer';
import NextPrevImageButtons from 'features/gallery/components/NextPrevImageButtons';
import { selectAutoSwitch } from 'features/gallery/store/gallerySelectors';
import type { ProgressImage as ProgressImageType } from 'features/nodes/types/common';
import { selectShouldShowImageDetails, selectShouldShowProgressInViewer } from 'features/ui/store/uiSelectors';
import type { AnimationProps } from 'framer-motion';
@ -21,6 +22,7 @@ import { ProgressIndicator } from './ProgressIndicator2';
export const CurrentImagePreview = memo(({ imageDTO }: { imageDTO: ImageDTO | null }) => {
const shouldShowImageDetails = useAppSelector(selectShouldShowImageDetails);
const shouldShowProgressInViewer = useAppSelector(selectShouldShowProgressInViewer);
const autoSwitch = useAppSelector(selectAutoSwitch);
const socket = useStore($socket);
const [progressEvent, setProgressEvent] = useState<S['InvocationProgressEvent'] | null>(null);
@ -58,6 +60,29 @@ export const CurrentImagePreview = memo(({ imageDTO }: { imageDTO: ImageDTO | nu
};
}, [socket]);
useEffect(() => {
if (!socket) {
return;
}
if (autoSwitch) {
return;
}
// When auto-switch is enabled, we will get a load event as we switch to the new image. This in turn clears the progress image,
// creating the illusion of the progress image turning into the new image.
// But when auto-switch is disabled, we won't get that load event, so we need to clear the progress image manually.
const onQueueItemStatusChanged = () => {
setProgressEvent(null);
setProgressImage(null);
};
socket.on('queue_item_status_changed', onQueueItemStatusChanged);
return () => {
socket.off('queue_item_status_changed', onQueueItemStatusChanged);
};
}, [autoSwitch, socket]);
const onLoadImage = useCallback(() => {
if (!progressEvent || !imageDTO) {
return;

View File

@ -4,10 +4,11 @@ import { logger } from 'app/logging/logger';
import { EMPTY_ARRAY } from 'app/store/constants';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import {
LIMIT,
selectGalleryImageMinimumWidth,
selectImageToCompare,
selectLastSelectedImage,
selectListImagesQueryArgs,
selectListImageNamesQueryArgs,
} from 'features/gallery/store/gallerySelectors';
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
@ -37,31 +38,26 @@ const SCROLL_SEEK_VELOCITY_THRESHOLD = 4096;
const DEBOUNCE_DELAY = 500;
const SPINNER_OPACITY = 0.3;
type ListImagesQueryArgs = ReturnType<typeof selectListImagesQueryArgs>;
type ListImageNamesQueryArgs = ReturnType<typeof selectListImageNamesQueryArgs>;
type GridContext = {
queryArgs: ListImagesQueryArgs;
queryArgs: ListImageNamesQueryArgs;
imageNames: string[];
};
export const useDebouncedListImagesQueryArgs = () => {
const _galleryQueryArgs = useAppSelector(selectListImagesQueryArgs);
const [queryArgs] = useDebounce(_galleryQueryArgs, DEBOUNCE_DELAY);
return queryArgs;
};
// Hook to get an image DTO from cache or trigger loading
const useImageDTOFromListQuery = (
index: number,
imageName: string,
queryArgs: ListImagesQueryArgs
queryArgs: ListImageNamesQueryArgs
): ImageDTO | null => {
const { arg, options } = useMemo(() => {
const pageOffset = Math.floor(index / queryArgs.limit) * queryArgs.limit;
const pageOffset = Math.floor(index / LIMIT) * LIMIT;
return {
arg: {
...queryArgs,
offset: pageOffset,
limit: LIMIT,
} satisfies Parameters<typeof useListImagesQuery>[0],
options: {
selectFromResult: ({ data }) => {
@ -82,7 +78,7 @@ const useImageDTOFromListQuery = (
// Individual image component that gets its data from RTK Query cache
const ImageAtPosition = memo(
({ index, queryArgs, imageName }: { index: number; imageName: string; queryArgs: ListImagesQueryArgs }) => {
({ index, queryArgs, imageName }: { index: number; imageName: string; queryArgs: ListImageNamesQueryArgs }) => {
const imageDTO = useImageDTOFromListQuery(index, imageName, queryArgs);
if (!imageDTO) {
@ -408,7 +404,8 @@ const getImageNamesQueryOptions = {
} satisfies Parameters<typeof useGetImageNamesQuery>[1];
export const useGalleryImageNames = () => {
const queryArgs = useDebouncedListImagesQueryArgs();
const _queryArgs = useAppSelector(selectListImageNamesQueryArgs);
const [queryArgs] = useDebounce(_queryArgs, DEBOUNCE_DELAY);
const { imageNames, isLoading, isFetching } = useGetImageNamesQuery(queryArgs, getImageNamesQueryOptions);
return { imageNames, isLoading, isFetching, queryArgs };
};

View File

@ -2,8 +2,7 @@ import { createSelector } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { selectGallerySlice } from 'features/gallery/store/gallerySlice';
import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/types';
import type { ListBoardsArgs, ListImagesArgs } from 'services/api/types';
import type { SetNonNullable } from 'type-fest';
import type { ListBoardsArgs } from 'services/api/types';
export const selectFirstSelectedImage = createSelector(selectGallerySlice, (gallery) => gallery.selection.at(0));
export const selectLastSelectedImage = createSelector(selectGallerySlice, (gallery) => gallery.selection.at(-1));
@ -28,7 +27,7 @@ export const selectGallerySearchTerm = createSelector(selectGallerySlice, (galle
export const selectGalleryOrderDir = createSelector(selectGallerySlice, (gallery) => gallery.orderDir);
export const selectGalleryStarredFirst = createSelector(selectGallerySlice, (gallery) => gallery.starredFirst);
export const selectListImagesQueryArgs = createMemoizedSelector(
export const selectListImageNamesQueryArgs = createMemoizedSelector(
[
selectSelectedBoardId,
selectGalleryQueryCategories,
@ -36,17 +35,20 @@ export const selectListImagesQueryArgs = createMemoizedSelector(
selectGalleryOrderDir,
selectGalleryStarredFirst,
],
(board_id, categories, search_term, order_dir, starred_first) =>
({
board_id,
categories,
search_term,
order_dir,
starred_first,
is_intermediate: false, // We don't show intermediate images in the gallery
limit: 100, // Page size is _always_ 100
}) satisfies SetNonNullable<ListImagesArgs, 'limit'>
(board_id, categories, search_term, order_dir, starred_first) => ({
board_id,
categories,
search_term,
order_dir,
starred_first,
is_intermediate: false,
})
);
export const LIMIT = 100;
export const selectListImagesBaseQueryArgs = createMemoizedSelector(selectListImageNamesQueryArgs, (baseQueryArgs) => ({
...baseQueryArgs,
limit: LIMIT,
}));
export const selectAutoAssignBoardOnClick = createSelector(
selectGallerySlice,
(gallery) => gallery.autoAssignBoardOnClick

View File

@ -427,61 +427,12 @@ export const imagesApi = api.injectEndpoints({
},
}),
}),
/**
* Get counts for starred and unstarred image collections
*/
getImageCollectionCounts: build.query<
paths['/api/v1/images/collections/counts']['get']['responses']['200']['content']['application/json'],
paths['/api/v1/images/collections/counts']['get']['parameters']['query']
>({
query: (queryArgs) => ({
url: buildImagesUrl('collections/counts', queryArgs),
method: 'GET',
}),
providesTags: ['ImageCollectionCounts', 'FetchOnReconnect'],
}),
/**
* Get images from a specific collection (starred or unstarred)
*/
getImageCollection: build.query<
paths['/api/v1/images/collections/{collection}']['get']['responses']['200']['content']['application/json'],
paths['/api/v1/images/collections/{collection}']['get']['parameters']['path'] &
paths['/api/v1/images/collections/{collection}']['get']['parameters']['query']
>({
query: ({ collection, ...queryArgs }) => ({
url: buildImagesUrl(`collections/${collection}`, queryArgs),
method: 'GET',
}),
providesTags: (result, error, { collection, board_id, categories }) => {
const cacheKey = `${collection}-${board_id || 'all'}-${categories?.join(',') || 'all'}`;
return [
{ type: 'ImageCollection', id: collection },
{ type: 'ImageCollection', id: cacheKey },
'FetchOnReconnect',
];
},
async onQueryStarted(_, { dispatch, queryFulfilled }) {
// Populate the getImageDTO cache with these images, similar to listImages
const res = await queryFulfilled;
const imageDTOs = res.data.items;
const updates: Param0<typeof imagesApi.util.upsertQueryEntries> = [];
for (const imageDTO of imageDTOs) {
updates.push({
endpointName: 'getImageDTO',
arg: imageDTO.image_name,
value: imageDTO,
});
}
dispatch(imagesApi.util.upsertQueryEntries(updates));
},
}),
/**
* Get ordered list of image names for selection operations
*/
getImageNames: build.query<
string[],
{
image_origin?: 'internal' | 'external' | null;
categories?: ImageCategory[] | null;
is_intermediate?: boolean | null;
board_id?: string | null;
@ -493,46 +444,11 @@ export const imagesApi = api.injectEndpoints({
url: buildImagesUrl('names', queryArgs),
method: 'GET',
}),
providesTags: ['ImageNameList', 'FetchOnReconnect'],
}),
/**
* Get paginated images with starred first (unified list)
*/
getUnifiedImageList: build.query<
ListImagesResponse,
{
offset?: number;
limit?: number;
image_origin?: 'internal' | 'external' | null;
categories?: ImageCategory[] | null;
is_intermediate?: boolean | null;
board_id?: string | null;
search_term?: string | null;
order_dir?: SQLiteDirection;
}
>({
query: (queryArgs) => ({
url: getListImagesUrl({ ...queryArgs, starred_first: true }),
method: 'GET',
}),
providesTags: (result, error, { board_id, categories }) => [
{ type: 'ImageList', id: getListImagesUrl({ board_id, categories }) },
providesTags: (result, error, queryArgs) => [
'ImageNameList',
'FetchOnReconnect',
{ type: 'ImageNameList', id: stableHash(queryArgs) },
],
async onQueryStarted(_, { dispatch, queryFulfilled }) {
// Populate the getImageDTO cache with these images
const res = await queryFulfilled;
const imageDTOs = res.data.items;
const updates: Param0<typeof imagesApi.util.upsertQueryEntries> = [];
for (const imageDTO of imageDTOs) {
updates.push({
endpointName: 'getImageDTO',
arg: imageDTO.image_name,
value: imageDTO,
});
}
dispatch(imagesApi.util.upsertQueryEntries(updates));
},
}),
}),
});
@ -555,11 +471,7 @@ export const {
useStarImagesMutation,
useUnstarImagesMutation,
useBulkDownloadImagesMutation,
useGetImageCollectionCountsQuery,
useGetImageCollectionQuery,
useLazyGetImageCollectionQuery,
useGetImageNamesQuery,
useGetUnifiedImageListQuery,
} = imagesApi;
/**

View File

@ -752,7 +752,7 @@ export type paths = {
patch?: never;
trace?: never;
};
"/api/v1/images/collections/counts": {
"/api/v1/images/names": {
parameters: {
query?: never;
header?: never;
@ -760,30 +760,10 @@ export type paths = {
cookie?: never;
};
/**
* Get Image Collection Counts
* @description Gets counts for starred and unstarred image collections
* Get Image Names
* @description Gets ordered list of all image names (starred first, then unstarred)
*/
get: operations["get_image_collection_counts"];
put?: never;
post?: never;
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/api/v1/images/collections/{collection}": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
/**
* Get Image Collection
* @description Gets images from a specific collection (starred or unstarred)
*/
get: operations["get_image_collection"];
get: operations["get_image_names"];
put?: never;
post?: never;
delete?: never;
@ -9844,19 +9824,6 @@ export type components = {
*/
type: "img_channel_offset";
};
/** ImageCollectionCounts */
ImageCollectionCounts: {
/**
* Starred Count
* @description The number of starred images in the collection.
*/
starred_count: number;
/**
* Unstarred Count
* @description The number of unstarred images in the collection.
*/
unstarred_count: number;
};
/**
* Image Collection Primitive
* @description A collection of image primitive values
@ -23728,17 +23695,21 @@ export interface operations {
};
};
};
get_image_collection_counts: {
get_image_names: {
parameters: {
query?: {
/** @description The origin of images to count. */
/** @description The origin of images to list. */
image_origin?: components["schemas"]["ResourceOrigin"] | null;
/** @description The categories of image to include. */
categories?: components["schemas"]["ImageCategory"][] | null;
/** @description Whether to include intermediate images. */
/** @description Whether to list intermediate images. */
is_intermediate?: boolean | null;
/** @description The board id to filter by. Use 'none' to find images without a board. */
board_id?: string | null;
/** @description The order of sort */
order_dir?: components["schemas"]["SQLiteDirection"];
/** @description Whether to sort by starred images first */
starred_first?: boolean;
/** @description The term to search for */
search_term?: string | null;
};
@ -23754,56 +23725,7 @@ export interface operations {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["ImageCollectionCounts"];
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
get_image_collection: {
parameters: {
query?: {
/** @description The origin of images to list. */
image_origin?: components["schemas"]["ResourceOrigin"] | null;
/** @description The categories of image to include. */
categories?: components["schemas"]["ImageCategory"][] | null;
/** @description Whether to list intermediate images. */
is_intermediate?: boolean | null;
/** @description The board id to filter by. Use 'none' to find images without a board. */
board_id?: string | null;
/** @description The offset within the collection */
offset?: number;
/** @description The number of images to return */
limit?: number;
/** @description The order of sort */
order_dir?: components["schemas"]["SQLiteDirection"];
/** @description The term to search for */
search_term?: string | null;
};
header?: never;
path: {
/** @description The collection to retrieve from */
collection: "starred" | "unstarred";
};
cookie?: never;
};
requestBody?: never;
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["OffsetPaginatedResults_ImageDTO_"];
"application/json": string[];
};
};
/** @description Validation Error */

View File

@ -5,7 +5,7 @@ import { deepClone } from 'common/util/deepClone';
import {
selectAutoSwitch,
selectGalleryView,
selectListImagesQueryArgs,
selectListImagesBaseQueryArgs,
selectSelectedBoardId,
} from 'features/gallery/store/gallerySelectors';
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
@ -44,7 +44,7 @@ export const buildOnInvocationComplete = (getState: AppGetState, dispatch: AppDi
const boardTotalAdditions: Record<string, number> = {};
const boardTagIdsToInvalidate: Set<string> = new Set();
const imageListTagIdsToInvalidate: Set<string> = new Set();
const listImagesArg = selectListImagesQueryArgs(getState());
const listImagesArg = selectListImagesBaseQueryArgs(getState());
for (const imageDTO of imageDTOs) {
if (imageDTO.is_intermediate) {
@ -94,7 +94,7 @@ export const buildOnInvocationComplete = (getState: AppGetState, dispatch: AppDi
type: 'ImageList' as const,
id: imageListId,
}));
dispatch(imagesApi.util.invalidateTags([...boardTags, ...imageListTags]));
dispatch(imagesApi.util.invalidateTags(['ImageNameList', ...boardTags, ...imageListTags]));
const autoSwitch = selectAutoSwitch(getState());

View File

@ -211,12 +211,12 @@ def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
assert job.bytes > 0, "expected download bytes to be positive"
assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
assert job.download_path == tmp_path / "sdxl-turbo"
assert Path(tmp_path, "sdxl-turbo/model_index.json").exists(), (
f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist"
)
assert Path(tmp_path, "sdxl-turbo/text_encoder/config.json").exists(), (
f"expected {tmp_path}/sdxl-turbo/text_encoder/config.json to exist"
)
assert Path(
tmp_path, "sdxl-turbo/model_index.json"
).exists(), f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist"
assert Path(
tmp_path, "sdxl-turbo/text_encoder/config.json"
).exists(), f"expected {tmp_path}/sdxl-turbo/text_encoder/config.json to exist"
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
queue.stop()

View File

@ -48,9 +48,9 @@ def test_flux_aitoolkit_transformer_state_dict_is_in_invoke_format():
model_keys = set(model.state_dict().keys())
for converted_key_prefix in converted_key_prefixes:
assert any(model_key.startswith(converted_key_prefix) for model_key in model_keys), (
f"'{converted_key_prefix}' did not match any model keys."
)
assert any(
model_key.startswith(converted_key_prefix) for model_key in model_keys
), f"'{converted_key_prefix}' did not match any model keys."
def test_lora_model_from_flux_aitoolkit_state_dict():