feat(api): add metadata to upload route

Canvas images are saved by uploading a blob generated from the HTML canvas element. This means the existing metadata handling, inside the graph execution engine, is not available.

To save metadata to canvas images, we need to provide it when uploading that blob.

The upload route now has a `metadata` body param. If this is provided, we use it over any metadata embedded in the image.
This commit is contained in:
psychedelicious 2024-05-20 09:33:50 +10:00
parent ba8bed6870
commit ecfff6cb1e

View File

@ -6,7 +6,7 @@ from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request,
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from PIL import Image from PIL import Image
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, JsonValue
from invokeai.app.invocations.fields import MetadataField from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
@ -41,14 +41,17 @@ async def upload_image(
board_id: Optional[str] = Query(default=None, description="The board to add this image to, if any"), board_id: Optional[str] = Query(default=None, description="The board to add this image to, if any"),
session_id: Optional[str] = Query(default=None, description="The session ID associated with this upload, if any"), session_id: Optional[str] = Query(default=None, description="The session ID associated with this upload, if any"),
crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"), crop_visible: Optional[bool] = Query(default=False, description="Whether to crop the image"),
metadata: Optional[JsonValue] = Body(
default=None, description="The metadata to associate with the image", embed=True
),
) -> ImageDTO: ) -> ImageDTO:
"""Uploads an image""" """Uploads an image"""
if not file.content_type or not file.content_type.startswith("image"): if not file.content_type or not file.content_type.startswith("image"):
raise HTTPException(status_code=415, detail="Not an image") raise HTTPException(status_code=415, detail="Not an image")
metadata = None _metadata = None
workflow = None _workflow = None
graph = None _graph = None
contents = await file.read() contents = await file.read()
try: try:
@ -62,9 +65,9 @@ async def upload_image(
# TODO: retain non-invokeai metadata on upload? # TODO: retain non-invokeai metadata on upload?
# attempt to parse metadata from image # attempt to parse metadata from image
metadata_raw = pil_image.info.get("invokeai_metadata", None) metadata_raw = metadata if isinstance(metadata, str) else pil_image.info.get("invokeai_metadata", None)
if isinstance(metadata_raw, str): if isinstance(metadata_raw, str):
metadata = metadata_raw _metadata = metadata_raw
else: else:
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image") ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
pass pass
@ -72,7 +75,7 @@ async def upload_image(
# attempt to parse workflow from image # attempt to parse workflow from image
workflow_raw = pil_image.info.get("invokeai_workflow", None) workflow_raw = pil_image.info.get("invokeai_workflow", None)
if isinstance(workflow_raw, str): if isinstance(workflow_raw, str):
workflow = workflow_raw _workflow = workflow_raw
else: else:
ApiDependencies.invoker.services.logger.warn("Failed to parse workflow for uploaded image") ApiDependencies.invoker.services.logger.warn("Failed to parse workflow for uploaded image")
pass pass
@ -80,7 +83,7 @@ async def upload_image(
# attempt to extract graph from image # attempt to extract graph from image
graph_raw = pil_image.info.get("invokeai_graph", None) graph_raw = pil_image.info.get("invokeai_graph", None)
if isinstance(graph_raw, str): if isinstance(graph_raw, str):
graph = graph_raw _graph = graph_raw
else: else:
ApiDependencies.invoker.services.logger.warn("Failed to parse graph for uploaded image") ApiDependencies.invoker.services.logger.warn("Failed to parse graph for uploaded image")
pass pass
@ -92,9 +95,9 @@ async def upload_image(
image_category=image_category, image_category=image_category,
session_id=session_id, session_id=session_id,
board_id=board_id, board_id=board_id,
metadata=metadata, metadata=_metadata,
workflow=workflow, workflow=_workflow,
graph=graph, graph=_graph,
is_intermediate=is_intermediate, is_intermediate=is_intermediate,
) )