mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): make list images route use offset pagination
Because we dynamically insert images into the DB and UI's images state, `page`/`per_page` pagination makes loading the images awkward. Using `offset`/`limit` pagination lets us query for images with an offset equal to the number of images already loaded (which match the query parameters). The result is that we always get the correct next page of images when loading more.
This commit is contained in:
parent
38fd2ad45d
commit
f31e62afad
@ -8,6 +8,7 @@ from invokeai.app.models.image import (
|
||||
ImageCategory,
|
||||
ResourceOrigin,
|
||||
)
|
||||
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
|
||||
from invokeai.app.services.models.image_record import (
|
||||
ImageDTO,
|
||||
ImageRecordChanges,
|
||||
@ -221,35 +222,28 @@ async def get_image_urls(
|
||||
@images_router.get(
|
||||
"/",
|
||||
operation_id="list_images_with_metadata",
|
||||
response_model=PaginatedResults[ImageDTO],
|
||||
response_model=OffsetPaginatedResults[ImageDTO],
|
||||
)
|
||||
async def list_images_with_metadata(
|
||||
image_origin: Optional[ResourceOrigin] = Query(
|
||||
default=None, description="The origin of images to list"
|
||||
),
|
||||
include_categories: Optional[list[ImageCategory]] = Query(
|
||||
categories: Optional[list[ImageCategory]] = Query(
|
||||
default=None, description="The categories of image to include"
|
||||
),
|
||||
exclude_categories: Optional[list[ImageCategory]] = Query(
|
||||
default=None, description="The categories of image to exclude"
|
||||
),
|
||||
is_intermediate: Optional[bool] = Query(
|
||||
default=None, description="Whether to list intermediate images"
|
||||
),
|
||||
page: int = Query(default=0, description="The page of images to get"),
|
||||
per_page: int = Query(default=10, description="The number of images per page"),
|
||||
) -> PaginatedResults[ImageDTO]:
|
||||
offset: int = Query(default=0, description="The page offset"),
|
||||
limit: int = Query(default=10, description="The number of images per page"),
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a list of images"""
|
||||
|
||||
if include_categories is not None and exclude_categories is not None:
|
||||
raise HTTPException(status_code=400, detail="Cannot use both 'include_category' and 'exclude_category' at the same time.")
|
||||
|
||||
image_dtos = ApiDependencies.invoker.services.images.get_many(
|
||||
page,
|
||||
per_page,
|
||||
offset,
|
||||
limit,
|
||||
image_origin,
|
||||
include_categories,
|
||||
exclude_categories,
|
||||
categories,
|
||||
is_intermediate,
|
||||
)
|
||||
|
||||
|
@ -1,10 +1,13 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Optional, cast
|
||||
from typing import Generic, Optional, TypeVar, cast
|
||||
import sqlite3
|
||||
import threading
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.generics import GenericModel
|
||||
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.models.image import (
|
||||
ImageCategory,
|
||||
@ -15,7 +18,18 @@ from invokeai.app.services.models.image_record import (
|
||||
ImageRecordChanges,
|
||||
deserialize_image_record,
|
||||
)
|
||||
from invokeai.app.services.item_storage import PaginatedResults
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
class OffsetPaginatedResults(GenericModel, Generic[T]):
|
||||
"""Offset-paginated results"""
|
||||
|
||||
# fmt: off
|
||||
items: list[T] = Field(description="Items")
|
||||
offset: int = Field(description="Offset from which to retrieve items")
|
||||
limit: int = Field(description="Limit of items to get")
|
||||
total: int = Field(description="Total number of items in result")
|
||||
# fmt: on
|
||||
|
||||
|
||||
# TODO: Should these excpetions subclass existing python exceptions?
|
||||
@ -63,13 +77,12 @@ class ImageRecordStorageBase(ABC):
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
page: int = 0,
|
||||
per_page: int = 10,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
include_categories: Optional[list[ImageCategory]] = None,
|
||||
exclude_categories: Optional[list[ImageCategory]] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
) -> PaginatedResults[ImageRecord]:
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
"""Gets a page of image records."""
|
||||
pass
|
||||
|
||||
@ -238,6 +251,17 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
""",
|
||||
(changes.session_id, image_name),
|
||||
)
|
||||
|
||||
# Change the image's `is_intermediate`` flag
|
||||
if changes.session_id is not None:
|
||||
self._cursor.execute(
|
||||
f"""--sql
|
||||
UPDATE images
|
||||
SET is_intermediate = ?
|
||||
WHERE image_name = ?;
|
||||
""",
|
||||
(changes.is_intermediate, image_name),
|
||||
)
|
||||
self._conn.commit()
|
||||
except sqlite3.Error as e:
|
||||
self._conn.rollback()
|
||||
@ -247,13 +271,12 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
page: int = 0,
|
||||
per_page: int = 10,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
include_categories: Optional[list[ImageCategory]] = None,
|
||||
exclude_categories: Optional[list[ImageCategory]] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
) -> PaginatedResults[ImageRecord]:
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
try:
|
||||
self._lock.acquire()
|
||||
|
||||
@ -269,30 +292,18 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
query_conditions += f"""AND image_origin = ?\n"""
|
||||
query_params.append(image_origin.value)
|
||||
|
||||
if include_categories is not None:
|
||||
if categories is not None:
|
||||
## Convert the enum values to unique list of strings
|
||||
include_category_strings = list(
|
||||
map(lambda c: c.value, set(include_categories))
|
||||
category_strings = list(
|
||||
map(lambda c: c.value, set(categories))
|
||||
)
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(include_category_strings))
|
||||
placeholders = ",".join("?" * len(category_strings))
|
||||
query_conditions += f"AND image_category IN ( {placeholders} )\n"
|
||||
|
||||
# Unpack the included categories into the query params
|
||||
query_params.append(*include_category_strings)
|
||||
|
||||
if exclude_categories is not None:
|
||||
## Convert the enum values to unique list of strings
|
||||
exclude_category_strings = list(
|
||||
map(lambda c: c.value, set(exclude_categories))
|
||||
)
|
||||
|
||||
# Create the correct length of placeholders
|
||||
placeholders = ",".join("?" * len(exclude_category_strings))
|
||||
query_conditions += f"AND image_category NOT IN ( {placeholders} )\n"
|
||||
|
||||
# Unpack the included categories into the query params
|
||||
query_params.append(*exclude_category_strings)
|
||||
for c in category_strings:
|
||||
query_params.append(c)
|
||||
|
||||
if is_intermediate is not None:
|
||||
query_conditions += f"""AND is_intermediate = ?\n"""
|
||||
@ -304,8 +315,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
images_query += query_conditions + query_pagination + ";"
|
||||
# Add all the parameters
|
||||
images_params = query_params.copy()
|
||||
images_params.append(per_page)
|
||||
images_params.append(page * per_page)
|
||||
images_params.append(limit)
|
||||
images_params.append(offset)
|
||||
# Build the list of images, deserializing each row
|
||||
self._cursor.execute(images_query, images_params)
|
||||
result = cast(list[sqlite3.Row], self._cursor.fetchall())
|
||||
@ -322,10 +333,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
pageCount = int(count / per_page) + 1
|
||||
|
||||
return PaginatedResults(
|
||||
items=images, page=page, pages=pageCount, per_page=per_page, total=count
|
||||
return OffsetPaginatedResults(
|
||||
items=images, offset=offset, limit=limit, total=count
|
||||
)
|
||||
|
||||
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
|
||||
|
@ -15,6 +15,7 @@ from invokeai.app.services.image_record_storage import (
|
||||
ImageRecordNotFoundException,
|
||||
ImageRecordSaveException,
|
||||
ImageRecordStorageBase,
|
||||
OffsetPaginatedResults,
|
||||
)
|
||||
from invokeai.app.services.models.image_record import (
|
||||
ImageRecord,
|
||||
@ -98,13 +99,12 @@ class ImageServiceABC(ABC):
|
||||
@abstractmethod
|
||||
def get_many(
|
||||
self,
|
||||
page: int = 0,
|
||||
per_page: int = 10,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
include_categories: Optional[list[ImageCategory]] = None,
|
||||
exclude_categories: Optional[list[ImageCategory]] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
) -> PaginatedResults[ImageDTO]:
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
"""Gets a paginated list of image DTOs."""
|
||||
pass
|
||||
|
||||
@ -328,20 +328,18 @@ class ImageService(ImageServiceABC):
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
page: int = 0,
|
||||
per_page: int = 10,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
include_categories: Optional[list[ImageCategory]] = None,
|
||||
exclude_categories: Optional[list[ImageCategory]] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
) -> PaginatedResults[ImageDTO]:
|
||||
) -> OffsetPaginatedResults[ImageDTO]:
|
||||
try:
|
||||
results = self._services.records.get_many(
|
||||
page,
|
||||
per_page,
|
||||
offset,
|
||||
limit,
|
||||
image_origin,
|
||||
include_categories,
|
||||
exclude_categories,
|
||||
categories,
|
||||
is_intermediate,
|
||||
)
|
||||
|
||||
@ -358,11 +356,10 @@ class ImageService(ImageServiceABC):
|
||||
)
|
||||
)
|
||||
|
||||
return PaginatedResults[ImageDTO](
|
||||
return OffsetPaginatedResults[ImageDTO](
|
||||
items=image_dtos,
|
||||
page=results.page,
|
||||
pages=results.pages,
|
||||
per_page=results.per_page,
|
||||
offset=results.offset,
|
||||
limit=results.limit,
|
||||
total=results.total,
|
||||
)
|
||||
except Exception as e:
|
||||
|
@ -1,6 +1,6 @@
|
||||
import datetime
|
||||
from typing import Optional, Union
|
||||
from pydantic import BaseModel, Extra, Field, StrictStr
|
||||
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
|
||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.models.metadata import ImageMetadata
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
@ -56,6 +56,7 @@ class ImageRecordChanges(BaseModel, extra=Extra.forbid):
|
||||
Only limited changes are valid:
|
||||
- `image_category`: change the category of an image
|
||||
- `session_id`: change the session associated with an image
|
||||
- `is_intermediate`: change the image's `is_intermediate` flag
|
||||
"""
|
||||
|
||||
image_category: Optional[ImageCategory] = Field(
|
||||
@ -67,6 +68,10 @@ class ImageRecordChanges(BaseModel, extra=Extra.forbid):
|
||||
description="The image's new session ID.",
|
||||
)
|
||||
"""The image's new session ID."""
|
||||
is_intermediate: Optional[StrictBool] = Field(
|
||||
default=None, description="The image's new `is_intermediate` flag."
|
||||
)
|
||||
"""The image's new `is_intermediate` flag."""
|
||||
|
||||
|
||||
class ImageUrlsDTO(BaseModel):
|
||||
@ -105,7 +110,9 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
# Retrieve all the values, setting "reasonable" defaults if they are not present.
|
||||
|
||||
image_name = image_dict.get("image_name", "unknown")
|
||||
image_origin = ResourceOrigin(image_dict.get("image_origin", ResourceOrigin.INTERNAL.value))
|
||||
image_origin = ResourceOrigin(
|
||||
image_dict.get("image_origin", ResourceOrigin.INTERNAL.value)
|
||||
)
|
||||
image_category = ImageCategory(
|
||||
image_dict.get("image_category", ImageCategory.GENERAL.value)
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user