diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 11d5eaef97..013c60ed06 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -7,7 +7,7 @@ from fastapi import Path, Query, Request, UploadFile from fastapi.responses import FileResponse, Response from fastapi.routing import APIRouter from PIL import Image -from invokeai.app.invocations.image import ImageField +from invokeai.app.datatypes.image import ImageResponse from invokeai.app.services.item_storage import PaginatedResults from ...services.image_storage import ImageType @@ -70,13 +70,13 @@ async def upload_image(file: UploadFile, request: Request): @images_router.get( "/", operation_id="list_images", - responses={200: {"model": PaginatedResults[ImageField]}}, + responses={200: {"model": PaginatedResults[ImageResponse]}}, ) async def list_images( image_type: ImageType = Query(default=ImageType.RESULT, description="The type of images to get"), page: int = Query(default=0, description="The page of images to get"), per_page: int = Query(default=10, description="The number of images per page"), -) -> PaginatedResults[ImageField]: +) -> PaginatedResults[ImageResponse]: """Gets a list of images""" result = ApiDependencies.invoker.services.images.list( image_type, page, per_page diff --git a/invokeai/app/datatypes/image.py b/invokeai/app/datatypes/image.py index 9edb16800d..fb1b464b19 100644 --- a/invokeai/app/datatypes/image.py +++ b/invokeai/app/datatypes/image.py @@ -2,6 +2,8 @@ from enum import Enum from typing import Optional from pydantic import BaseModel, Field +from invokeai.app.datatypes.metadata import ImageMetadata + class ImageType(str, Enum): RESULT = "results" @@ -24,3 +26,13 @@ class ImageField(BaseModel): "image_name", ] } + + +class ImageResponse(BaseModel): + """The response type for images""" + + image_type: ImageType = Field(description="The type of the image") + image_name: str = Field(description="The name of the image") + image_url: str = Field(description="The url of the image") + thumbnail_url: str = Field(description="The url of the image's thumbnail") + metadata: ImageMetadata = Field(description="The image's metadata") diff --git a/invokeai/app/datatypes/metadata.py b/invokeai/app/datatypes/metadata.py new file mode 100644 index 0000000000..dc2f1d4dda --- /dev/null +++ b/invokeai/app/datatypes/metadata.py @@ -0,0 +1,11 @@ +from typing import Optional +from pydantic import BaseModel, Field + +class ImageMetadata(BaseModel): + """An image's metadata""" + + timestamp: int = Field(description="The creation timestamp of the image") + width: int = Field(description="The width of the image in pixels") + height: int = Field(description="The width of the image in pixels") + # TODO: figure out metadata + sd_metadata: Optional[dict] = Field(default={}, description="The image's SD-specific metadata") diff --git a/invokeai/app/services/image_storage.py b/invokeai/app/services/image_storage.py index e9edd85d6e..8912b993bc 100644 --- a/invokeai/app/services/image_storage.py +++ b/invokeai/app/services/image_storage.py @@ -7,15 +7,19 @@ from abc import ABC, abstractmethod from enum import Enum from pathlib import Path from queue import Queue -from typing import Callable, Dict +from typing import Callable, Dict, List from PIL.Image import Image -from invokeai.app.datatypes.image import ImageField, ImageType +import PIL.Image as PILImage +from pydantic import BaseModel +from invokeai.app.datatypes.image import ImageField, ImageResponse, ImageType +from invokeai.app.datatypes.metadata import ImageMetadata from invokeai.app.services.item_storage import PaginatedResults from invokeai.app.util.save_thumbnail import save_thumbnail from invokeai.backend.image_util import PngWriter + class ImageStorageBase(ABC): """Responsible for storing and retrieving images.""" @@ -26,12 +30,14 @@ class ImageStorageBase(ABC): @abstractmethod def list( self, image_type: ImageType, page: int = 0, per_page: int = 10 - ) -> PaginatedResults[ImageField]: + ) -> PaginatedResults[ImageResponse]: pass # TODO: make this a bit more flexible for e.g. cloud storage @abstractmethod - def get_path(self, image_type: ImageType, image_name: str) -> str: + def get_path( + self, image_type: ImageType, image_name: str, is_thumbnail: bool = False + ) -> str: pass @abstractmethod @@ -75,32 +81,46 @@ class DiskImageStorage(ImageStorageBase): def list( self, image_type: ImageType, page: int = 0, per_page: int = 10 - ) -> PaginatedResults[ImageField]: + ) -> PaginatedResults[ImageResponse]: dir_path = os.path.join(self.__output_folder, image_type) image_paths = glob(f"{dir_path}/*.png") + count = len(image_paths) - # just want the filenames - image_filenames = list(map(lambda i: os.path.basename(i), image_paths)) - - # we want to sort the images by timestamp, but we don't trust the filesystem - # we do have a timestamp in the filename: `{uuid}_{timestamp}.png` - sorted_paths = sorted( - # extract the timestamp as int and multiply -1 to reverse sorting - image_filenames, key=lambda i: int(os.path.splitext(i)[0].split("_")[-1]) * -1 - ) - - all_images = list( - map(lambda i: ImageField(image_type=image_type, image_name=i), sorted_paths) + # TODO: do all platforms support `getmtime`? seem to recall some do not... + sorted_image_paths = sorted( + glob(f"{dir_path}/*.png"), key=os.path.getmtime, reverse=True ) - count = len(all_images) - page_of_images = all_images[page * per_page : (page + 1) * per_page] + page_of_image_paths = sorted_image_paths[ + page * per_page : (page + 1) * per_page + ] - page_count_trunc = int(count / per_page) + page_of_images: List[ImageResponse] = [] + + for path in page_of_image_paths: + filename = os.path.basename(path) + img = PILImage.open(path) + page_of_images.append( + ImageResponse( + image_type=image_type.value, + image_name=os.path.basename(path), + # TODO: DiskImageStorage should not be building URLs...? + image_url=f"api/v1/images/{image_type.value}/{filename}", + thumbnail_url=f"api/v1/images/{image_type.value}/thumbnails/{os.path.splitext(filename)[0]}.webp", + # TODO: Creation of this object should happen elsewhere, just making it fit here so it works + metadata=ImageMetadata( + timestamp=int(os.path.splitext(filename)[0].split("_")[-1]), + width=img.width, + height=img.height, + ), + ) + ) + + page_count_trunc = int(count / per_page) page_count_mod = count % per_page page_count = page_count_trunc if page_count_mod == 0 else page_count_trunc + 1 - return PaginatedResults[ImageField]( + return PaginatedResults[ImageResponse]( items=page_of_images, page=page, pages=page_count, @@ -119,8 +139,15 @@ class DiskImageStorage(ImageStorageBase): return image # TODO: make this a bit more flexible for e.g. cloud storage - def get_path(self, image_type: ImageType, image_name: str) -> str: - path = os.path.join(self.__output_folder, image_type, image_name) + def get_path( + self, image_type: ImageType, image_name: str, is_thumbnail: bool = False + ) -> str: + if is_thumbnail: + path = os.path.join( + self.__output_folder, image_type, "thumbnails", image_name + ) + else: + path = os.path.join(self.__output_folder, image_type, image_name) return path def save(self, image_type: ImageType, image_name: str, image: Image) -> None: @@ -138,12 +165,19 @@ class DiskImageStorage(ImageStorageBase): def delete(self, image_type: ImageType, image_name: str) -> None: image_path = self.get_path(image_type, image_name) + thumbnail_path = self.get_path(image_type, image_name, True) if os.path.exists(image_path): os.remove(image_path) if image_path in self.__cache: del self.__cache[image_path] + if os.path.exists(thumbnail_path): + os.remove(thumbnail_path) + + if thumbnail_path in self.__cache: + del self.__cache[thumbnail_path] + def __get_cache(self, image_name: str) -> Image: return None if image_name not in self.__cache else self.__cache[image_name]