feat(nodes): make list images route use offset pagination

Because we dynamically insert images into the DB and UI's images state, `page`/`per_page` pagination makes loading the images awkward.

Using `offset`/`limit` pagination lets us query for images with an offset equal to the number of images already loaded (which match the query parameters).

The result is that we always get the correct next page of images when loading more.
This commit is contained in:
psychedelicious 2023-05-28 18:59:14 +10:00 committed by Kent Keirsey
parent 38fd2ad45d
commit f31e62afad
4 changed files with 78 additions and 71 deletions

View File

@ -8,6 +8,7 @@ from invokeai.app.models.image import (
ImageCategory,
ResourceOrigin,
)
from invokeai.app.services.image_record_storage import OffsetPaginatedResults
from invokeai.app.services.models.image_record import (
ImageDTO,
ImageRecordChanges,
@ -221,35 +222,28 @@ async def get_image_urls(
@images_router.get(
"/",
operation_id="list_images_with_metadata",
response_model=PaginatedResults[ImageDTO],
response_model=OffsetPaginatedResults[ImageDTO],
)
async def list_images_with_metadata(
image_origin: Optional[ResourceOrigin] = Query(
default=None, description="The origin of images to list"
),
include_categories: Optional[list[ImageCategory]] = Query(
categories: Optional[list[ImageCategory]] = Query(
default=None, description="The categories of image to include"
),
exclude_categories: Optional[list[ImageCategory]] = Query(
default=None, description="The categories of image to exclude"
),
is_intermediate: Optional[bool] = Query(
default=None, description="Whether to list intermediate images"
),
page: int = Query(default=0, description="The page of images to get"),
per_page: int = Query(default=10, description="The number of images per page"),
) -> PaginatedResults[ImageDTO]:
offset: int = Query(default=0, description="The page offset"),
limit: int = Query(default=10, description="The number of images per page"),
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a list of images"""
if include_categories is not None and exclude_categories is not None:
raise HTTPException(status_code=400, detail="Cannot use both 'include_category' and 'exclude_category' at the same time.")
image_dtos = ApiDependencies.invoker.services.images.get_many(
page,
per_page,
offset,
limit,
image_origin,
include_categories,
exclude_categories,
categories,
is_intermediate,
)

View File

@ -1,10 +1,13 @@
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Optional, cast
from typing import Generic, Optional, TypeVar, cast
import sqlite3
import threading
from typing import Optional, Union
from pydantic import BaseModel, Field
from pydantic.generics import GenericModel
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.models.image import (
ImageCategory,
@ -15,7 +18,18 @@ from invokeai.app.services.models.image_record import (
ImageRecordChanges,
deserialize_image_record,
)
from invokeai.app.services.item_storage import PaginatedResults
T = TypeVar("T", bound=BaseModel)
class OffsetPaginatedResults(GenericModel, Generic[T]):
"""Offset-paginated results"""
# fmt: off
items: list[T] = Field(description="Items")
offset: int = Field(description="Offset from which to retrieve items")
limit: int = Field(description="Limit of items to get")
total: int = Field(description="Total number of items in result")
# fmt: on
# TODO: Should these excpetions subclass existing python exceptions?
@ -63,13 +77,12 @@ class ImageRecordStorageBase(ABC):
@abstractmethod
def get_many(
self,
page: int = 0,
per_page: int = 10,
offset: int = 0,
limit: int = 10,
image_origin: Optional[ResourceOrigin] = None,
include_categories: Optional[list[ImageCategory]] = None,
exclude_categories: Optional[list[ImageCategory]] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
) -> PaginatedResults[ImageRecord]:
) -> OffsetPaginatedResults[ImageRecord]:
"""Gets a page of image records."""
pass
@ -238,6 +251,17 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
""",
(changes.session_id, image_name),
)
# Change the image's `is_intermediate`` flag
if changes.session_id is not None:
self._cursor.execute(
f"""--sql
UPDATE images
SET is_intermediate = ?
WHERE image_name = ?;
""",
(changes.is_intermediate, image_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
@ -247,13 +271,12 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
def get_many(
self,
page: int = 0,
per_page: int = 10,
offset: int = 0,
limit: int = 10,
image_origin: Optional[ResourceOrigin] = None,
include_categories: Optional[list[ImageCategory]] = None,
exclude_categories: Optional[list[ImageCategory]] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
) -> PaginatedResults[ImageRecord]:
) -> OffsetPaginatedResults[ImageRecord]:
try:
self._lock.acquire()
@ -269,30 +292,18 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
query_conditions += f"""AND image_origin = ?\n"""
query_params.append(image_origin.value)
if include_categories is not None:
if categories is not None:
## Convert the enum values to unique list of strings
include_category_strings = list(
map(lambda c: c.value, set(include_categories))
category_strings = list(
map(lambda c: c.value, set(categories))
)
# Create the correct length of placeholders
placeholders = ",".join("?" * len(include_category_strings))
placeholders = ",".join("?" * len(category_strings))
query_conditions += f"AND image_category IN ( {placeholders} )\n"
# Unpack the included categories into the query params
query_params.append(*include_category_strings)
if exclude_categories is not None:
## Convert the enum values to unique list of strings
exclude_category_strings = list(
map(lambda c: c.value, set(exclude_categories))
)
# Create the correct length of placeholders
placeholders = ",".join("?" * len(exclude_category_strings))
query_conditions += f"AND image_category NOT IN ( {placeholders} )\n"
# Unpack the included categories into the query params
query_params.append(*exclude_category_strings)
for c in category_strings:
query_params.append(c)
if is_intermediate is not None:
query_conditions += f"""AND is_intermediate = ?\n"""
@ -304,8 +315,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
images_query += query_conditions + query_pagination + ";"
# Add all the parameters
images_params = query_params.copy()
images_params.append(per_page)
images_params.append(page * per_page)
images_params.append(limit)
images_params.append(offset)
# Build the list of images, deserializing each row
self._cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], self._cursor.fetchall())
@ -322,10 +333,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
finally:
self._lock.release()
pageCount = int(count / per_page) + 1
return PaginatedResults(
items=images, page=page, pages=pageCount, per_page=per_page, total=count
return OffsetPaginatedResults(
items=images, offset=offset, limit=limit, total=count
)
def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:

View File

@ -15,6 +15,7 @@ from invokeai.app.services.image_record_storage import (
ImageRecordNotFoundException,
ImageRecordSaveException,
ImageRecordStorageBase,
OffsetPaginatedResults,
)
from invokeai.app.services.models.image_record import (
ImageRecord,
@ -98,13 +99,12 @@ class ImageServiceABC(ABC):
@abstractmethod
def get_many(
self,
page: int = 0,
per_page: int = 10,
offset: int = 0,
limit: int = 10,
image_origin: Optional[ResourceOrigin] = None,
include_categories: Optional[list[ImageCategory]] = None,
exclude_categories: Optional[list[ImageCategory]] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
) -> PaginatedResults[ImageDTO]:
) -> OffsetPaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs."""
pass
@ -328,20 +328,18 @@ class ImageService(ImageServiceABC):
def get_many(
self,
page: int = 0,
per_page: int = 10,
offset: int = 0,
limit: int = 10,
image_origin: Optional[ResourceOrigin] = None,
include_categories: Optional[list[ImageCategory]] = None,
exclude_categories: Optional[list[ImageCategory]] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
) -> PaginatedResults[ImageDTO]:
) -> OffsetPaginatedResults[ImageDTO]:
try:
results = self._services.records.get_many(
page,
per_page,
offset,
limit,
image_origin,
include_categories,
exclude_categories,
categories,
is_intermediate,
)
@ -358,11 +356,10 @@ class ImageService(ImageServiceABC):
)
)
return PaginatedResults[ImageDTO](
return OffsetPaginatedResults[ImageDTO](
items=image_dtos,
page=results.page,
pages=results.pages,
per_page=results.per_page,
offset=results.offset,
limit=results.limit,
total=results.total,
)
except Exception as e:

View File

@ -1,6 +1,6 @@
import datetime
from typing import Optional, Union
from pydantic import BaseModel, Extra, Field, StrictStr
from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr
from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.misc import get_iso_timestamp
@ -56,6 +56,7 @@ class ImageRecordChanges(BaseModel, extra=Extra.forbid):
Only limited changes are valid:
- `image_category`: change the category of an image
- `session_id`: change the session associated with an image
- `is_intermediate`: change the image's `is_intermediate` flag
"""
image_category: Optional[ImageCategory] = Field(
@ -67,6 +68,10 @@ class ImageRecordChanges(BaseModel, extra=Extra.forbid):
description="The image's new session ID.",
)
"""The image's new session ID."""
is_intermediate: Optional[StrictBool] = Field(
default=None, description="The image's new `is_intermediate` flag."
)
"""The image's new `is_intermediate` flag."""
class ImageUrlsDTO(BaseModel):
@ -105,7 +110,9 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
# Retrieve all the values, setting "reasonable" defaults if they are not present.
image_name = image_dict.get("image_name", "unknown")
image_origin = ResourceOrigin(image_dict.get("image_origin", ResourceOrigin.INTERNAL.value))
image_origin = ResourceOrigin(
image_dict.get("image_origin", ResourceOrigin.INTERNAL.value)
)
image_category = ImageCategory(
image_dict.get("image_category", ImageCategory.GENERAL.value)
)