feat(nodes): wip ImageResponse

This commit is contained in:
psychedelicious 2023-04-05 14:00:43 +10:00
parent cc3401a159
commit a35dc090c5
4 changed files with 83 additions and 26 deletions

View File

@ -7,7 +7,7 @@ from fastapi import Path, Query, Request, UploadFile
from fastapi.responses import FileResponse, Response
from fastapi.routing import APIRouter
from PIL import Image
from invokeai.app.invocations.image import ImageField
from invokeai.app.datatypes.image import ImageResponse
from invokeai.app.services.item_storage import PaginatedResults
from ...services.image_storage import ImageType
@ -70,13 +70,13 @@ async def upload_image(file: UploadFile, request: Request):
@images_router.get(
"/",
operation_id="list_images",
responses={200: {"model": PaginatedResults[ImageField]}},
responses={200: {"model": PaginatedResults[ImageResponse]}},
)
async def list_images(
image_type: ImageType = Query(default=ImageType.RESULT, description="The type of images to get"),
page: int = Query(default=0, description="The page of images to get"),
per_page: int = Query(default=10, description="The number of images per page"),
) -> PaginatedResults[ImageField]:
) -> PaginatedResults[ImageResponse]:
"""Gets a list of images"""
result = ApiDependencies.invoker.services.images.list(
image_type, page, per_page

View File

@ -2,6 +2,8 @@ from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
from invokeai.app.datatypes.metadata import ImageMetadata
class ImageType(str, Enum):
RESULT = "results"
@ -24,3 +26,13 @@ class ImageField(BaseModel):
"image_name",
]
}
class ImageResponse(BaseModel):
"""The response type for images"""
image_type: ImageType = Field(description="The type of the image")
image_name: str = Field(description="The name of the image")
image_url: str = Field(description="The url of the image")
thumbnail_url: str = Field(description="The url of the image's thumbnail")
metadata: ImageMetadata = Field(description="The image's metadata")

View File

@ -0,0 +1,11 @@
from typing import Optional
from pydantic import BaseModel, Field
class ImageMetadata(BaseModel):
"""An image's metadata"""
timestamp: int = Field(description="The creation timestamp of the image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The width of the image in pixels")
# TODO: figure out metadata
sd_metadata: Optional[dict] = Field(default={}, description="The image's SD-specific metadata")

View File

@ -7,15 +7,19 @@ from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from queue import Queue
from typing import Callable, Dict
from typing import Callable, Dict, List
from PIL.Image import Image
from invokeai.app.datatypes.image import ImageField, ImageType
import PIL.Image as PILImage
from pydantic import BaseModel
from invokeai.app.datatypes.image import ImageField, ImageResponse, ImageType
from invokeai.app.datatypes.metadata import ImageMetadata
from invokeai.app.services.item_storage import PaginatedResults
from invokeai.app.util.save_thumbnail import save_thumbnail
from invokeai.backend.image_util import PngWriter
class ImageStorageBase(ABC):
"""Responsible for storing and retrieving images."""
@ -26,12 +30,14 @@ class ImageStorageBase(ABC):
@abstractmethod
def list(
self, image_type: ImageType, page: int = 0, per_page: int = 10
) -> PaginatedResults[ImageField]:
) -> PaginatedResults[ImageResponse]:
pass
# TODO: make this a bit more flexible for e.g. cloud storage
@abstractmethod
def get_path(self, image_type: ImageType, image_name: str) -> str:
def get_path(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
pass
@abstractmethod
@ -75,32 +81,46 @@ class DiskImageStorage(ImageStorageBase):
def list(
self, image_type: ImageType, page: int = 0, per_page: int = 10
) -> PaginatedResults[ImageField]:
) -> PaginatedResults[ImageResponse]:
dir_path = os.path.join(self.__output_folder, image_type)
image_paths = glob(f"{dir_path}/*.png")
count = len(image_paths)
# just want the filenames
image_filenames = list(map(lambda i: os.path.basename(i), image_paths))
# we want to sort the images by timestamp, but we don't trust the filesystem
# we do have a timestamp in the filename: `{uuid}_{timestamp}.png`
sorted_paths = sorted(
# extract the timestamp as int and multiply -1 to reverse sorting
image_filenames, key=lambda i: int(os.path.splitext(i)[0].split("_")[-1]) * -1
)
all_images = list(
map(lambda i: ImageField(image_type=image_type, image_name=i), sorted_paths)
# TODO: do all platforms support `getmtime`? seem to recall some do not...
sorted_image_paths = sorted(
glob(f"{dir_path}/*.png"), key=os.path.getmtime, reverse=True
)
count = len(all_images)
page_of_images = all_images[page * per_page : (page + 1) * per_page]
page_of_image_paths = sorted_image_paths[
page * per_page : (page + 1) * per_page
]
page_count_trunc = int(count / per_page)
page_of_images: List[ImageResponse] = []
for path in page_of_image_paths:
filename = os.path.basename(path)
img = PILImage.open(path)
page_of_images.append(
ImageResponse(
image_type=image_type.value,
image_name=os.path.basename(path),
# TODO: DiskImageStorage should not be building URLs...?
image_url=f"api/v1/images/{image_type.value}/{filename}",
thumbnail_url=f"api/v1/images/{image_type.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
# TODO: Creation of this object should happen elsewhere, just making it fit here so it works
metadata=ImageMetadata(
timestamp=int(os.path.splitext(filename)[0].split("_")[-1]),
width=img.width,
height=img.height,
),
)
)
page_count_trunc = int(count / per_page)
page_count_mod = count % per_page
page_count = page_count_trunc if page_count_mod == 0 else page_count_trunc + 1
return PaginatedResults[ImageField](
return PaginatedResults[ImageResponse](
items=page_of_images,
page=page,
pages=page_count,
@ -119,8 +139,15 @@ class DiskImageStorage(ImageStorageBase):
return image
# TODO: make this a bit more flexible for e.g. cloud storage
def get_path(self, image_type: ImageType, image_name: str) -> str:
path = os.path.join(self.__output_folder, image_type, image_name)
def get_path(
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
) -> str:
if is_thumbnail:
path = os.path.join(
self.__output_folder, image_type, "thumbnails", image_name
)
else:
path = os.path.join(self.__output_folder, image_type, image_name)
return path
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
@ -138,12 +165,19 @@ class DiskImageStorage(ImageStorageBase):
def delete(self, image_type: ImageType, image_name: str) -> None:
image_path = self.get_path(image_type, image_name)
thumbnail_path = self.get_path(image_type, image_name, True)
if os.path.exists(image_path):
os.remove(image_path)
if image_path in self.__cache:
del self.__cache[image_path]
if os.path.exists(thumbnail_path):
os.remove(thumbnail_path)
if thumbnail_path in self.__cache:
del self.__cache[thumbnail_path]
def __get_cache(self, image_name: str) -> Image:
return None if image_name not in self.__cache else self.__cache[image_name]