feat(nodes): change intermediates handling

- `ImageType` is now restricted to `results` and `uploads`.
- Add a reserved `meta` field to nodes to hold the `is_intermediate` boolean. We can extend it in the future to support other node `meta`.
- Add a `is_intermediate` column to the `images` table to hold this. (When `latents`, `conditioning` etc are added to the DB, they will also have this column.)
- All nodes default to `*not* intermediate`. Nodes must explicitly be marked `intermediate` for their outputs to be `intermediate`.
- When building a graph, you can set `node.meta.is_intermediate=True` and it will be handled as an intermediate.
- Add a new `update()` method to the `ImageService`, and a route to call it. Updates have a strict model, currently only `session_id` and `image_category` may be updated.
- Add a new `update()` method to the `ImageRecordStorageService` to update the image record using the model.
This commit is contained in:
psychedelicious 2023-05-25 23:47:18 +10:00 committed by Kent Keirsey
parent 05fb0ac2b2
commit d2c8a53c55
9 changed files with 168 additions and 22 deletions

View File

@ -1,5 +1,6 @@
import io
from fastapi import HTTPException, Path, Query, Request, Response, UploadFile
from typing import Optional
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
from fastapi.routing import APIRouter
from fastapi.responses import FileResponse
from PIL import Image
@ -7,7 +8,11 @@ from invokeai.app.models.image import (
ImageCategory,
ImageType,
)
from invokeai.app.services.models.image_record import ImageDTO, ImageUrlsDTO
from invokeai.app.services.models.image_record import (
ImageDTO,
ImageRecordChanges,
ImageUrlsDTO,
)
from invokeai.app.services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies
@ -27,10 +32,17 @@ images_router = APIRouter(prefix="/v1/images", tags=["images"])
)
async def upload_image(
file: UploadFile,
image_type: ImageType,
request: Request,
response: Response,
image_category: ImageCategory = ImageCategory.GENERAL,
image_category: ImageCategory = Query(
default=ImageCategory.GENERAL, description="The category of the image"
),
is_intermediate: bool = Query(
default=False, description="Whether this is an intermediate image"
),
session_id: Optional[str] = Query(
default=None, description="The session ID associated with this upload, if any"
),
) -> ImageDTO:
"""Uploads an image"""
if not file.content_type.startswith("image"):
@ -46,9 +58,11 @@ async def upload_image(
try:
image_dto = ApiDependencies.invoker.services.images.create(
pil_image,
image_type,
image_category,
image=pil_image,
image_type=ImageType.UPLOAD,
image_category=image_category,
session_id=session_id,
is_intermediate=is_intermediate,
)
response.status_code = 201
@ -61,7 +75,7 @@ async def upload_image(
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
async def delete_image(
image_type: ImageType = Query(description="The type of image to delete"),
image_type: ImageType = Path(description="The type of image to delete"),
image_name: str = Path(description="The name of the image to delete"),
) -> None:
"""Deletes an image"""
@ -73,6 +87,28 @@ async def delete_image(
pass
@images_router.patch(
"/{image_type}/{image_name}",
operation_id="update_image",
response_model=ImageDTO,
)
async def update_image(
image_type: ImageType = Path(description="The type of image to update"),
image_name: str = Path(description="The name of the image to update"),
image_changes: ImageRecordChanges = Body(
description="The changes to apply to the image"
),
) -> ImageDTO:
"""Updates an image"""
try:
return ApiDependencies.invoker.services.images.update(
image_type, image_name, image_changes
)
except Exception as e:
raise HTTPException(status_code=400, detail="Failed to update image")
@images_router.get(
"/{image_type}/{image_name}/metadata",
operation_id="get_image_metadata",
@ -85,9 +121,7 @@ async def get_image_metadata(
"""Gets an image's metadata"""
try:
return ApiDependencies.invoker.services.images.get_dto(
image_type, image_name
)
return ApiDependencies.invoker.services.images.get_dto(image_type, image_name)
except Exception as e:
raise HTTPException(status_code=404)
@ -113,9 +147,7 @@ async def get_image_full(
"""Gets a full-resolution image file"""
try:
path = ApiDependencies.invoker.services.images.get_path(
image_type, image_name
)
path = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404)

View File

@ -39,6 +39,12 @@ class BaseInvocationOutput(BaseModel):
return tuple(subclasses)
class InvocationMeta(BaseModel):
is_intermediate: bool = Field(
default=False,
description="Whether this is an intermediate node. Intermediate nodes are periodically deleted."
)
class BaseInvocation(ABC, BaseModel):
"""A node to process inputs and produce outputs.
May use dependency injection in __init__ to receive providers.
@ -78,6 +84,8 @@ class BaseInvocation(ABC, BaseModel):
#fmt: off
id: str = Field(description="The id of this node. Must be unique among all nodes.")
type: str = Field(description="The type of this node. Must be unique among all nodes.")
meta: InvocationMeta = Field(default=InvocationMeta(), description="The meta properties of this node.")
#fmt: on

View File

@ -57,7 +57,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
image_dto = context.services.images.create(
image=image_inpainted,
image_type=ImageType.INTERMEDIATE,
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,

View File

@ -370,6 +370,7 @@ class LatentsToImageInvocation(BaseInvocation):
image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id,
node_id=self.id,
is_intermediate=self.meta.is_intermediate
)
return ImageOutput(

View File

@ -43,7 +43,7 @@ class RestoreFaceInvocation(BaseInvocation):
# TODO: can this return multiple results?
image_dto = context.services.images.create(
image=results[0][0],
image_type=ImageType.INTERMEDIATE,
image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,

View File

@ -10,7 +10,6 @@ class ImageType(str, Enum, metaclass=MetaEnum):
RESULT = "results"
UPLOAD = "uploads"
INTERMEDIATE = "intermediates"
class InvalidImageTypeException(ValueError):

View File

@ -12,6 +12,7 @@ from invokeai.app.models.image import (
)
from invokeai.app.services.models.image_record import (
ImageRecord,
ImageRecordChanges,
deserialize_image_record,
)
from invokeai.app.services.item_storage import PaginatedResults
@ -49,6 +50,16 @@ class ImageRecordStorageBase(ABC):
"""Gets an image record."""
pass
@abstractmethod
def update(
self,
image_name: str,
image_type: ImageType,
changes: ImageRecordChanges,
) -> None:
"""Updates an image record."""
pass
@abstractmethod
def get_many(
self,
@ -78,6 +89,7 @@ class ImageRecordStorageBase(ABC):
session_id: Optional[str],
node_id: Optional[str],
metadata: Optional[ImageMetadata],
is_intermediate: bool = False,
) -> datetime:
"""Saves an image record."""
pass
@ -125,6 +137,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
session_id TEXT,
node_id TEXT,
metadata TEXT,
is_intermediate BOOLEAN DEFAULT FALSE,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
-- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
@ -193,6 +206,42 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
return deserialize_image_record(dict(result))
def update(
self,
image_name: str,
image_type: ImageType,
changes: ImageRecordChanges,
) -> None:
try:
self._lock.acquire()
# Change the category of the image
if changes.image_category is not None:
self._cursor.execute(
f"""--sql
UPDATE images
SET image_category = ?
WHERE image_name = ?;
""",
(changes.image_category, image_name),
)
# Change the session associated with the image
if changes.session_id is not None:
self._cursor.execute(
f"""--sql
UPDATE images
SET session_id = ?
WHERE image_name = ?;
""",
(changes.session_id, image_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordSaveException from e
finally:
self._lock.release()
def get_many(
self,
image_type: ImageType,
@ -265,6 +314,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
height: int,
node_id: Optional[str],
metadata: Optional[ImageMetadata],
is_intermediate: bool = False,
) -> datetime:
try:
metadata_json = (
@ -281,9 +331,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
height,
node_id,
session_id,
metadata
metadata,
is_intermediate
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?);
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?);
""",
(
image_name,
@ -294,6 +345,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
node_id,
session_id,
metadata_json,
is_intermediate,
),
)
self._conn.commit()

View File

@ -20,6 +20,7 @@ from invokeai.app.services.image_record_storage import (
from invokeai.app.services.models.image_record import (
ImageRecord,
ImageDTO,
ImageRecordChanges,
image_record_to_dto,
)
from invokeai.app.services.image_file_storage import (
@ -31,7 +32,6 @@ from invokeai.app.services.image_file_storage import (
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
from invokeai.app.services.metadata import MetadataServiceBase
from invokeai.app.services.urls import UrlServiceBase
from invokeai.app.util.misc import get_iso_timestamp
if TYPE_CHECKING:
from invokeai.app.services.graph import GraphExecutionState
@ -48,11 +48,21 @@ class ImageServiceABC(ABC):
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
metadata: Optional[ImageMetadata] = None,
intermediate: bool = False,
) -> ImageDTO:
"""Creates an image, storing the file and its metadata."""
pass
@abstractmethod
def update(
self,
image_type: ImageType,
image_name: str,
changes: ImageRecordChanges,
) -> ImageDTO:
"""Updates an image."""
pass
@abstractmethod
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
"""Gets an image as a PIL image."""
@ -157,6 +167,7 @@ class ImageService(ImageServiceABC):
image_category: ImageCategory,
node_id: Optional[str] = None,
session_id: Optional[str] = None,
is_intermediate: bool = False,
) -> ImageDTO:
if image_type not in ImageType:
raise InvalidImageTypeException
@ -184,6 +195,8 @@ class ImageService(ImageServiceABC):
image_category=image_category,
width=width,
height=height,
# Meta fields
is_intermediate=is_intermediate,
# Nullable fields
node_id=node_id,
session_id=session_id,
@ -217,6 +230,7 @@ class ImageService(ImageServiceABC):
created_at=created_at,
updated_at=created_at, # this is always the same as the created_at at this time
deleted_at=None,
is_intermediate=is_intermediate,
# Extra non-nullable fields for DTO
image_url=image_url,
thumbnail_url=thumbnail_url,
@ -231,6 +245,23 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem saving image record and file")
raise e
def update(
self,
image_type: ImageType,
image_name: str,
changes: ImageRecordChanges,
) -> ImageDTO:
try:
self._services.records.update(image_name, image_type, changes)
return self.get_dto(image_type, image_name)
except ImageRecordSaveException:
self._services.logger.error("Failed to update image record")
raise
except Exception as e:
self._services.logger.error("Problem updating image record")
raise e
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
try:
return self._services.files.get(image_type, image_name)

View File

@ -1,6 +1,6 @@
import datetime
from typing import Optional, Union
from pydantic import BaseModel, Field
from pydantic import BaseModel, Extra, Field, StrictStr
from invokeai.app.models.image import ImageCategory, ImageType
from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.misc import get_iso_timestamp
@ -31,6 +31,8 @@ class ImageRecord(BaseModel):
description="The deleted timestamp of the image."
)
"""The deleted timestamp of the image."""
is_intermediate: bool = Field(description="Whether this is an intermediate image.")
"""Whether this is an intermediate image."""
session_id: Optional[str] = Field(
default=None,
description="The session ID that generated this image, if it is a generated image.",
@ -48,6 +50,25 @@ class ImageRecord(BaseModel):
"""A limited subset of the image's generation metadata. Retrieve the image's session for full metadata."""
class ImageRecordChanges(BaseModel, extra=Extra.forbid):
"""A set of changes to apply to an image record.
Only limited changes are valid:
- `image_category`: change the category of an image
- `session_id`: change the session associated with an image
"""
image_category: Optional[ImageCategory] = Field(
description="The image's new category."
)
"""The image's new category."""
session_id: Optional[StrictStr] = Field(
default=None,
description="The image's new session ID.",
)
"""The image's new session ID."""
class ImageUrlsDTO(BaseModel):
"""The URLs for an image and its thumbnail."""
@ -95,6 +116,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
created_at = image_dict.get("created_at", get_iso_timestamp())
updated_at = image_dict.get("updated_at", get_iso_timestamp())
deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
is_intermediate = image_dict.get("is_intermediate", False)
raw_metadata = image_dict.get("metadata")
@ -115,4 +137,5 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
created_at=created_at,
updated_at=updated_at,
deleted_at=deleted_at,
is_intermediate=is_intermediate,
)