feat(nodes): refactor image types

- Remove `ImageType` entirely, it is confusing
- Create `ResourceOrigin`, may be `internal` or `external`
- Revamp `ImageCategory`, may be `general`, `mask`, `control`, `user`, `other`. Expect to add more as time goes on
- Update images `list` route to accept `include_categories` OR `exclude_categories` query parameters to afford finer-grained querying. All services are updated to accomodate this change.

The new setup should account for our types of images, including the combinations we couldn't really handle until now:
- Canvas init and masks
- Canvas when saved-to-gallery or merged
This commit is contained in:
psychedelicious
2023-05-27 21:39:20 +10:00
committed by Kent Keirsey
parent fd47e70c92
commit 160267c71a
17 changed files with 291 additions and 311 deletions

View File

@ -1,39 +0,0 @@
from typing import Optional
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageType
class ImageResponseMetadata(BaseModel):
"""An image's metadata. Used only in HTTP responses."""
created: int = Field(description="The creation timestamp of the image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# invokeai: Optional[InvokeAIMetadata] = Field(
# description="The image's InvokeAI-specific metadata"
# )
class ImageResponse(BaseModel):
"""The response type for images"""
image_type: ImageType = Field(description="The type of the image")
image_name: str = Field(description="The name of the image")
image_url: str = Field(description="The url of the image")
thumbnail_url: str = Field(description="The url of the image's thumbnail")
metadata: ImageResponseMetadata = Field(description="The image's metadata")
class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing"""
width: int = Field(description="The effective width of the image in pixels")
height: int = Field(description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL")
class SavedImage(BaseModel):
image_name: str = Field(description="The name of the saved image")
thumbnail_name: str = Field(description="The name of the saved thumbnail")
created: int = Field(description="The created timestamp of the saved image")

View File

@ -6,7 +6,7 @@ from fastapi.responses import FileResponse
from PIL import Image
from invokeai.app.models.image import (
ImageCategory,
ImageType,
ResourceOrigin,
)
from invokeai.app.services.models.image_record import (
ImageDTO,
@ -36,9 +36,6 @@ async def upload_image(
response: Response,
image_category: ImageCategory = Query(description="The category of the image"),
is_intermediate: bool = Query(description="Whether this is an intermediate image"),
show_in_gallery: bool = Query(
description="Whether this image should be shown in the gallery"
),
session_id: Optional[str] = Query(
default=None, description="The session ID associated with this upload, if any"
),
@ -58,11 +55,10 @@ async def upload_image(
try:
image_dto = ApiDependencies.invoker.services.images.create(
image=pil_image,
image_type=ImageType.UPLOAD,
image_origin=ResourceOrigin.EXTERNAL,
image_category=image_category,
session_id=session_id,
is_intermediate=is_intermediate,
show_in_gallery=show_in_gallery,
)
response.status_code = 201
@ -73,27 +69,27 @@ async def upload_image(
raise HTTPException(status_code=500, detail="Failed to create image")
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
@images_router.delete("/{image_origin}/{image_name}", operation_id="delete_image")
async def delete_image(
image_type: ImageType = Path(description="The type of image to delete"),
image_origin: ResourceOrigin = Path(description="The origin of image to delete"),
image_name: str = Path(description="The name of the image to delete"),
) -> None:
"""Deletes an image"""
try:
ApiDependencies.invoker.services.images.delete(image_type, image_name)
ApiDependencies.invoker.services.images.delete(image_origin, image_name)
except Exception as e:
# TODO: Does this need any exception handling at all?
pass
@images_router.patch(
"/{image_type}/{image_name}",
"/{image_origin}/{image_name}",
operation_id="update_image",
response_model=ImageDTO,
)
async def update_image(
image_type: ImageType = Path(description="The type of image to update"),
image_origin: ResourceOrigin = Path(description="The origin of image to update"),
image_name: str = Path(description="The name of the image to update"),
image_changes: ImageRecordChanges = Body(
description="The changes to apply to the image"
@ -103,31 +99,31 @@ async def update_image(
try:
return ApiDependencies.invoker.services.images.update(
image_type, image_name, image_changes
image_origin, image_name, image_changes
)
except Exception as e:
raise HTTPException(status_code=400, detail="Failed to update image")
@images_router.get(
"/{image_type}/{image_name}/metadata",
"/{image_origin}/{image_name}/metadata",
operation_id="get_image_metadata",
response_model=ImageDTO,
)
async def get_image_metadata(
image_type: ImageType = Path(description="The type of image to get"),
image_origin: ResourceOrigin = Path(description="The origin of image to get"),
image_name: str = Path(description="The name of image to get"),
) -> ImageDTO:
"""Gets an image's metadata"""
try:
return ApiDependencies.invoker.services.images.get_dto(image_type, image_name)
return ApiDependencies.invoker.services.images.get_dto(image_origin, image_name)
except Exception as e:
raise HTTPException(status_code=404)
@images_router.get(
"/{image_type}/{image_name}",
"/{image_origin}/{image_name}",
operation_id="get_image_full",
response_class=Response,
responses={
@ -139,7 +135,7 @@ async def get_image_metadata(
},
)
async def get_image_full(
image_type: ImageType = Path(
image_origin: ResourceOrigin = Path(
description="The type of full-resolution image file to get"
),
image_name: str = Path(description="The name of full-resolution image file to get"),
@ -147,7 +143,7 @@ async def get_image_full(
"""Gets a full-resolution image file"""
try:
path = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
path = ApiDependencies.invoker.services.images.get_path(image_origin, image_name)
if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404)
@ -163,7 +159,7 @@ async def get_image_full(
@images_router.get(
"/{image_type}/{image_name}/thumbnail",
"/{image_origin}/{image_name}/thumbnail",
operation_id="get_image_thumbnail",
response_class=Response,
responses={
@ -175,14 +171,14 @@ async def get_image_full(
},
)
async def get_image_thumbnail(
image_type: ImageType = Path(description="The type of thumbnail image file to get"),
image_origin: ResourceOrigin = Path(description="The origin of thumbnail image file to get"),
image_name: str = Path(description="The name of thumbnail image file to get"),
) -> FileResponse:
"""Gets a thumbnail image file"""
try:
path = ApiDependencies.invoker.services.images.get_path(
image_type, image_name, thumbnail=True
image_origin, image_name, thumbnail=True
)
if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404)
@ -195,25 +191,25 @@ async def get_image_thumbnail(
@images_router.get(
"/{image_type}/{image_name}/urls",
"/{image_origin}/{image_name}/urls",
operation_id="get_image_urls",
response_model=ImageUrlsDTO,
)
async def get_image_urls(
image_type: ImageType = Path(description="The type of the image whose URL to get"),
image_origin: ResourceOrigin = Path(description="The origin of the image whose URL to get"),
image_name: str = Path(description="The name of the image whose URL to get"),
) -> ImageUrlsDTO:
"""Gets an image and thumbnail URL"""
try:
image_url = ApiDependencies.invoker.services.images.get_url(
image_type, image_name
image_origin, image_name
)
thumbnail_url = ApiDependencies.invoker.services.images.get_url(
image_type, image_name, thumbnail=True
image_origin, image_name, thumbnail=True
)
return ImageUrlsDTO(
image_type=image_type,
image_origin=image_origin,
image_name=image_name,
image_url=image_url,
thumbnail_url=thumbnail_url,
@ -228,30 +224,33 @@ async def get_image_urls(
response_model=PaginatedResults[ImageDTO],
)
async def list_images_with_metadata(
image_type: Optional[ImageType] = Query(
default=None, description="The type of images to list"
image_origin: Optional[ResourceOrigin] = Query(
default=None, description="The origin of images to list"
),
image_category: Optional[ImageCategory] = Query(
default=None, description="The kind of images to list"
include_categories: Optional[list[ImageCategory]] = Query(
default=None, description="The categories of image to include"
),
exclude_categories: Optional[list[ImageCategory]] = Query(
default=None, description="The categories of image to exclude"
),
is_intermediate: Optional[bool] = Query(
default=None, description="Whether to list intermediate images"
),
show_in_gallery: Optional[bool] = Query(
default=None, description="Whether to list images that show in the gallery"
),
page: int = Query(default=0, description="The page of images to get"),
per_page: int = Query(default=10, description="The number of images per page"),
) -> PaginatedResults[ImageDTO]:
"""Gets a list of images"""
if include_categories is not None and exclude_categories is not None:
raise HTTPException(status_code=400, detail="Cannot use both 'include_category' and 'exclude_category' at the same time.")
image_dtos = ApiDependencies.invoker.services.images.get_many(
page,
per_page,
image_type,
image_category,
image_origin,
include_categories,
exclude_categories,
is_intermediate,
show_in_gallery,
)
return image_dtos