diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index f0399a2d07..ae10cce140 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -8,6 +8,7 @@ from invokeai.app.models.image import ( ImageCategory, ResourceOrigin, ) +from invokeai.app.services.image_record_storage import OffsetPaginatedResults from invokeai.app.services.models.image_record import ( ImageDTO, ImageRecordChanges, @@ -221,35 +222,28 @@ async def get_image_urls( @images_router.get( "/", operation_id="list_images_with_metadata", - response_model=PaginatedResults[ImageDTO], + response_model=OffsetPaginatedResults[ImageDTO], ) async def list_images_with_metadata( image_origin: Optional[ResourceOrigin] = Query( default=None, description="The origin of images to list" ), - include_categories: Optional[list[ImageCategory]] = Query( + categories: Optional[list[ImageCategory]] = Query( default=None, description="The categories of image to include" ), - exclude_categories: Optional[list[ImageCategory]] = Query( - default=None, description="The categories of image to exclude" - ), is_intermediate: Optional[bool] = Query( default=None, description="Whether to list intermediate images" ), - page: int = Query(default=0, description="The page of images to get"), - per_page: int = Query(default=10, description="The number of images per page"), -) -> PaginatedResults[ImageDTO]: + offset: int = Query(default=0, description="The page offset"), + limit: int = Query(default=10, description="The number of images per page"), +) -> OffsetPaginatedResults[ImageDTO]: """Gets a list of images""" - if include_categories is not None and exclude_categories is not None: - raise HTTPException(status_code=400, detail="Cannot use both 'include_category' and 'exclude_category' at the same time.") - image_dtos = ApiDependencies.invoker.services.images.get_many( - page, - per_page, + offset, + limit, image_origin, - include_categories, - exclude_categories, + categories, is_intermediate, ) diff --git a/invokeai/app/services/image_record_storage.py b/invokeai/app/services/image_record_storage.py index 6b6d1ce7b2..c27596afac 100644 --- a/invokeai/app/services/image_record_storage.py +++ b/invokeai/app/services/image_record_storage.py @@ -1,10 +1,13 @@ from abc import ABC, abstractmethod from datetime import datetime -from typing import Optional, cast +from typing import Generic, Optional, TypeVar, cast import sqlite3 import threading from typing import Optional, Union +from pydantic import BaseModel, Field +from pydantic.generics import GenericModel + from invokeai.app.models.metadata import ImageMetadata from invokeai.app.models.image import ( ImageCategory, @@ -15,7 +18,18 @@ from invokeai.app.services.models.image_record import ( ImageRecordChanges, deserialize_image_record, ) -from invokeai.app.services.item_storage import PaginatedResults + +T = TypeVar("T", bound=BaseModel) + +class OffsetPaginatedResults(GenericModel, Generic[T]): + """Offset-paginated results""" + + # fmt: off + items: list[T] = Field(description="Items") + offset: int = Field(description="Offset from which to retrieve items") + limit: int = Field(description="Limit of items to get") + total: int = Field(description="Total number of items in result") + # fmt: on # TODO: Should these excpetions subclass existing python exceptions? @@ -63,13 +77,12 @@ class ImageRecordStorageBase(ABC): @abstractmethod def get_many( self, - page: int = 0, - per_page: int = 10, + offset: int = 0, + limit: int = 10, image_origin: Optional[ResourceOrigin] = None, - include_categories: Optional[list[ImageCategory]] = None, - exclude_categories: Optional[list[ImageCategory]] = None, + categories: Optional[list[ImageCategory]] = None, is_intermediate: Optional[bool] = None, - ) -> PaginatedResults[ImageRecord]: + ) -> OffsetPaginatedResults[ImageRecord]: """Gets a page of image records.""" pass @@ -238,6 +251,17 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): """, (changes.session_id, image_name), ) + + # Change the image's `is_intermediate`` flag + if changes.session_id is not None: + self._cursor.execute( + f"""--sql + UPDATE images + SET is_intermediate = ? + WHERE image_name = ?; + """, + (changes.is_intermediate, image_name), + ) self._conn.commit() except sqlite3.Error as e: self._conn.rollback() @@ -247,13 +271,12 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): def get_many( self, - page: int = 0, - per_page: int = 10, + offset: int = 0, + limit: int = 10, image_origin: Optional[ResourceOrigin] = None, - include_categories: Optional[list[ImageCategory]] = None, - exclude_categories: Optional[list[ImageCategory]] = None, + categories: Optional[list[ImageCategory]] = None, is_intermediate: Optional[bool] = None, - ) -> PaginatedResults[ImageRecord]: + ) -> OffsetPaginatedResults[ImageRecord]: try: self._lock.acquire() @@ -269,30 +292,18 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): query_conditions += f"""AND image_origin = ?\n""" query_params.append(image_origin.value) - if include_categories is not None: + if categories is not None: ## Convert the enum values to unique list of strings - include_category_strings = list( - map(lambda c: c.value, set(include_categories)) + category_strings = list( + map(lambda c: c.value, set(categories)) ) # Create the correct length of placeholders - placeholders = ",".join("?" * len(include_category_strings)) + placeholders = ",".join("?" * len(category_strings)) query_conditions += f"AND image_category IN ( {placeholders} )\n" # Unpack the included categories into the query params - query_params.append(*include_category_strings) - - if exclude_categories is not None: - ## Convert the enum values to unique list of strings - exclude_category_strings = list( - map(lambda c: c.value, set(exclude_categories)) - ) - - # Create the correct length of placeholders - placeholders = ",".join("?" * len(exclude_category_strings)) - query_conditions += f"AND image_category NOT IN ( {placeholders} )\n" - - # Unpack the included categories into the query params - query_params.append(*exclude_category_strings) + for c in category_strings: + query_params.append(c) if is_intermediate is not None: query_conditions += f"""AND is_intermediate = ?\n""" @@ -304,8 +315,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): images_query += query_conditions + query_pagination + ";" # Add all the parameters images_params = query_params.copy() - images_params.append(per_page) - images_params.append(page * per_page) + images_params.append(limit) + images_params.append(offset) # Build the list of images, deserializing each row self._cursor.execute(images_query, images_params) result = cast(list[sqlite3.Row], self._cursor.fetchall()) @@ -322,10 +333,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): finally: self._lock.release() - pageCount = int(count / per_page) + 1 - - return PaginatedResults( - items=images, page=page, pages=pageCount, per_page=per_page, total=count + return OffsetPaginatedResults( + items=images, offset=offset, limit=limit, total=count ) def delete(self, image_origin: ResourceOrigin, image_name: str) -> None: diff --git a/invokeai/app/services/images.py b/invokeai/app/services/images.py index dca95f673f..2618a9763e 100644 --- a/invokeai/app/services/images.py +++ b/invokeai/app/services/images.py @@ -15,6 +15,7 @@ from invokeai.app.services.image_record_storage import ( ImageRecordNotFoundException, ImageRecordSaveException, ImageRecordStorageBase, + OffsetPaginatedResults, ) from invokeai.app.services.models.image_record import ( ImageRecord, @@ -98,13 +99,12 @@ class ImageServiceABC(ABC): @abstractmethod def get_many( self, - page: int = 0, - per_page: int = 10, + offset: int = 0, + limit: int = 10, image_origin: Optional[ResourceOrigin] = None, - include_categories: Optional[list[ImageCategory]] = None, - exclude_categories: Optional[list[ImageCategory]] = None, + categories: Optional[list[ImageCategory]] = None, is_intermediate: Optional[bool] = None, - ) -> PaginatedResults[ImageDTO]: + ) -> OffsetPaginatedResults[ImageDTO]: """Gets a paginated list of image DTOs.""" pass @@ -328,20 +328,18 @@ class ImageService(ImageServiceABC): def get_many( self, - page: int = 0, - per_page: int = 10, + offset: int = 0, + limit: int = 10, image_origin: Optional[ResourceOrigin] = None, - include_categories: Optional[list[ImageCategory]] = None, - exclude_categories: Optional[list[ImageCategory]] = None, + categories: Optional[list[ImageCategory]] = None, is_intermediate: Optional[bool] = None, - ) -> PaginatedResults[ImageDTO]: + ) -> OffsetPaginatedResults[ImageDTO]: try: results = self._services.records.get_many( - page, - per_page, + offset, + limit, image_origin, - include_categories, - exclude_categories, + categories, is_intermediate, ) @@ -358,11 +356,10 @@ class ImageService(ImageServiceABC): ) ) - return PaginatedResults[ImageDTO]( + return OffsetPaginatedResults[ImageDTO]( items=image_dtos, - page=results.page, - pages=results.pages, - per_page=results.per_page, + offset=results.offset, + limit=results.limit, total=results.total, ) except Exception as e: diff --git a/invokeai/app/services/models/image_record.py b/invokeai/app/services/models/image_record.py index f143a30928..051236b12b 100644 --- a/invokeai/app/services/models/image_record.py +++ b/invokeai/app/services/models/image_record.py @@ -1,6 +1,6 @@ import datetime from typing import Optional, Union -from pydantic import BaseModel, Extra, Field, StrictStr +from pydantic import BaseModel, Extra, Field, StrictBool, StrictStr from invokeai.app.models.image import ImageCategory, ResourceOrigin from invokeai.app.models.metadata import ImageMetadata from invokeai.app.util.misc import get_iso_timestamp @@ -56,6 +56,7 @@ class ImageRecordChanges(BaseModel, extra=Extra.forbid): Only limited changes are valid: - `image_category`: change the category of an image - `session_id`: change the session associated with an image + - `is_intermediate`: change the image's `is_intermediate` flag """ image_category: Optional[ImageCategory] = Field( @@ -67,6 +68,10 @@ class ImageRecordChanges(BaseModel, extra=Extra.forbid): description="The image's new session ID.", ) """The image's new session ID.""" + is_intermediate: Optional[StrictBool] = Field( + default=None, description="The image's new `is_intermediate` flag." + ) + """The image's new `is_intermediate` flag.""" class ImageUrlsDTO(BaseModel): @@ -105,7 +110,9 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord: # Retrieve all the values, setting "reasonable" defaults if they are not present. image_name = image_dict.get("image_name", "unknown") - image_origin = ResourceOrigin(image_dict.get("image_origin", ResourceOrigin.INTERNAL.value)) + image_origin = ResourceOrigin( + image_dict.get("image_origin", ResourceOrigin.INTERNAL.value) + ) image_category = ImageCategory( image_dict.get("image_category", ImageCategory.GENERAL.value) )