InvokeAI/invokeai/app/api/routers/images.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

84 lines
2.9 KiB
Python
Raw Normal View History

# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
2023-04-03 18:32:43 +00:00
import io
from datetime import datetime, timezone
2023-03-03 06:02:00 +00:00
2023-04-04 01:05:15 +00:00
from fastapi import Path, Query, Request, UploadFile
from fastapi.responses import FileResponse, Response
2023-03-03 06:02:00 +00:00
from fastapi.routing import APIRouter
from PIL import Image
2023-04-04 01:05:15 +00:00
from invokeai.app.invocations.image import ImageField
from invokeai.app.services.item_storage import PaginatedResults
2023-03-03 06:02:00 +00:00
from ...services.image_storage import ImageType
from ..dependencies import ApiDependencies
2023-03-03 06:02:00 +00:00
images_router = APIRouter(prefix="/v1/images", tags=["images"])
2023-03-03 06:02:00 +00:00
@images_router.get("/{image_type}/{image_name}", operation_id="get_image")
async def get_image(
2023-03-03 06:02:00 +00:00
image_type: ImageType = Path(description="The type of image to get"),
image_name: str = Path(description="The name of the image to get"),
):
"""Gets a result"""
# TODO: This is not really secure at all. At least make sure only output results are served
filename = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
return FileResponse(filename)
2023-04-03 04:34:07 +00:00
@images_router.get("/{image_type}/thumbnails/{image_name}", operation_id="get_thumbnail")
async def get_thumbnail(
image_type: ImageType = Path(description="The type of image to get"),
image_name: str = Path(description="The name of the image to get"),
):
"""Gets a thumbnail"""
# TODO: This is not really secure at all. At least make sure only output results are served
filename = ApiDependencies.invoker.services.images.get_path(image_type, 'thumbnails/' + image_name)
return FileResponse(filename)
2023-03-03 06:02:00 +00:00
@images_router.post(
"/uploads/",
operation_id="upload_image",
responses={
201: {"description": "The image was uploaded successfully"},
404: {"description": "Session not found"},
},
)
async def upload_image(file: UploadFile, request: Request):
if not file.content_type.startswith("image"):
return Response(status_code=415)
contents = await file.read()
try:
2023-04-03 18:32:43 +00:00
im = Image.open(io.BytesIO(contents))
except:
# Error opening the image
2023-03-03 06:02:00 +00:00
return Response(status_code=415)
filename = f"{str(int(datetime.now(timezone.utc).timestamp()))}.png"
2023-04-03 18:32:43 +00:00
ApiDependencies.invoker.services.images.save("uploads", filename, im)
return Response(
status_code=201,
2023-03-03 06:02:00 +00:00
headers={
"Location": request.url_for(
2023-04-03 18:32:43 +00:00
"get_image", image_type="uploads", image_name=filename
2023-03-03 06:02:00 +00:00
)
},
)
2023-04-04 01:05:15 +00:00
@images_router.get(
"/uploads/",
operation_id="list_uploads",
responses={200: {"model": PaginatedResults[ImageField]}},
)
async def list_uploads(
page: int = Query(default=0, description="The page of uploads to get"),
per_page: int = Query(default=10, description="The number of uploads per page"),
) -> PaginatedResults[ImageField]:
"""Gets a list of uploads"""
result = ApiDependencies.invoker.services.images.list(
ImageType.UPLOAD, page, per_page
)
return result