mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): wip ImageResponse
This commit is contained in:
parent
cc3401a159
commit
a35dc090c5
@ -7,7 +7,7 @@ from fastapi import Path, Query, Request, UploadFile
|
|||||||
from fastapi.responses import FileResponse, Response
|
from fastapi.responses import FileResponse, Response
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from invokeai.app.invocations.image import ImageField
|
from invokeai.app.datatypes.image import ImageResponse
|
||||||
from invokeai.app.services.item_storage import PaginatedResults
|
from invokeai.app.services.item_storage import PaginatedResults
|
||||||
|
|
||||||
from ...services.image_storage import ImageType
|
from ...services.image_storage import ImageType
|
||||||
@ -70,13 +70,13 @@ async def upload_image(file: UploadFile, request: Request):
|
|||||||
@images_router.get(
|
@images_router.get(
|
||||||
"/",
|
"/",
|
||||||
operation_id="list_images",
|
operation_id="list_images",
|
||||||
responses={200: {"model": PaginatedResults[ImageField]}},
|
responses={200: {"model": PaginatedResults[ImageResponse]}},
|
||||||
)
|
)
|
||||||
async def list_images(
|
async def list_images(
|
||||||
image_type: ImageType = Query(default=ImageType.RESULT, description="The type of images to get"),
|
image_type: ImageType = Query(default=ImageType.RESULT, description="The type of images to get"),
|
||||||
page: int = Query(default=0, description="The page of images to get"),
|
page: int = Query(default=0, description="The page of images to get"),
|
||||||
per_page: int = Query(default=10, description="The number of images per page"),
|
per_page: int = Query(default=10, description="The number of images per page"),
|
||||||
) -> PaginatedResults[ImageField]:
|
) -> PaginatedResults[ImageResponse]:
|
||||||
"""Gets a list of images"""
|
"""Gets a list of images"""
|
||||||
result = ApiDependencies.invoker.services.images.list(
|
result = ApiDependencies.invoker.services.images.list(
|
||||||
image_type, page, per_page
|
image_type, page, per_page
|
||||||
|
@ -2,6 +2,8 @@ from enum import Enum
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from invokeai.app.datatypes.metadata import ImageMetadata
|
||||||
|
|
||||||
|
|
||||||
class ImageType(str, Enum):
|
class ImageType(str, Enum):
|
||||||
RESULT = "results"
|
RESULT = "results"
|
||||||
@ -24,3 +26,13 @@ class ImageField(BaseModel):
|
|||||||
"image_name",
|
"image_name",
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ImageResponse(BaseModel):
|
||||||
|
"""The response type for images"""
|
||||||
|
|
||||||
|
image_type: ImageType = Field(description="The type of the image")
|
||||||
|
image_name: str = Field(description="The name of the image")
|
||||||
|
image_url: str = Field(description="The url of the image")
|
||||||
|
thumbnail_url: str = Field(description="The url of the image's thumbnail")
|
||||||
|
metadata: ImageMetadata = Field(description="The image's metadata")
|
||||||
|
11
invokeai/app/datatypes/metadata.py
Normal file
11
invokeai/app/datatypes/metadata.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
from typing import Optional
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
class ImageMetadata(BaseModel):
|
||||||
|
"""An image's metadata"""
|
||||||
|
|
||||||
|
timestamp: int = Field(description="The creation timestamp of the image")
|
||||||
|
width: int = Field(description="The width of the image in pixels")
|
||||||
|
height: int = Field(description="The width of the image in pixels")
|
||||||
|
# TODO: figure out metadata
|
||||||
|
sd_metadata: Optional[dict] = Field(default={}, description="The image's SD-specific metadata")
|
@ -7,15 +7,19 @@ from abc import ABC, abstractmethod
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
from typing import Callable, Dict
|
from typing import Callable, Dict, List
|
||||||
|
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
from invokeai.app.datatypes.image import ImageField, ImageType
|
import PIL.Image as PILImage
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from invokeai.app.datatypes.image import ImageField, ImageResponse, ImageType
|
||||||
|
from invokeai.app.datatypes.metadata import ImageMetadata
|
||||||
from invokeai.app.services.item_storage import PaginatedResults
|
from invokeai.app.services.item_storage import PaginatedResults
|
||||||
from invokeai.app.util.save_thumbnail import save_thumbnail
|
from invokeai.app.util.save_thumbnail import save_thumbnail
|
||||||
|
|
||||||
from invokeai.backend.image_util import PngWriter
|
from invokeai.backend.image_util import PngWriter
|
||||||
|
|
||||||
|
|
||||||
class ImageStorageBase(ABC):
|
class ImageStorageBase(ABC):
|
||||||
"""Responsible for storing and retrieving images."""
|
"""Responsible for storing and retrieving images."""
|
||||||
|
|
||||||
@ -26,12 +30,14 @@ class ImageStorageBase(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def list(
|
def list(
|
||||||
self, image_type: ImageType, page: int = 0, per_page: int = 10
|
self, image_type: ImageType, page: int = 0, per_page: int = 10
|
||||||
) -> PaginatedResults[ImageField]:
|
) -> PaginatedResults[ImageResponse]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
def get_path(
|
||||||
|
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
||||||
|
) -> str:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -75,32 +81,46 @@ class DiskImageStorage(ImageStorageBase):
|
|||||||
|
|
||||||
def list(
|
def list(
|
||||||
self, image_type: ImageType, page: int = 0, per_page: int = 10
|
self, image_type: ImageType, page: int = 0, per_page: int = 10
|
||||||
) -> PaginatedResults[ImageField]:
|
) -> PaginatedResults[ImageResponse]:
|
||||||
dir_path = os.path.join(self.__output_folder, image_type)
|
dir_path = os.path.join(self.__output_folder, image_type)
|
||||||
image_paths = glob(f"{dir_path}/*.png")
|
image_paths = glob(f"{dir_path}/*.png")
|
||||||
|
count = len(image_paths)
|
||||||
|
|
||||||
# just want the filenames
|
# TODO: do all platforms support `getmtime`? seem to recall some do not...
|
||||||
image_filenames = list(map(lambda i: os.path.basename(i), image_paths))
|
sorted_image_paths = sorted(
|
||||||
|
glob(f"{dir_path}/*.png"), key=os.path.getmtime, reverse=True
|
||||||
# we want to sort the images by timestamp, but we don't trust the filesystem
|
|
||||||
# we do have a timestamp in the filename: `{uuid}_{timestamp}.png`
|
|
||||||
sorted_paths = sorted(
|
|
||||||
# extract the timestamp as int and multiply -1 to reverse sorting
|
|
||||||
image_filenames, key=lambda i: int(os.path.splitext(i)[0].split("_")[-1]) * -1
|
|
||||||
)
|
|
||||||
|
|
||||||
all_images = list(
|
|
||||||
map(lambda i: ImageField(image_type=image_type, image_name=i), sorted_paths)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
count = len(all_images)
|
page_of_image_paths = sorted_image_paths[
|
||||||
page_of_images = all_images[page * per_page : (page + 1) * per_page]
|
page * per_page : (page + 1) * per_page
|
||||||
|
]
|
||||||
|
|
||||||
page_count_trunc = int(count / per_page)
|
page_of_images: List[ImageResponse] = []
|
||||||
|
|
||||||
|
for path in page_of_image_paths:
|
||||||
|
filename = os.path.basename(path)
|
||||||
|
img = PILImage.open(path)
|
||||||
|
page_of_images.append(
|
||||||
|
ImageResponse(
|
||||||
|
image_type=image_type.value,
|
||||||
|
image_name=os.path.basename(path),
|
||||||
|
# TODO: DiskImageStorage should not be building URLs...?
|
||||||
|
image_url=f"api/v1/images/{image_type.value}/{filename}",
|
||||||
|
thumbnail_url=f"api/v1/images/{image_type.value}/thumbnails/{os.path.splitext(filename)[0]}.webp",
|
||||||
|
# TODO: Creation of this object should happen elsewhere, just making it fit here so it works
|
||||||
|
metadata=ImageMetadata(
|
||||||
|
timestamp=int(os.path.splitext(filename)[0].split("_")[-1]),
|
||||||
|
width=img.width,
|
||||||
|
height=img.height,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
page_count_trunc = int(count / per_page)
|
||||||
page_count_mod = count % per_page
|
page_count_mod = count % per_page
|
||||||
page_count = page_count_trunc if page_count_mod == 0 else page_count_trunc + 1
|
page_count = page_count_trunc if page_count_mod == 0 else page_count_trunc + 1
|
||||||
|
|
||||||
return PaginatedResults[ImageField](
|
return PaginatedResults[ImageResponse](
|
||||||
items=page_of_images,
|
items=page_of_images,
|
||||||
page=page,
|
page=page,
|
||||||
pages=page_count,
|
pages=page_count,
|
||||||
@ -119,8 +139,15 @@ class DiskImageStorage(ImageStorageBase):
|
|||||||
return image
|
return image
|
||||||
|
|
||||||
# TODO: make this a bit more flexible for e.g. cloud storage
|
# TODO: make this a bit more flexible for e.g. cloud storage
|
||||||
def get_path(self, image_type: ImageType, image_name: str) -> str:
|
def get_path(
|
||||||
path = os.path.join(self.__output_folder, image_type, image_name)
|
self, image_type: ImageType, image_name: str, is_thumbnail: bool = False
|
||||||
|
) -> str:
|
||||||
|
if is_thumbnail:
|
||||||
|
path = os.path.join(
|
||||||
|
self.__output_folder, image_type, "thumbnails", image_name
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
path = os.path.join(self.__output_folder, image_type, image_name)
|
||||||
return path
|
return path
|
||||||
|
|
||||||
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
|
def save(self, image_type: ImageType, image_name: str, image: Image) -> None:
|
||||||
@ -138,12 +165,19 @@ class DiskImageStorage(ImageStorageBase):
|
|||||||
|
|
||||||
def delete(self, image_type: ImageType, image_name: str) -> None:
|
def delete(self, image_type: ImageType, image_name: str) -> None:
|
||||||
image_path = self.get_path(image_type, image_name)
|
image_path = self.get_path(image_type, image_name)
|
||||||
|
thumbnail_path = self.get_path(image_type, image_name, True)
|
||||||
if os.path.exists(image_path):
|
if os.path.exists(image_path):
|
||||||
os.remove(image_path)
|
os.remove(image_path)
|
||||||
|
|
||||||
if image_path in self.__cache:
|
if image_path in self.__cache:
|
||||||
del self.__cache[image_path]
|
del self.__cache[image_path]
|
||||||
|
|
||||||
|
if os.path.exists(thumbnail_path):
|
||||||
|
os.remove(thumbnail_path)
|
||||||
|
|
||||||
|
if thumbnail_path in self.__cache:
|
||||||
|
del self.__cache[thumbnail_path]
|
||||||
|
|
||||||
def __get_cache(self, image_name: str) -> Image:
|
def __get_cache(self, image_name: str) -> Image:
|
||||||
return None if image_name not in self.__cache else self.__cache[image_name]
|
return None if image_name not in self.__cache else self.__cache[image_name]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user