mirror of
https://github.com/invoke-ai/InvokeAI
synced 2025-07-26 05:17:55 +00:00
refactor: remove unused methods/routes, fix some gallery invalidation issues
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
import io
|
||||
import json
|
||||
import traceback
|
||||
from typing import ClassVar, Literal, Optional
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
@ -14,7 +14,6 @@ from invokeai.app.api.extract_metadata_from_image import extract_metadata_from_i
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
ImageCollectionCounts,
|
||||
ImageRecordChanges,
|
||||
ResourceOrigin,
|
||||
)
|
||||
@ -565,67 +564,6 @@ async def get_bulk_download_item(
|
||||
raise HTTPException(status_code=404)
|
||||
|
||||
|
||||
@images_router.get(
|
||||
"/collections/counts", operation_id="get_image_collection_counts", response_model=ImageCollectionCounts
|
||||
)
|
||||
async def get_image_collection_counts(
|
||||
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to count."),
|
||||
categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."),
|
||||
is_intermediate: Optional[bool] = Query(default=None, description="Whether to include intermediate images."),
|
||||
board_id: Optional[str] = Query(
|
||||
default=None,
|
||||
description="The board id to filter by. Use 'none' to find images without a board.",
|
||||
),
|
||||
search_term: Optional[str] = Query(default=None, description="The term to search for"),
|
||||
) -> ImageCollectionCounts:
|
||||
"""Gets counts for starred and unstarred image collections"""
|
||||
|
||||
try:
|
||||
return ApiDependencies.invoker.services.images.get_collection_counts(
|
||||
image_origin=image_origin,
|
||||
categories=categories,
|
||||
is_intermediate=is_intermediate,
|
||||
board_id=board_id,
|
||||
search_term=search_term,
|
||||
)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to get collection counts")
|
||||
|
||||
|
||||
@images_router.get("/collections/{collection}", operation_id="get_image_collection")
|
||||
async def get_image_collection(
|
||||
collection: Literal["starred", "unstarred"] = Path(..., description="The collection to retrieve from"),
|
||||
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."),
|
||||
categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."),
|
||||
is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate images."),
|
||||
board_id: Optional[str] = Query(
|
||||
default=None,
|
||||
description="The board id to filter by. Use 'none' to find images without a board.",
|
||||
),
|
||||
offset: int = Query(default=0, description="The offset within the collection"),
|
||||
limit: int = Query(default=50, description="The number of images to return"),
|
||||
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
|
||||
search_term: Optional[str] = Query(default=None, description="The term to search for"),
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets images from a specific collection (starred or unstarred)"""
|
||||
|
||||
try:
|
||||
image_dtos = ApiDependencies.invoker.services.images.get_collection_images(
|
||||
collection=collection,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
order_dir=order_dir,
|
||||
image_origin=image_origin,
|
||||
categories=categories,
|
||||
is_intermediate=is_intermediate,
|
||||
board_id=board_id,
|
||||
search_term=search_term,
|
||||
)
|
||||
return image_dtos
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to get collection images")
|
||||
|
||||
|
||||
@images_router.get("/names", operation_id="get_image_names")
|
||||
async def get_image_names(
|
||||
image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."),
|
||||
@ -636,12 +574,14 @@ async def get_image_names(
|
||||
description="The board id to filter by. Use 'none' to find images without a board.",
|
||||
),
|
||||
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
|
||||
starred_first: bool = Query(default=True, description="Whether to sort by starred images first"),
|
||||
search_term: Optional[str] = Query(default=None, description="The term to search for"),
|
||||
) -> list[str]:
|
||||
"""Gets ordered list of all image names (starred first, then unstarred)"""
|
||||
|
||||
try:
|
||||
image_names = ApiDependencies.invoker.services.images.get_image_names(
|
||||
starred_first=starred_first,
|
||||
order_dir=order_dir,
|
||||
image_origin=image_origin,
|
||||
categories=categories,
|
||||
|
@ -587,9 +587,9 @@ def invocation(
|
||||
for field_name, field_info in cls.model_fields.items():
|
||||
annotation = field_info.annotation
|
||||
assert annotation is not None, f"{field_name} on invocation {invocation_type} has no type annotation."
|
||||
assert isinstance(field_info.json_schema_extra, dict), (
|
||||
f"{field_name} on invocation {invocation_type} has a non-dict json_schema_extra, did you forget to use InputField?"
|
||||
)
|
||||
assert isinstance(
|
||||
field_info.json_schema_extra, dict
|
||||
), f"{field_name} on invocation {invocation_type} has a non-dict json_schema_extra, did you forget to use InputField?"
|
||||
|
||||
original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info)
|
||||
|
||||
@ -712,9 +712,9 @@ def invocation_output(
|
||||
for field_name, field_info in cls.model_fields.items():
|
||||
annotation = field_info.annotation
|
||||
assert annotation is not None, f"{field_name} on invocation output {output_type} has no type annotation."
|
||||
assert isinstance(field_info.json_schema_extra, dict), (
|
||||
f"{field_name} on invocation output {output_type} has a non-dict json_schema_extra, did you forget to use InputField?"
|
||||
)
|
||||
assert isinstance(
|
||||
field_info.json_schema_extra, dict
|
||||
), f"{field_name} on invocation output {output_type} has a non-dict json_schema_extra, did you forget to use InputField?"
|
||||
|
||||
cls._original_model_fields[field_name] = OriginalModelField(annotation=annotation, field_info=field_info)
|
||||
|
||||
|
@ -184,9 +184,9 @@ class SegmentAnythingInvocation(BaseInvocation):
|
||||
# Find the largest mask.
|
||||
return [max(masks, key=lambda x: float(x.sum()))]
|
||||
elif self.mask_filter == "highest_box_score":
|
||||
assert bounding_boxes is not None, (
|
||||
"Bounding boxes must be provided to use the 'highest_box_score' mask filter."
|
||||
)
|
||||
assert (
|
||||
bounding_boxes is not None
|
||||
), "Bounding boxes must be provided to use the 'highest_box_score' mask filter."
|
||||
assert len(masks) == len(bounding_boxes)
|
||||
# Find the index of the bounding box with the highest score.
|
||||
# Note that we fallback to -1.0 if the score is None. This is mainly to satisfy the type checker. In most
|
||||
|
@ -482,9 +482,9 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
try:
|
||||
# Meta is not included in the model fields, so we need to validate it separately
|
||||
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
|
||||
assert config.schema_version == CONFIG_SCHEMA_VERSION, (
|
||||
f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
|
||||
)
|
||||
assert (
|
||||
config.schema_version == CONFIG_SCHEMA_VERSION
|
||||
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
|
||||
return config
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
|
||||
|
@ -1,11 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Literal, Optional
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
ImageCollectionCounts,
|
||||
ImageRecord,
|
||||
ImageRecordChanges,
|
||||
ResourceOrigin,
|
||||
@ -99,37 +98,10 @@ class ImageRecordStorageBase(ABC):
|
||||
"""Gets the most recent image for a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_collection_counts(
|
||||
self,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> ImageCollectionCounts:
|
||||
"""Gets counts for starred and unstarred image collections."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_collection_images(
|
||||
self,
|
||||
collection: Literal["starred", "unstarred"],
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
order_dir: SQLiteDirection = SQLiteDirection.Descending,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
"""Gets images from a specific collection (starred or unstarred)."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_image_names(
|
||||
self,
|
||||
starred_first: bool = True,
|
||||
order_dir: SQLiteDirection = SQLiteDirection.Descending,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
|
@ -1,13 +1,12 @@
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from typing import Literal, Optional, Union, cast
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator
|
||||
from invokeai.app.services.image_records.image_records_base import ImageRecordStorageBase
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
IMAGE_DTO_COLS,
|
||||
ImageCategory,
|
||||
ImageCollectionCounts,
|
||||
ImageRecord,
|
||||
ImageRecordChanges,
|
||||
ImageRecordDeleteException,
|
||||
@ -388,182 +387,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
|
||||
return deserialize_image_record(dict(result))
|
||||
|
||||
def get_collection_counts(
|
||||
self,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> ImageCollectionCounts:
|
||||
cursor = self._conn.cursor()
|
||||
|
||||
# Build the base query conditions (same as get_many)
|
||||
base_query = """--sql
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
query_conditions = ""
|
||||
query_params: list[Union[int, str, bool]] = []
|
||||
|
||||
if image_origin is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.image_origin = ?
|
||||
"""
|
||||
query_params.append(image_origin.value)
|
||||
|
||||
if categories is not None:
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
query_conditions += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
|
||||
if is_intermediate is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
query_params.append(is_intermediate)
|
||||
|
||||
if board_id == "none":
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
elif board_id is not None:
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
query_params.append(board_id)
|
||||
|
||||
if search_term:
|
||||
query_conditions += """--sql
|
||||
AND (
|
||||
images.metadata LIKE ?
|
||||
OR images.created_at LIKE ?
|
||||
)
|
||||
"""
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
|
||||
# Get starred count
|
||||
starred_query = f"SELECT COUNT(*) {base_query} {query_conditions} AND images.starred = TRUE;"
|
||||
cursor.execute(starred_query, query_params)
|
||||
starred_count = cast(int, cursor.fetchone()[0])
|
||||
|
||||
# Get unstarred count
|
||||
unstarred_query = f"SELECT COUNT(*) {base_query} {query_conditions} AND images.starred = FALSE;"
|
||||
cursor.execute(unstarred_query, query_params)
|
||||
unstarred_count = cast(int, cursor.fetchone()[0])
|
||||
|
||||
return ImageCollectionCounts(starred_count=starred_count, unstarred_count=unstarred_count)
|
||||
|
||||
def get_collection_images(
|
||||
self,
|
||||
collection: Literal["starred", "unstarred"],
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
order_dir: SQLiteDirection = SQLiteDirection.Descending,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
cursor = self._conn.cursor()
|
||||
|
||||
# Base queries
|
||||
count_query = """--sql
|
||||
SELECT COUNT(*)
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
images_query = f"""--sql
|
||||
SELECT {IMAGE_DTO_COLS}
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
query_conditions = ""
|
||||
query_params: list[Union[int, str, bool]] = []
|
||||
|
||||
# Add starred/unstarred filter
|
||||
is_starred = collection == "starred"
|
||||
query_conditions += """--sql
|
||||
AND images.starred = ?
|
||||
"""
|
||||
query_params.append(is_starred)
|
||||
|
||||
if image_origin is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.image_origin = ?
|
||||
"""
|
||||
query_params.append(image_origin.value)
|
||||
|
||||
if categories is not None:
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
query_conditions += f"""--sql
|
||||
AND images.image_category IN ( {placeholders} )
|
||||
"""
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
|
||||
if is_intermediate is not None:
|
||||
query_conditions += """--sql
|
||||
AND images.is_intermediate = ?
|
||||
"""
|
||||
query_params.append(is_intermediate)
|
||||
|
||||
if board_id == "none":
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id IS NULL
|
||||
"""
|
||||
elif board_id is not None:
|
||||
query_conditions += """--sql
|
||||
AND board_images.board_id = ?
|
||||
"""
|
||||
query_params.append(board_id)
|
||||
|
||||
if search_term:
|
||||
query_conditions += """--sql
|
||||
AND (
|
||||
images.metadata LIKE ?
|
||||
OR images.created_at LIKE ?
|
||||
)
|
||||
"""
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
|
||||
# Add ordering and pagination
|
||||
query_pagination = f"""--sql
|
||||
ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
|
||||
"""
|
||||
|
||||
# Execute images query
|
||||
images_query += query_conditions + query_pagination + ";"
|
||||
images_params = query_params.copy()
|
||||
images_params.extend([limit, offset])
|
||||
|
||||
cursor.execute(images_query, images_params)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
images = [deserialize_image_record(dict(r)) for r in result]
|
||||
|
||||
# Execute count query
|
||||
count_query += query_conditions + ";"
|
||||
cursor.execute(count_query, query_params)
|
||||
count = cast(int, cursor.fetchone()[0])
|
||||
|
||||
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
|
||||
|
||||
def get_image_names(
|
||||
self,
|
||||
starred_first: bool = True,
|
||||
order_dir: SQLiteDirection = SQLiteDirection.Descending,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
@ -625,13 +451,20 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
|
||||
# Order by starred first, then by created_at
|
||||
query += (
|
||||
query_conditions
|
||||
+ f"""--sql
|
||||
ORDER BY images.starred DESC, images.created_at {order_dir.value}
|
||||
"""
|
||||
)
|
||||
if starred_first:
|
||||
query += (
|
||||
query_conditions
|
||||
+ f"""--sql
|
||||
ORDER BY images.starred DESC, images.created_at {order_dir.value}
|
||||
"""
|
||||
)
|
||||
else:
|
||||
query += (
|
||||
query_conditions
|
||||
+ f"""--sql
|
||||
ORDER BY images.created_at {order_dir.value}
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute(query, query_params)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
|
@ -1,12 +1,11 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Literal, Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
ImageCollectionCounts,
|
||||
ImageRecord,
|
||||
ImageRecordChanges,
|
||||
ResourceOrigin,
|
||||
@ -149,37 +148,10 @@ class ImageServiceABC(ABC):
|
||||
"""Deletes all images on a board."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_collection_counts(
|
||||
self,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> ImageCollectionCounts:
|
||||
"""Gets counts for starred and unstarred image collections."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_collection_images(
|
||||
self,
|
||||
collection: Literal["starred", "unstarred"],
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
order_dir: SQLiteDirection = SQLiteDirection.Descending,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets images from a specific collection (starred or unstarred)."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_image_names(
|
||||
self,
|
||||
starred_first: bool = True,
|
||||
order_dir: SQLiteDirection = SQLiteDirection.Descending,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
@ -187,5 +159,5 @@ class ImageServiceABC(ABC):
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> list[str]:
|
||||
"""Gets ordered list of all image names (starred first, then unstarred)."""
|
||||
"""Gets ordered list of all image names."""
|
||||
pass
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Literal, Optional
|
||||
from typing import Optional
|
||||
|
||||
from PIL.Image import Image as PILImageType
|
||||
|
||||
@ -10,7 +10,6 @@ from invokeai.app.services.image_files.image_files_common import (
|
||||
)
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
ImageCollectionCounts,
|
||||
ImageRecord,
|
||||
ImageRecordChanges,
|
||||
ImageRecordDeleteException,
|
||||
@ -311,73 +310,9 @@ class ImageService(ImageServiceABC):
|
||||
self.__invoker.services.logger.error("Problem getting intermediates count")
|
||||
raise e
|
||||
|
||||
def get_collection_counts(
|
||||
self,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> ImageCollectionCounts:
|
||||
try:
|
||||
return self.__invoker.services.image_records.get_collection_counts(
|
||||
image_origin=image_origin,
|
||||
categories=categories,
|
||||
is_intermediate=is_intermediate,
|
||||
board_id=board_id,
|
||||
search_term=search_term,
|
||||
)
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Problem getting collection counts")
|
||||
raise e
|
||||
|
||||
def get_collection_images(
|
||||
self,
|
||||
collection: Literal["starred", "unstarred"],
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
order_dir: SQLiteDirection = SQLiteDirection.Descending,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
try:
|
||||
results = self.__invoker.services.image_records.get_collection_images(
|
||||
collection=collection,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
order_dir=order_dir,
|
||||
image_origin=image_origin,
|
||||
categories=categories,
|
||||
is_intermediate=is_intermediate,
|
||||
board_id=board_id,
|
||||
search_term=search_term,
|
||||
)
|
||||
|
||||
image_dtos = [
|
||||
image_record_to_dto(
|
||||
image_record=r,
|
||||
image_url=self.__invoker.services.urls.get_image_url(r.image_name),
|
||||
thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True),
|
||||
board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name),
|
||||
)
|
||||
for r in results.items
|
||||
]
|
||||
|
||||
return OffsetPaginatedResults[ImageDTO](
|
||||
items=image_dtos,
|
||||
offset=results.offset,
|
||||
limit=results.limit,
|
||||
total=results.total,
|
||||
)
|
||||
except Exception as e:
|
||||
self.__invoker.services.logger.error("Problem getting collection images")
|
||||
raise e
|
||||
|
||||
def get_image_names(
|
||||
self,
|
||||
starred_first: bool = True,
|
||||
order_dir: SQLiteDirection = SQLiteDirection.Descending,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
@ -387,6 +322,7 @@ class ImageService(ImageServiceABC):
|
||||
) -> list[str]:
|
||||
try:
|
||||
return self.__invoker.services.image_records.get_image_names(
|
||||
starred_first=starred_first,
|
||||
order_dir=order_dir,
|
||||
image_origin=image_origin,
|
||||
categories=categories,
|
||||
|
@ -379,13 +379,13 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
bytes_ = path.read_bytes()
|
||||
workflow_from_file = WorkflowValidator.validate_json(bytes_)
|
||||
|
||||
assert workflow_from_file.id.startswith("default_"), (
|
||||
f'Invalid default workflow ID (must start with "default_"): {workflow_from_file.id}'
|
||||
)
|
||||
assert workflow_from_file.id.startswith(
|
||||
"default_"
|
||||
), f'Invalid default workflow ID (must start with "default_"): {workflow_from_file.id}'
|
||||
|
||||
assert workflow_from_file.meta.category is WorkflowCategory.Default, (
|
||||
f"Invalid default workflow category: {workflow_from_file.meta.category}"
|
||||
)
|
||||
assert (
|
||||
workflow_from_file.meta.category is WorkflowCategory.Default
|
||||
), f"Invalid default workflow category: {workflow_from_file.meta.category}"
|
||||
|
||||
workflows_from_file.append(workflow_from_file)
|
||||
|
||||
|
@ -381,7 +381,7 @@ class LoRALyCORISConfig(LoRAConfigBase, ModelConfigBase):
|
||||
|
||||
state_dict = mod.load_state_dict()
|
||||
for key in state_dict.keys():
|
||||
if type(key) is int:
|
||||
if isinstance(key, int):
|
||||
continue
|
||||
|
||||
if key.startswith(("lora_te_", "lora_unet_", "lora_te1_", "lora_te2_", "lora_transformer_")):
|
||||
|
@ -115,19 +115,19 @@ class ModelMerger(object):
|
||||
base_models: Set[BaseModelType] = set()
|
||||
variant = None if self._installer.app_config.precision == "float32" else "fp16"
|
||||
|
||||
assert len(model_keys) <= 2 or interp == MergeInterpolationMethod.AddDifference, (
|
||||
"When merging three models, only the 'add_difference' merge method is supported"
|
||||
)
|
||||
assert (
|
||||
len(model_keys) <= 2 or interp == MergeInterpolationMethod.AddDifference
|
||||
), "When merging three models, only the 'add_difference' merge method is supported"
|
||||
|
||||
for key in model_keys:
|
||||
info = store.get_model(key)
|
||||
model_names.append(info.name)
|
||||
assert isinstance(info, MainDiffusersConfig), (
|
||||
f"{info.name} ({info.key}) is not a diffusers model. It must be optimized before merging"
|
||||
)
|
||||
assert info.variant == ModelVariantType("normal"), (
|
||||
f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged"
|
||||
)
|
||||
assert isinstance(
|
||||
info, MainDiffusersConfig
|
||||
), f"{info.name} ({info.key}) is not a diffusers model. It must be optimized before merging"
|
||||
assert info.variant == ModelVariantType(
|
||||
"normal"
|
||||
), f"{info.name} ({info.key}) is a {info.variant} model, which cannot currently be merged"
|
||||
|
||||
# tally base models used
|
||||
base_models.add(info.base)
|
||||
|
@ -1,6 +1,6 @@
|
||||
import { isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { selectListImagesBaseQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
|
||||
@ -13,7 +13,7 @@ export const addBoardIdSelectedListener = (startAppListening: AppStartListening)
|
||||
|
||||
const state = getState();
|
||||
|
||||
const queryArgs = selectListImagesQueryArgs(state);
|
||||
const queryArgs = { ...selectListImagesBaseQueryArgs(state), offset: 0 };
|
||||
|
||||
// wait until the board has some images - maybe it already has some from a previous fetch
|
||||
// must use getState() to ensure we do not have stale state
|
||||
|
@ -1,29 +1,9 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { RootState } from 'app/store/store';
|
||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { selectListImageNamesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { uniq } from 'lodash-es';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageCategory, SQLiteDirection } from 'services/api/types';
|
||||
|
||||
// Type for image collection query arguments
|
||||
type ImageCollectionQueryArgs = {
|
||||
board_id?: string;
|
||||
categories?: ImageCategory[];
|
||||
search_term?: string;
|
||||
order_dir?: SQLiteDirection;
|
||||
is_intermediate: boolean;
|
||||
};
|
||||
|
||||
/**
|
||||
* Helper function to get cached image names list for selection operations
|
||||
* Returns an ordered array of image names (starred first, then unstarred)
|
||||
*/
|
||||
const getCachedImageNames = (state: RootState, queryArgs: ImageCollectionQueryArgs): string[] => {
|
||||
const queryResult = imagesApi.endpoints.getImageNames.select(queryArgs)(state);
|
||||
return queryResult.data || [];
|
||||
};
|
||||
|
||||
export const galleryImageClicked = createAction<{
|
||||
imageName: string;
|
||||
@ -50,10 +30,8 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const { imageName, shiftKey, ctrlKey, metaKey, altKey } = action.payload;
|
||||
const state = getState();
|
||||
const queryArgs = selectListImagesQueryArgs(state);
|
||||
|
||||
// Get cached image names for selection operations
|
||||
const imageNames = getCachedImageNames(state, queryArgs);
|
||||
const queryArgs = selectListImageNamesQueryArgs(state);
|
||||
const imageNames = imagesApi.endpoints.getImageNames.select(queryArgs)(state).data ?? [];
|
||||
|
||||
// If we don't have the image names cached, we can't perform selection operations
|
||||
// This can happen if the user clicks on an image before the names are loaded
|
||||
|
@ -10,7 +10,6 @@ import {
|
||||
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
|
||||
import type { CanvasState, RefImagesState } from 'features/controlLayers/store/types';
|
||||
import type { ImageUsage } from 'features/deleteImageModal/store/types';
|
||||
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { fieldImageCollectionValueChanged, fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
@ -81,14 +80,8 @@ const handleDeletions = async (image_names: string[], dispatch: AppDispatch, get
|
||||
await dispatch(imagesApi.endpoints.deleteImages.initiate({ image_names }, { track: false })).unwrap();
|
||||
|
||||
if (intersection(state.gallery.selection, image_names).length > 0) {
|
||||
// Some selected images were deleted, need to select the next image
|
||||
const queryArgs = selectListImagesQueryArgs(state);
|
||||
const { data } = imagesApi.endpoints.listImages.select(queryArgs)(state);
|
||||
if (data) {
|
||||
// When we delete multiple images, we clear the selection. Then, the the next time we load images, we will
|
||||
// select the first one. This is handled below in the listener for `imagesApi.endpoints.listImages.matchFulfilled`.
|
||||
dispatch(imageSelected(null));
|
||||
}
|
||||
// Some selected images were deleted, clear selection
|
||||
dispatch(imageSelected(null));
|
||||
}
|
||||
|
||||
// We need to reset the features where the image is in use - none of these work if their image(s) don't exist
|
||||
|
@ -5,6 +5,7 @@ import { CanvasAlertsInvocationProgress } from 'features/controlLayers/component
|
||||
import { DndImage } from 'features/dnd/DndImage';
|
||||
import ImageMetadataViewer from 'features/gallery/components/ImageMetadataViewer/ImageMetadataViewer';
|
||||
import NextPrevImageButtons from 'features/gallery/components/NextPrevImageButtons';
|
||||
import { selectAutoSwitch } from 'features/gallery/store/gallerySelectors';
|
||||
import type { ProgressImage as ProgressImageType } from 'features/nodes/types/common';
|
||||
import { selectShouldShowImageDetails, selectShouldShowProgressInViewer } from 'features/ui/store/uiSelectors';
|
||||
import type { AnimationProps } from 'framer-motion';
|
||||
@ -21,6 +22,7 @@ import { ProgressIndicator } from './ProgressIndicator2';
|
||||
export const CurrentImagePreview = memo(({ imageDTO }: { imageDTO: ImageDTO | null }) => {
|
||||
const shouldShowImageDetails = useAppSelector(selectShouldShowImageDetails);
|
||||
const shouldShowProgressInViewer = useAppSelector(selectShouldShowProgressInViewer);
|
||||
const autoSwitch = useAppSelector(selectAutoSwitch);
|
||||
|
||||
const socket = useStore($socket);
|
||||
const [progressEvent, setProgressEvent] = useState<S['InvocationProgressEvent'] | null>(null);
|
||||
@ -58,6 +60,29 @@ export const CurrentImagePreview = memo(({ imageDTO }: { imageDTO: ImageDTO | nu
|
||||
};
|
||||
}, [socket]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!socket) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (autoSwitch) {
|
||||
return;
|
||||
}
|
||||
// When auto-switch is enabled, we will get a load event as we switch to the new image. This in turn clears the progress image,
|
||||
// creating the illusion of the progress image turning into the new image.
|
||||
// But when auto-switch is disabled, we won't get that load event, so we need to clear the progress image manually.
|
||||
const onQueueItemStatusChanged = () => {
|
||||
setProgressEvent(null);
|
||||
setProgressImage(null);
|
||||
};
|
||||
|
||||
socket.on('queue_item_status_changed', onQueueItemStatusChanged);
|
||||
|
||||
return () => {
|
||||
socket.off('queue_item_status_changed', onQueueItemStatusChanged);
|
||||
};
|
||||
}, [autoSwitch, socket]);
|
||||
|
||||
const onLoadImage = useCallback(() => {
|
||||
if (!progressEvent || !imageDTO) {
|
||||
return;
|
||||
|
@ -4,10 +4,11 @@ import { logger } from 'app/logging/logger';
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import {
|
||||
LIMIT,
|
||||
selectGalleryImageMinimumWidth,
|
||||
selectImageToCompare,
|
||||
selectLastSelectedImage,
|
||||
selectListImagesQueryArgs,
|
||||
selectListImageNamesQueryArgs,
|
||||
} from 'features/gallery/store/gallerySelectors';
|
||||
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
|
||||
@ -37,31 +38,26 @@ const SCROLL_SEEK_VELOCITY_THRESHOLD = 4096;
|
||||
const DEBOUNCE_DELAY = 500;
|
||||
const SPINNER_OPACITY = 0.3;
|
||||
|
||||
type ListImagesQueryArgs = ReturnType<typeof selectListImagesQueryArgs>;
|
||||
type ListImageNamesQueryArgs = ReturnType<typeof selectListImageNamesQueryArgs>;
|
||||
|
||||
type GridContext = {
|
||||
queryArgs: ListImagesQueryArgs;
|
||||
queryArgs: ListImageNamesQueryArgs;
|
||||
imageNames: string[];
|
||||
};
|
||||
|
||||
export const useDebouncedListImagesQueryArgs = () => {
|
||||
const _galleryQueryArgs = useAppSelector(selectListImagesQueryArgs);
|
||||
const [queryArgs] = useDebounce(_galleryQueryArgs, DEBOUNCE_DELAY);
|
||||
return queryArgs;
|
||||
};
|
||||
|
||||
// Hook to get an image DTO from cache or trigger loading
|
||||
const useImageDTOFromListQuery = (
|
||||
index: number,
|
||||
imageName: string,
|
||||
queryArgs: ListImagesQueryArgs
|
||||
queryArgs: ListImageNamesQueryArgs
|
||||
): ImageDTO | null => {
|
||||
const { arg, options } = useMemo(() => {
|
||||
const pageOffset = Math.floor(index / queryArgs.limit) * queryArgs.limit;
|
||||
const pageOffset = Math.floor(index / LIMIT) * LIMIT;
|
||||
return {
|
||||
arg: {
|
||||
...queryArgs,
|
||||
offset: pageOffset,
|
||||
limit: LIMIT,
|
||||
} satisfies Parameters<typeof useListImagesQuery>[0],
|
||||
options: {
|
||||
selectFromResult: ({ data }) => {
|
||||
@ -82,7 +78,7 @@ const useImageDTOFromListQuery = (
|
||||
|
||||
// Individual image component that gets its data from RTK Query cache
|
||||
const ImageAtPosition = memo(
|
||||
({ index, queryArgs, imageName }: { index: number; imageName: string; queryArgs: ListImagesQueryArgs }) => {
|
||||
({ index, queryArgs, imageName }: { index: number; imageName: string; queryArgs: ListImageNamesQueryArgs }) => {
|
||||
const imageDTO = useImageDTOFromListQuery(index, imageName, queryArgs);
|
||||
|
||||
if (!imageDTO) {
|
||||
@ -408,7 +404,8 @@ const getImageNamesQueryOptions = {
|
||||
} satisfies Parameters<typeof useGetImageNamesQuery>[1];
|
||||
|
||||
export const useGalleryImageNames = () => {
|
||||
const queryArgs = useDebouncedListImagesQueryArgs();
|
||||
const _queryArgs = useAppSelector(selectListImageNamesQueryArgs);
|
||||
const [queryArgs] = useDebounce(_queryArgs, DEBOUNCE_DELAY);
|
||||
const { imageNames, isLoading, isFetching } = useGetImageNamesQuery(queryArgs, getImageNamesQueryOptions);
|
||||
return { imageNames, isLoading, isFetching, queryArgs };
|
||||
};
|
||||
|
@ -2,8 +2,7 @@ import { createSelector } from '@reduxjs/toolkit';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { selectGallerySlice } from 'features/gallery/store/gallerySlice';
|
||||
import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/types';
|
||||
import type { ListBoardsArgs, ListImagesArgs } from 'services/api/types';
|
||||
import type { SetNonNullable } from 'type-fest';
|
||||
import type { ListBoardsArgs } from 'services/api/types';
|
||||
|
||||
export const selectFirstSelectedImage = createSelector(selectGallerySlice, (gallery) => gallery.selection.at(0));
|
||||
export const selectLastSelectedImage = createSelector(selectGallerySlice, (gallery) => gallery.selection.at(-1));
|
||||
@ -28,7 +27,7 @@ export const selectGallerySearchTerm = createSelector(selectGallerySlice, (galle
|
||||
export const selectGalleryOrderDir = createSelector(selectGallerySlice, (gallery) => gallery.orderDir);
|
||||
export const selectGalleryStarredFirst = createSelector(selectGallerySlice, (gallery) => gallery.starredFirst);
|
||||
|
||||
export const selectListImagesQueryArgs = createMemoizedSelector(
|
||||
export const selectListImageNamesQueryArgs = createMemoizedSelector(
|
||||
[
|
||||
selectSelectedBoardId,
|
||||
selectGalleryQueryCategories,
|
||||
@ -36,17 +35,20 @@ export const selectListImagesQueryArgs = createMemoizedSelector(
|
||||
selectGalleryOrderDir,
|
||||
selectGalleryStarredFirst,
|
||||
],
|
||||
(board_id, categories, search_term, order_dir, starred_first) =>
|
||||
({
|
||||
board_id,
|
||||
categories,
|
||||
search_term,
|
||||
order_dir,
|
||||
starred_first,
|
||||
is_intermediate: false, // We don't show intermediate images in the gallery
|
||||
limit: 100, // Page size is _always_ 100
|
||||
}) satisfies SetNonNullable<ListImagesArgs, 'limit'>
|
||||
(board_id, categories, search_term, order_dir, starred_first) => ({
|
||||
board_id,
|
||||
categories,
|
||||
search_term,
|
||||
order_dir,
|
||||
starred_first,
|
||||
is_intermediate: false,
|
||||
})
|
||||
);
|
||||
export const LIMIT = 100;
|
||||
export const selectListImagesBaseQueryArgs = createMemoizedSelector(selectListImageNamesQueryArgs, (baseQueryArgs) => ({
|
||||
...baseQueryArgs,
|
||||
limit: LIMIT,
|
||||
}));
|
||||
export const selectAutoAssignBoardOnClick = createSelector(
|
||||
selectGallerySlice,
|
||||
(gallery) => gallery.autoAssignBoardOnClick
|
||||
|
@ -427,61 +427,12 @@ export const imagesApi = api.injectEndpoints({
|
||||
},
|
||||
}),
|
||||
}),
|
||||
/**
|
||||
* Get counts for starred and unstarred image collections
|
||||
*/
|
||||
getImageCollectionCounts: build.query<
|
||||
paths['/api/v1/images/collections/counts']['get']['responses']['200']['content']['application/json'],
|
||||
paths['/api/v1/images/collections/counts']['get']['parameters']['query']
|
||||
>({
|
||||
query: (queryArgs) => ({
|
||||
url: buildImagesUrl('collections/counts', queryArgs),
|
||||
method: 'GET',
|
||||
}),
|
||||
providesTags: ['ImageCollectionCounts', 'FetchOnReconnect'],
|
||||
}),
|
||||
/**
|
||||
* Get images from a specific collection (starred or unstarred)
|
||||
*/
|
||||
getImageCollection: build.query<
|
||||
paths['/api/v1/images/collections/{collection}']['get']['responses']['200']['content']['application/json'],
|
||||
paths['/api/v1/images/collections/{collection}']['get']['parameters']['path'] &
|
||||
paths['/api/v1/images/collections/{collection}']['get']['parameters']['query']
|
||||
>({
|
||||
query: ({ collection, ...queryArgs }) => ({
|
||||
url: buildImagesUrl(`collections/${collection}`, queryArgs),
|
||||
method: 'GET',
|
||||
}),
|
||||
providesTags: (result, error, { collection, board_id, categories }) => {
|
||||
const cacheKey = `${collection}-${board_id || 'all'}-${categories?.join(',') || 'all'}`;
|
||||
return [
|
||||
{ type: 'ImageCollection', id: collection },
|
||||
{ type: 'ImageCollection', id: cacheKey },
|
||||
'FetchOnReconnect',
|
||||
];
|
||||
},
|
||||
async onQueryStarted(_, { dispatch, queryFulfilled }) {
|
||||
// Populate the getImageDTO cache with these images, similar to listImages
|
||||
const res = await queryFulfilled;
|
||||
const imageDTOs = res.data.items;
|
||||
const updates: Param0<typeof imagesApi.util.upsertQueryEntries> = [];
|
||||
for (const imageDTO of imageDTOs) {
|
||||
updates.push({
|
||||
endpointName: 'getImageDTO',
|
||||
arg: imageDTO.image_name,
|
||||
value: imageDTO,
|
||||
});
|
||||
}
|
||||
dispatch(imagesApi.util.upsertQueryEntries(updates));
|
||||
},
|
||||
}),
|
||||
/**
|
||||
* Get ordered list of image names for selection operations
|
||||
*/
|
||||
getImageNames: build.query<
|
||||
string[],
|
||||
{
|
||||
image_origin?: 'internal' | 'external' | null;
|
||||
categories?: ImageCategory[] | null;
|
||||
is_intermediate?: boolean | null;
|
||||
board_id?: string | null;
|
||||
@ -493,46 +444,11 @@ export const imagesApi = api.injectEndpoints({
|
||||
url: buildImagesUrl('names', queryArgs),
|
||||
method: 'GET',
|
||||
}),
|
||||
providesTags: ['ImageNameList', 'FetchOnReconnect'],
|
||||
}),
|
||||
/**
|
||||
* Get paginated images with starred first (unified list)
|
||||
*/
|
||||
getUnifiedImageList: build.query<
|
||||
ListImagesResponse,
|
||||
{
|
||||
offset?: number;
|
||||
limit?: number;
|
||||
image_origin?: 'internal' | 'external' | null;
|
||||
categories?: ImageCategory[] | null;
|
||||
is_intermediate?: boolean | null;
|
||||
board_id?: string | null;
|
||||
search_term?: string | null;
|
||||
order_dir?: SQLiteDirection;
|
||||
}
|
||||
>({
|
||||
query: (queryArgs) => ({
|
||||
url: getListImagesUrl({ ...queryArgs, starred_first: true }),
|
||||
method: 'GET',
|
||||
}),
|
||||
providesTags: (result, error, { board_id, categories }) => [
|
||||
{ type: 'ImageList', id: getListImagesUrl({ board_id, categories }) },
|
||||
providesTags: (result, error, queryArgs) => [
|
||||
'ImageNameList',
|
||||
'FetchOnReconnect',
|
||||
{ type: 'ImageNameList', id: stableHash(queryArgs) },
|
||||
],
|
||||
async onQueryStarted(_, { dispatch, queryFulfilled }) {
|
||||
// Populate the getImageDTO cache with these images
|
||||
const res = await queryFulfilled;
|
||||
const imageDTOs = res.data.items;
|
||||
const updates: Param0<typeof imagesApi.util.upsertQueryEntries> = [];
|
||||
for (const imageDTO of imageDTOs) {
|
||||
updates.push({
|
||||
endpointName: 'getImageDTO',
|
||||
arg: imageDTO.image_name,
|
||||
value: imageDTO,
|
||||
});
|
||||
}
|
||||
dispatch(imagesApi.util.upsertQueryEntries(updates));
|
||||
},
|
||||
}),
|
||||
}),
|
||||
});
|
||||
@ -555,11 +471,7 @@ export const {
|
||||
useStarImagesMutation,
|
||||
useUnstarImagesMutation,
|
||||
useBulkDownloadImagesMutation,
|
||||
useGetImageCollectionCountsQuery,
|
||||
useGetImageCollectionQuery,
|
||||
useLazyGetImageCollectionQuery,
|
||||
useGetImageNamesQuery,
|
||||
useGetUnifiedImageListQuery,
|
||||
} = imagesApi;
|
||||
|
||||
/**
|
||||
|
@ -752,7 +752,7 @@ export type paths = {
|
||||
patch?: never;
|
||||
trace?: never;
|
||||
};
|
||||
"/api/v1/images/collections/counts": {
|
||||
"/api/v1/images/names": {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
@ -760,30 +760,10 @@ export type paths = {
|
||||
cookie?: never;
|
||||
};
|
||||
/**
|
||||
* Get Image Collection Counts
|
||||
* @description Gets counts for starred and unstarred image collections
|
||||
* Get Image Names
|
||||
* @description Gets ordered list of all image names (starred first, then unstarred)
|
||||
*/
|
||||
get: operations["get_image_collection_counts"];
|
||||
put?: never;
|
||||
post?: never;
|
||||
delete?: never;
|
||||
options?: never;
|
||||
head?: never;
|
||||
patch?: never;
|
||||
trace?: never;
|
||||
};
|
||||
"/api/v1/images/collections/{collection}": {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
path?: never;
|
||||
cookie?: never;
|
||||
};
|
||||
/**
|
||||
* Get Image Collection
|
||||
* @description Gets images from a specific collection (starred or unstarred)
|
||||
*/
|
||||
get: operations["get_image_collection"];
|
||||
get: operations["get_image_names"];
|
||||
put?: never;
|
||||
post?: never;
|
||||
delete?: never;
|
||||
@ -9844,19 +9824,6 @@ export type components = {
|
||||
*/
|
||||
type: "img_channel_offset";
|
||||
};
|
||||
/** ImageCollectionCounts */
|
||||
ImageCollectionCounts: {
|
||||
/**
|
||||
* Starred Count
|
||||
* @description The number of starred images in the collection.
|
||||
*/
|
||||
starred_count: number;
|
||||
/**
|
||||
* Unstarred Count
|
||||
* @description The number of unstarred images in the collection.
|
||||
*/
|
||||
unstarred_count: number;
|
||||
};
|
||||
/**
|
||||
* Image Collection Primitive
|
||||
* @description A collection of image primitive values
|
||||
@ -23728,17 +23695,21 @@ export interface operations {
|
||||
};
|
||||
};
|
||||
};
|
||||
get_image_collection_counts: {
|
||||
get_image_names: {
|
||||
parameters: {
|
||||
query?: {
|
||||
/** @description The origin of images to count. */
|
||||
/** @description The origin of images to list. */
|
||||
image_origin?: components["schemas"]["ResourceOrigin"] | null;
|
||||
/** @description The categories of image to include. */
|
||||
categories?: components["schemas"]["ImageCategory"][] | null;
|
||||
/** @description Whether to include intermediate images. */
|
||||
/** @description Whether to list intermediate images. */
|
||||
is_intermediate?: boolean | null;
|
||||
/** @description The board id to filter by. Use 'none' to find images without a board. */
|
||||
board_id?: string | null;
|
||||
/** @description The order of sort */
|
||||
order_dir?: components["schemas"]["SQLiteDirection"];
|
||||
/** @description Whether to sort by starred images first */
|
||||
starred_first?: boolean;
|
||||
/** @description The term to search for */
|
||||
search_term?: string | null;
|
||||
};
|
||||
@ -23754,56 +23725,7 @@ export interface operations {
|
||||
[name: string]: unknown;
|
||||
};
|
||||
content: {
|
||||
"application/json": components["schemas"]["ImageCollectionCounts"];
|
||||
};
|
||||
};
|
||||
/** @description Validation Error */
|
||||
422: {
|
||||
headers: {
|
||||
[name: string]: unknown;
|
||||
};
|
||||
content: {
|
||||
"application/json": components["schemas"]["HTTPValidationError"];
|
||||
};
|
||||
};
|
||||
};
|
||||
};
|
||||
get_image_collection: {
|
||||
parameters: {
|
||||
query?: {
|
||||
/** @description The origin of images to list. */
|
||||
image_origin?: components["schemas"]["ResourceOrigin"] | null;
|
||||
/** @description The categories of image to include. */
|
||||
categories?: components["schemas"]["ImageCategory"][] | null;
|
||||
/** @description Whether to list intermediate images. */
|
||||
is_intermediate?: boolean | null;
|
||||
/** @description The board id to filter by. Use 'none' to find images without a board. */
|
||||
board_id?: string | null;
|
||||
/** @description The offset within the collection */
|
||||
offset?: number;
|
||||
/** @description The number of images to return */
|
||||
limit?: number;
|
||||
/** @description The order of sort */
|
||||
order_dir?: components["schemas"]["SQLiteDirection"];
|
||||
/** @description The term to search for */
|
||||
search_term?: string | null;
|
||||
};
|
||||
header?: never;
|
||||
path: {
|
||||
/** @description The collection to retrieve from */
|
||||
collection: "starred" | "unstarred";
|
||||
};
|
||||
cookie?: never;
|
||||
};
|
||||
requestBody?: never;
|
||||
responses: {
|
||||
/** @description Successful Response */
|
||||
200: {
|
||||
headers: {
|
||||
[name: string]: unknown;
|
||||
};
|
||||
content: {
|
||||
"application/json": components["schemas"]["OffsetPaginatedResults_ImageDTO_"];
|
||||
"application/json": string[];
|
||||
};
|
||||
};
|
||||
/** @description Validation Error */
|
||||
|
@ -5,7 +5,7 @@ import { deepClone } from 'common/util/deepClone';
|
||||
import {
|
||||
selectAutoSwitch,
|
||||
selectGalleryView,
|
||||
selectListImagesQueryArgs,
|
||||
selectListImagesBaseQueryArgs,
|
||||
selectSelectedBoardId,
|
||||
} from 'features/gallery/store/gallerySelectors';
|
||||
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
@ -44,7 +44,7 @@ export const buildOnInvocationComplete = (getState: AppGetState, dispatch: AppDi
|
||||
const boardTotalAdditions: Record<string, number> = {};
|
||||
const boardTagIdsToInvalidate: Set<string> = new Set();
|
||||
const imageListTagIdsToInvalidate: Set<string> = new Set();
|
||||
const listImagesArg = selectListImagesQueryArgs(getState());
|
||||
const listImagesArg = selectListImagesBaseQueryArgs(getState());
|
||||
|
||||
for (const imageDTO of imageDTOs) {
|
||||
if (imageDTO.is_intermediate) {
|
||||
@ -94,7 +94,7 @@ export const buildOnInvocationComplete = (getState: AppGetState, dispatch: AppDi
|
||||
type: 'ImageList' as const,
|
||||
id: imageListId,
|
||||
}));
|
||||
dispatch(imagesApi.util.invalidateTags([...boardTags, ...imageListTags]));
|
||||
dispatch(imagesApi.util.invalidateTags(['ImageNameList', ...boardTags, ...imageListTags]));
|
||||
|
||||
const autoSwitch = selectAutoSwitch(getState());
|
||||
|
||||
|
@ -211,12 +211,12 @@ def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
|
||||
assert job.bytes > 0, "expected download bytes to be positive"
|
||||
assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
|
||||
assert job.download_path == tmp_path / "sdxl-turbo"
|
||||
assert Path(tmp_path, "sdxl-turbo/model_index.json").exists(), (
|
||||
f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist"
|
||||
)
|
||||
assert Path(tmp_path, "sdxl-turbo/text_encoder/config.json").exists(), (
|
||||
f"expected {tmp_path}/sdxl-turbo/text_encoder/config.json to exist"
|
||||
)
|
||||
assert Path(
|
||||
tmp_path, "sdxl-turbo/model_index.json"
|
||||
).exists(), f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist"
|
||||
assert Path(
|
||||
tmp_path, "sdxl-turbo/text_encoder/config.json"
|
||||
).exists(), f"expected {tmp_path}/sdxl-turbo/text_encoder/config.json to exist"
|
||||
|
||||
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
|
||||
queue.stop()
|
||||
|
@ -48,9 +48,9 @@ def test_flux_aitoolkit_transformer_state_dict_is_in_invoke_format():
|
||||
model_keys = set(model.state_dict().keys())
|
||||
|
||||
for converted_key_prefix in converted_key_prefixes:
|
||||
assert any(model_key.startswith(converted_key_prefix) for model_key in model_keys), (
|
||||
f"'{converted_key_prefix}' did not match any model keys."
|
||||
)
|
||||
assert any(
|
||||
model_key.startswith(converted_key_prefix) for model_key in model_keys
|
||||
), f"'{converted_key_prefix}' did not match any model keys."
|
||||
|
||||
|
||||
def test_lora_model_from_flux_aitoolkit_state_dict():
|
||||
|
Reference in New Issue
Block a user