feat(app): simplified create image API

Graph, metadata and workflow all take stringified JSON only. This makes the API consistent and means we don't need to do a round-trip of pydantic parsing when handling this data.

It also prevents a failure mode where an uploaded image's metadata, workflow or graph are old and don't match the current schema.

As before, the frontend does strict validation and parsing when loading these values.
This commit is contained in:
psychedelicious 2024-05-17 19:25:04 +10:00
parent 93ebc175c6
commit 5928ade5fd
5 changed files with 40 additions and 47 deletions

View File

@ -6,13 +6,12 @@ from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request,
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from PIL import Image from PIL import Image
from pydantic import BaseModel, Field, ValidationError from pydantic import BaseModel, Field
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO from invokeai.app.services.images.images_common import ImageDTO, ImageUrlsDTO
from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutIDValidator
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
@ -64,21 +63,19 @@ async def upload_image(
# TODO: retain non-invokeai metadata on upload? # TODO: retain non-invokeai metadata on upload?
# attempt to parse metadata from image # attempt to parse metadata from image
metadata_raw = pil_image.info.get("invokeai_metadata", None) metadata_raw = pil_image.info.get("invokeai_metadata", None)
if metadata_raw: if isinstance(metadata_raw, str):
try: metadata = metadata_raw
metadata = MetadataFieldValidator.validate_json(metadata_raw) else:
except ValidationError: ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image")
ApiDependencies.invoker.services.logger.warn("Failed to parse metadata for uploaded image") pass
pass
# attempt to parse workflow from image # attempt to parse workflow from image
workflow_raw = pil_image.info.get("invokeai_workflow", None) workflow_raw = pil_image.info.get("invokeai_workflow", None)
if workflow_raw is not None: if isinstance(workflow_raw, str):
try: workflow = workflow_raw
workflow = WorkflowWithoutIDValidator.validate_json(workflow_raw) else:
except ValidationError: ApiDependencies.invoker.services.logger.warn("Failed to parse workflow for uploaded image")
ApiDependencies.invoker.services.logger.warn("Failed to parse workflow for uploaded image") pass
pass
# attempt to extract graph from image # attempt to extract graph from image
graph_raw = pil_image.info.get("invokeai_graph", None) graph_raw = pil_image.info.get("invokeai_graph", None)

View File

@ -4,10 +4,6 @@ from typing import Optional
from PIL.Image import Image as PILImageType from PIL.Image import Image as PILImageType
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.shared.graph import Graph
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
class ImageFileStorageBase(ABC): class ImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files.""" """Low-level service responsible for storing and retrieving image files."""
@ -34,9 +30,9 @@ class ImageFileStorageBase(ABC):
self, self,
image: PILImageType, image: PILImageType,
image_name: str, image_name: str,
metadata: Optional[MetadataField] = None, metadata: Optional[str] = None,
workflow: Optional[WorkflowWithoutID] = None, workflow: Optional[str] = None,
graph: Optional[Graph | str] = None, graph: Optional[str] = None,
thumbnail_size: int = 256, thumbnail_size: int = 256,
) -> None: ) -> None:
"""Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp.""" """Saves an image and a 256x256 WEBP thumbnail. Returns a tuple of the image name, thumbnail name, and created timestamp."""

View File

@ -7,10 +7,7 @@ from PIL import Image, PngImagePlugin
from PIL.Image import Image as PILImageType from PIL.Image import Image as PILImageType
from send2trash import send2trash from send2trash import send2trash
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.graph import Graph
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
from .image_files_base import ImageFileStorageBase from .image_files_base import ImageFileStorageBase
@ -57,9 +54,9 @@ class DiskImageFileStorage(ImageFileStorageBase):
self, self,
image: PILImageType, image: PILImageType,
image_name: str, image_name: str,
metadata: Optional[MetadataField] = None, metadata: Optional[str] = None,
workflow: Optional[WorkflowWithoutID] = None, workflow: Optional[str] = None,
graph: Optional[Graph | str] = None, graph: Optional[str] = None,
thumbnail_size: int = 256, thumbnail_size: int = 256,
) -> None: ) -> None:
try: try:
@ -70,17 +67,14 @@ class DiskImageFileStorage(ImageFileStorageBase):
info_dict = {} info_dict = {}
if metadata is not None: if metadata is not None:
metadata_json = metadata.model_dump_json() info_dict["invokeai_metadata"] = metadata
info_dict["invokeai_metadata"] = metadata_json pnginfo.add_text("invokeai_metadata", metadata)
pnginfo.add_text("invokeai_metadata", metadata_json)
if workflow is not None: if workflow is not None:
workflow_json = workflow.model_dump_json() info_dict["invokeai_workflow"] = workflow
info_dict["invokeai_workflow"] = workflow_json pnginfo.add_text("invokeai_workflow", workflow)
pnginfo.add_text("invokeai_workflow", workflow_json)
if graph is not None: if graph is not None:
graph_json = graph.model_dump_json() if isinstance(graph, Graph) else graph info_dict["invokeai_graph"] = graph
info_dict["invokeai_graph"] = graph_json pnginfo.add_text("invokeai_graph", graph)
pnginfo.add_text("invokeai_graph", graph_json)
# When saving the image, the image object's info field is not populated. We need to set it # When saving the image, the image object's info field is not populated. We need to set it
image.info = info_dict image.info = info_dict

View File

@ -11,9 +11,7 @@ from invokeai.app.services.image_records.image_records_common import (
ResourceOrigin, ResourceOrigin,
) )
from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.shared.graph import Graph
from invokeai.app.services.shared.pagination import OffsetPaginatedResults from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.workflow_records.workflow_records_common import WorkflowWithoutID
class ImageServiceABC(ABC): class ImageServiceABC(ABC):
@ -52,9 +50,9 @@ class ImageServiceABC(ABC):
session_id: Optional[str] = None, session_id: Optional[str] = None,
board_id: Optional[str] = None, board_id: Optional[str] = None,
is_intermediate: Optional[bool] = False, is_intermediate: Optional[bool] = False,
metadata: Optional[MetadataField] = None, metadata: Optional[str] = None,
workflow: Optional[WorkflowWithoutID] = None, workflow: Optional[str] = None,
graph: Optional[Graph | str] = None, graph: Optional[str] = None,
) -> ImageDTO: ) -> ImageDTO:
"""Creates an image, storing the file and its metadata.""" """Creates an image, storing the file and its metadata."""
pass pass

View File

@ -180,9 +180,9 @@ class ImagesInterface(InvocationContextInterface):
# If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None. # If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None.
metadata_ = None metadata_ = None
if metadata: if metadata:
metadata_ = metadata metadata_ = metadata.model_dump_json()
elif isinstance(self._data.invocation, WithMetadata): elif isinstance(self._data.invocation, WithMetadata) and self._data.invocation.metadata:
metadata_ = self._data.invocation.metadata metadata_ = self._data.invocation.metadata.model_dump_json()
# If `board_id` is provided directly, use that. Else, use the board provided by `WithBoard`, falling back to None. # If `board_id` is provided directly, use that. Else, use the board provided by `WithBoard`, falling back to None.
board_id_ = None board_id_ = None
@ -191,6 +191,14 @@ class ImagesInterface(InvocationContextInterface):
elif isinstance(self._data.invocation, WithBoard) and self._data.invocation.board: elif isinstance(self._data.invocation, WithBoard) and self._data.invocation.board:
board_id_ = self._data.invocation.board.board_id board_id_ = self._data.invocation.board.board_id
workflow_ = None
if self._data.queue_item.workflow:
workflow_ = self._data.queue_item.workflow.model_dump_json()
graph_ = None
if self._data.queue_item.session.graph:
graph_ = self._data.queue_item.session.graph.model_dump_json()
return self._services.images.create( return self._services.images.create(
image=image, image=image,
is_intermediate=self._data.invocation.is_intermediate, is_intermediate=self._data.invocation.is_intermediate,
@ -198,8 +206,8 @@ class ImagesInterface(InvocationContextInterface):
board_id=board_id_, board_id=board_id_,
metadata=metadata_, metadata=metadata_,
image_origin=ResourceOrigin.INTERNAL, image_origin=ResourceOrigin.INTERNAL,
workflow=self._data.queue_item.workflow, workflow=workflow_,
graph=self._data.queue_item.session.graph, graph=graph_,
session_id=self._data.queue_item.session_id, session_id=self._data.queue_item.session_id,
node_id=self._data.invocation.id, node_id=self._data.invocation.id,
) )