diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 0615ff187e..920181ff8b 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -1,5 +1,6 @@ import io -from fastapi import HTTPException, Path, Query, Request, Response, UploadFile +from typing import Optional +from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile from fastapi.routing import APIRouter from fastapi.responses import FileResponse from PIL import Image @@ -7,7 +8,11 @@ from invokeai.app.models.image import ( ImageCategory, ImageType, ) -from invokeai.app.services.models.image_record import ImageDTO, ImageUrlsDTO +from invokeai.app.services.models.image_record import ( + ImageDTO, + ImageRecordChanges, + ImageUrlsDTO, +) from invokeai.app.services.item_storage import PaginatedResults from ..dependencies import ApiDependencies @@ -27,10 +32,17 @@ images_router = APIRouter(prefix="/v1/images", tags=["images"]) ) async def upload_image( file: UploadFile, - image_type: ImageType, request: Request, response: Response, - image_category: ImageCategory = ImageCategory.GENERAL, + image_category: ImageCategory = Query( + default=ImageCategory.GENERAL, description="The category of the image" + ), + is_intermediate: bool = Query( + default=False, description="Whether this is an intermediate image" + ), + session_id: Optional[str] = Query( + default=None, description="The session ID associated with this upload, if any" + ), ) -> ImageDTO: """Uploads an image""" if not file.content_type.startswith("image"): @@ -46,9 +58,11 @@ async def upload_image( try: image_dto = ApiDependencies.invoker.services.images.create( - pil_image, - image_type, - image_category, + image=pil_image, + image_type=ImageType.UPLOAD, + image_category=image_category, + session_id=session_id, + is_intermediate=is_intermediate, ) response.status_code = 201 @@ -61,7 +75,7 @@ async def upload_image( @images_router.delete("/{image_type}/{image_name}", operation_id="delete_image") async def delete_image( - image_type: ImageType = Query(description="The type of image to delete"), + image_type: ImageType = Path(description="The type of image to delete"), image_name: str = Path(description="The name of the image to delete"), ) -> None: """Deletes an image""" @@ -73,6 +87,28 @@ async def delete_image( pass +@images_router.patch( + "/{image_type}/{image_name}", + operation_id="update_image", + response_model=ImageDTO, +) +async def update_image( + image_type: ImageType = Path(description="The type of image to update"), + image_name: str = Path(description="The name of the image to update"), + image_changes: ImageRecordChanges = Body( + description="The changes to apply to the image" + ), +) -> ImageDTO: + """Updates an image""" + + try: + return ApiDependencies.invoker.services.images.update( + image_type, image_name, image_changes + ) + except Exception as e: + raise HTTPException(status_code=400, detail="Failed to update image") + + @images_router.get( "/{image_type}/{image_name}/metadata", operation_id="get_image_metadata", @@ -85,9 +121,7 @@ async def get_image_metadata( """Gets an image's metadata""" try: - return ApiDependencies.invoker.services.images.get_dto( - image_type, image_name - ) + return ApiDependencies.invoker.services.images.get_dto(image_type, image_name) except Exception as e: raise HTTPException(status_code=404) @@ -113,9 +147,7 @@ async def get_image_full( """Gets a full-resolution image file""" try: - path = ApiDependencies.invoker.services.images.get_path( - image_type, image_name - ) + path = ApiDependencies.invoker.services.images.get_path(image_type, image_name) if not ApiDependencies.invoker.services.images.validate_path(path): raise HTTPException(status_code=404) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index da61641105..1ba498c9d8 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -39,6 +39,12 @@ class BaseInvocationOutput(BaseModel): return tuple(subclasses) +class InvocationMeta(BaseModel): + is_intermediate: bool = Field( + default=False, + description="Whether this is an intermediate node. Intermediate nodes are periodically deleted." + ) + class BaseInvocation(ABC, BaseModel): """A node to process inputs and produce outputs. May use dependency injection in __init__ to receive providers. @@ -78,6 +84,8 @@ class BaseInvocation(ABC, BaseModel): #fmt: off id: str = Field(description="The id of this node. Must be unique among all nodes.") + type: str = Field(description="The type of this node. Must be unique among all nodes.") + meta: InvocationMeta = Field(default=InvocationMeta(), description="The meta properties of this node.") #fmt: on diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py index 26e06a2af8..d900ecfdbf 100644 --- a/invokeai/app/invocations/cv.py +++ b/invokeai/app/invocations/cv.py @@ -57,7 +57,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig): image_dto = context.services.images.create( image=image_inpainted, - image_type=ImageType.INTERMEDIATE, + image_type=ImageType.RESULT, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 12cebdf41d..9f78d72b77 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -370,6 +370,7 @@ class LatentsToImageInvocation(BaseInvocation): image_category=ImageCategory.GENERAL, session_id=context.graph_execution_state_id, node_id=self.id, + is_intermediate=self.meta.is_intermediate ) return ImageOutput( diff --git a/invokeai/app/invocations/reconstruct.py b/invokeai/app/invocations/reconstruct.py index 024134cd46..a234693128 100644 --- a/invokeai/app/invocations/reconstruct.py +++ b/invokeai/app/invocations/reconstruct.py @@ -43,7 +43,7 @@ class RestoreFaceInvocation(BaseInvocation): # TODO: can this return multiple results? image_dto = context.services.images.create( image=results[0][0], - image_type=ImageType.INTERMEDIATE, + image_type=ImageType.RESULT, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, diff --git a/invokeai/app/models/image.py b/invokeai/app/models/image.py index 544951ea34..46b50145aa 100644 --- a/invokeai/app/models/image.py +++ b/invokeai/app/models/image.py @@ -10,7 +10,6 @@ class ImageType(str, Enum, metaclass=MetaEnum): RESULT = "results" UPLOAD = "uploads" - INTERMEDIATE = "intermediates" class InvalidImageTypeException(ValueError): diff --git a/invokeai/app/services/image_record_storage.py b/invokeai/app/services/image_record_storage.py index 4e1f73978b..188a411a6b 100644 --- a/invokeai/app/services/image_record_storage.py +++ b/invokeai/app/services/image_record_storage.py @@ -12,6 +12,7 @@ from invokeai.app.models.image import ( ) from invokeai.app.services.models.image_record import ( ImageRecord, + ImageRecordChanges, deserialize_image_record, ) from invokeai.app.services.item_storage import PaginatedResults @@ -49,6 +50,16 @@ class ImageRecordStorageBase(ABC): """Gets an image record.""" pass + @abstractmethod + def update( + self, + image_name: str, + image_type: ImageType, + changes: ImageRecordChanges, + ) -> None: + """Updates an image record.""" + pass + @abstractmethod def get_many( self, @@ -78,6 +89,7 @@ class ImageRecordStorageBase(ABC): session_id: Optional[str], node_id: Optional[str], metadata: Optional[ImageMetadata], + is_intermediate: bool = False, ) -> datetime: """Saves an image record.""" pass @@ -125,6 +137,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): session_id TEXT, node_id TEXT, metadata TEXT, + is_intermediate BOOLEAN DEFAULT FALSE, created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, -- Updated via trigger updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, @@ -193,6 +206,42 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): return deserialize_image_record(dict(result)) + def update( + self, + image_name: str, + image_type: ImageType, + changes: ImageRecordChanges, + ) -> None: + try: + self._lock.acquire() + # Change the category of the image + if changes.image_category is not None: + self._cursor.execute( + f"""--sql + UPDATE images + SET image_category = ? + WHERE image_name = ?; + """, + (changes.image_category, image_name), + ) + + # Change the session associated with the image + if changes.session_id is not None: + self._cursor.execute( + f"""--sql + UPDATE images + SET session_id = ? + WHERE image_name = ?; + """, + (changes.session_id, image_name), + ) + self._conn.commit() + except sqlite3.Error as e: + self._conn.rollback() + raise ImageRecordSaveException from e + finally: + self._lock.release() + def get_many( self, image_type: ImageType, @@ -265,6 +314,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): height: int, node_id: Optional[str], metadata: Optional[ImageMetadata], + is_intermediate: bool = False, ) -> datetime: try: metadata_json = ( @@ -281,9 +331,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): height, node_id, session_id, - metadata + metadata, + is_intermediate ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?); + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?); """, ( image_name, @@ -294,6 +345,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): node_id, session_id, metadata_json, + is_intermediate, ), ) self._conn.commit() diff --git a/invokeai/app/services/images.py b/invokeai/app/services/images.py index 914dd3b6d3..d0f7236fe2 100644 --- a/invokeai/app/services/images.py +++ b/invokeai/app/services/images.py @@ -20,6 +20,7 @@ from invokeai.app.services.image_record_storage import ( from invokeai.app.services.models.image_record import ( ImageRecord, ImageDTO, + ImageRecordChanges, image_record_to_dto, ) from invokeai.app.services.image_file_storage import ( @@ -31,7 +32,6 @@ from invokeai.app.services.image_file_storage import ( from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults from invokeai.app.services.metadata import MetadataServiceBase from invokeai.app.services.urls import UrlServiceBase -from invokeai.app.util.misc import get_iso_timestamp if TYPE_CHECKING: from invokeai.app.services.graph import GraphExecutionState @@ -48,11 +48,21 @@ class ImageServiceABC(ABC): image_category: ImageCategory, node_id: Optional[str] = None, session_id: Optional[str] = None, - metadata: Optional[ImageMetadata] = None, + intermediate: bool = False, ) -> ImageDTO: """Creates an image, storing the file and its metadata.""" pass + @abstractmethod + def update( + self, + image_type: ImageType, + image_name: str, + changes: ImageRecordChanges, + ) -> ImageDTO: + """Updates an image.""" + pass + @abstractmethod def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType: """Gets an image as a PIL image.""" @@ -157,6 +167,7 @@ class ImageService(ImageServiceABC): image_category: ImageCategory, node_id: Optional[str] = None, session_id: Optional[str] = None, + is_intermediate: bool = False, ) -> ImageDTO: if image_type not in ImageType: raise InvalidImageTypeException @@ -184,6 +195,8 @@ class ImageService(ImageServiceABC): image_category=image_category, width=width, height=height, + # Meta fields + is_intermediate=is_intermediate, # Nullable fields node_id=node_id, session_id=session_id, @@ -217,6 +230,7 @@ class ImageService(ImageServiceABC): created_at=created_at, updated_at=created_at, # this is always the same as the created_at at this time deleted_at=None, + is_intermediate=is_intermediate, # Extra non-nullable fields for DTO image_url=image_url, thumbnail_url=thumbnail_url, @@ -231,6 +245,23 @@ class ImageService(ImageServiceABC): self._services.logger.error("Problem saving image record and file") raise e + def update( + self, + image_type: ImageType, + image_name: str, + changes: ImageRecordChanges, + ) -> ImageDTO: + try: + self._services.records.update(image_name, image_type, changes) + return self.get_dto(image_type, image_name) + except ImageRecordSaveException: + self._services.logger.error("Failed to update image record") + raise + except Exception as e: + self._services.logger.error("Problem updating image record") + raise e + + def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType: try: return self._services.files.get(image_type, image_name) diff --git a/invokeai/app/services/models/image_record.py b/invokeai/app/services/models/image_record.py index c1155ff73e..26e4929be2 100644 --- a/invokeai/app/services/models/image_record.py +++ b/invokeai/app/services/models/image_record.py @@ -1,6 +1,6 @@ import datetime from typing import Optional, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Extra, Field, StrictStr from invokeai.app.models.image import ImageCategory, ImageType from invokeai.app.models.metadata import ImageMetadata from invokeai.app.util.misc import get_iso_timestamp @@ -31,6 +31,8 @@ class ImageRecord(BaseModel): description="The deleted timestamp of the image." ) """The deleted timestamp of the image.""" + is_intermediate: bool = Field(description="Whether this is an intermediate image.") + """Whether this is an intermediate image.""" session_id: Optional[str] = Field( default=None, description="The session ID that generated this image, if it is a generated image.", @@ -48,6 +50,25 @@ class ImageRecord(BaseModel): """A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.""" +class ImageRecordChanges(BaseModel, extra=Extra.forbid): + """A set of changes to apply to an image record. + + Only limited changes are valid: + - `image_category`: change the category of an image + - `session_id`: change the session associated with an image + """ + + image_category: Optional[ImageCategory] = Field( + description="The image's new category." + ) + """The image's new category.""" + session_id: Optional[StrictStr] = Field( + default=None, + description="The image's new session ID.", + ) + """The image's new session ID.""" + + class ImageUrlsDTO(BaseModel): """The URLs for an image and its thumbnail.""" @@ -95,6 +116,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord: created_at = image_dict.get("created_at", get_iso_timestamp()) updated_at = image_dict.get("updated_at", get_iso_timestamp()) deleted_at = image_dict.get("deleted_at", get_iso_timestamp()) + is_intermediate = image_dict.get("is_intermediate", False) raw_metadata = image_dict.get("metadata") @@ -115,4 +137,5 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord: created_at=created_at, updated_at=updated_at, deleted_at=deleted_at, + is_intermediate=is_intermediate, )