feat(nodes): change intermediates handling

- `ImageType` is now restricted to `results` and `uploads`.
- Add a reserved `meta` field to nodes to hold the `is_intermediate` boolean. We can extend it in the future to support other node `meta`.
- Add a `is_intermediate` column to the `images` table to hold this. (When `latents`, `conditioning` etc are added to the DB, they will also have this column.)
- All nodes default to `*not* intermediate`. Nodes must explicitly be marked `intermediate` for their outputs to be `intermediate`.
- When building a graph, you can set `node.meta.is_intermediate=True` and it will be handled as an intermediate.
- Add a new `update()` method to the `ImageService`, and a route to call it. Updates have a strict model, currently only `session_id` and `image_category` may be updated.
- Add a new `update()` method to the `ImageRecordStorageService` to update the image record using the model.
This commit is contained in:
psychedelicious 2023-05-25 23:47:18 +10:00 committed by Kent Keirsey
parent 05fb0ac2b2
commit d2c8a53c55
9 changed files with 168 additions and 22 deletions

View File

@ -1,5 +1,6 @@
import io import io
from fastapi import HTTPException, Path, Query, Request, Response, UploadFile from typing import Optional
from fastapi import Body, HTTPException, Path, Query, Request, Response, UploadFile
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from PIL import Image from PIL import Image
@ -7,7 +8,11 @@ from invokeai.app.models.image import (
ImageCategory, ImageCategory,
ImageType, ImageType,
) )
from invokeai.app.services.models.image_record import ImageDTO, ImageUrlsDTO from invokeai.app.services.models.image_record import (
ImageDTO,
ImageRecordChanges,
ImageUrlsDTO,
)
from invokeai.app.services.item_storage import PaginatedResults from invokeai.app.services.item_storage import PaginatedResults
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
@ -27,10 +32,17 @@ images_router = APIRouter(prefix="/v1/images", tags=["images"])
) )
async def upload_image( async def upload_image(
file: UploadFile, file: UploadFile,
image_type: ImageType,
request: Request, request: Request,
response: Response, response: Response,
image_category: ImageCategory = ImageCategory.GENERAL, image_category: ImageCategory = Query(
default=ImageCategory.GENERAL, description="The category of the image"
),
is_intermediate: bool = Query(
default=False, description="Whether this is an intermediate image"
),
session_id: Optional[str] = Query(
default=None, description="The session ID associated with this upload, if any"
),
) -> ImageDTO: ) -> ImageDTO:
"""Uploads an image""" """Uploads an image"""
if not file.content_type.startswith("image"): if not file.content_type.startswith("image"):
@ -46,9 +58,11 @@ async def upload_image(
try: try:
image_dto = ApiDependencies.invoker.services.images.create( image_dto = ApiDependencies.invoker.services.images.create(
pil_image, image=pil_image,
image_type, image_type=ImageType.UPLOAD,
image_category, image_category=image_category,
session_id=session_id,
is_intermediate=is_intermediate,
) )
response.status_code = 201 response.status_code = 201
@ -61,7 +75,7 @@ async def upload_image(
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image") @images_router.delete("/{image_type}/{image_name}", operation_id="delete_image")
async def delete_image( async def delete_image(
image_type: ImageType = Query(description="The type of image to delete"), image_type: ImageType = Path(description="The type of image to delete"),
image_name: str = Path(description="The name of the image to delete"), image_name: str = Path(description="The name of the image to delete"),
) -> None: ) -> None:
"""Deletes an image""" """Deletes an image"""
@ -73,6 +87,28 @@ async def delete_image(
pass pass
@images_router.patch(
"/{image_type}/{image_name}",
operation_id="update_image",
response_model=ImageDTO,
)
async def update_image(
image_type: ImageType = Path(description="The type of image to update"),
image_name: str = Path(description="The name of the image to update"),
image_changes: ImageRecordChanges = Body(
description="The changes to apply to the image"
),
) -> ImageDTO:
"""Updates an image"""
try:
return ApiDependencies.invoker.services.images.update(
image_type, image_name, image_changes
)
except Exception as e:
raise HTTPException(status_code=400, detail="Failed to update image")
@images_router.get( @images_router.get(
"/{image_type}/{image_name}/metadata", "/{image_type}/{image_name}/metadata",
operation_id="get_image_metadata", operation_id="get_image_metadata",
@ -85,9 +121,7 @@ async def get_image_metadata(
"""Gets an image's metadata""" """Gets an image's metadata"""
try: try:
return ApiDependencies.invoker.services.images.get_dto( return ApiDependencies.invoker.services.images.get_dto(image_type, image_name)
image_type, image_name
)
except Exception as e: except Exception as e:
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@ -113,9 +147,7 @@ async def get_image_full(
"""Gets a full-resolution image file""" """Gets a full-resolution image file"""
try: try:
path = ApiDependencies.invoker.services.images.get_path( path = ApiDependencies.invoker.services.images.get_path(image_type, image_name)
image_type, image_name
)
if not ApiDependencies.invoker.services.images.validate_path(path): if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404) raise HTTPException(status_code=404)

View File

@ -39,6 +39,12 @@ class BaseInvocationOutput(BaseModel):
return tuple(subclasses) return tuple(subclasses)
class InvocationMeta(BaseModel):
is_intermediate: bool = Field(
default=False,
description="Whether this is an intermediate node. Intermediate nodes are periodically deleted."
)
class BaseInvocation(ABC, BaseModel): class BaseInvocation(ABC, BaseModel):
"""A node to process inputs and produce outputs. """A node to process inputs and produce outputs.
May use dependency injection in __init__ to receive providers. May use dependency injection in __init__ to receive providers.
@ -78,6 +84,8 @@ class BaseInvocation(ABC, BaseModel):
#fmt: off #fmt: off
id: str = Field(description="The id of this node. Must be unique among all nodes.") id: str = Field(description="The id of this node. Must be unique among all nodes.")
type: str = Field(description="The type of this node. Must be unique among all nodes.")
meta: InvocationMeta = Field(default=InvocationMeta(), description="The meta properties of this node.")
#fmt: on #fmt: on

View File

@ -57,7 +57,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=image_inpainted, image=image_inpainted,
image_type=ImageType.INTERMEDIATE, image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,

View File

@ -370,6 +370,7 @@ class LatentsToImageInvocation(BaseInvocation):
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
node_id=self.id, node_id=self.id,
is_intermediate=self.meta.is_intermediate
) )
return ImageOutput( return ImageOutput(

View File

@ -43,7 +43,7 @@ class RestoreFaceInvocation(BaseInvocation):
# TODO: can this return multiple results? # TODO: can this return multiple results?
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=results[0][0], image=results[0][0],
image_type=ImageType.INTERMEDIATE, image_type=ImageType.RESULT,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,

View File

@ -10,7 +10,6 @@ class ImageType(str, Enum, metaclass=MetaEnum):
RESULT = "results" RESULT = "results"
UPLOAD = "uploads" UPLOAD = "uploads"
INTERMEDIATE = "intermediates"
class InvalidImageTypeException(ValueError): class InvalidImageTypeException(ValueError):

View File

@ -12,6 +12,7 @@ from invokeai.app.models.image import (
) )
from invokeai.app.services.models.image_record import ( from invokeai.app.services.models.image_record import (
ImageRecord, ImageRecord,
ImageRecordChanges,
deserialize_image_record, deserialize_image_record,
) )
from invokeai.app.services.item_storage import PaginatedResults from invokeai.app.services.item_storage import PaginatedResults
@ -49,6 +50,16 @@ class ImageRecordStorageBase(ABC):
"""Gets an image record.""" """Gets an image record."""
pass pass
@abstractmethod
def update(
self,
image_name: str,
image_type: ImageType,
changes: ImageRecordChanges,
) -> None:
"""Updates an image record."""
pass
@abstractmethod @abstractmethod
def get_many( def get_many(
self, self,
@ -78,6 +89,7 @@ class ImageRecordStorageBase(ABC):
session_id: Optional[str], session_id: Optional[str],
node_id: Optional[str], node_id: Optional[str],
metadata: Optional[ImageMetadata], metadata: Optional[ImageMetadata],
is_intermediate: bool = False,
) -> datetime: ) -> datetime:
"""Saves an image record.""" """Saves an image record."""
pass pass
@ -125,6 +137,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
session_id TEXT, session_id TEXT,
node_id TEXT, node_id TEXT,
metadata TEXT, metadata TEXT,
is_intermediate BOOLEAN DEFAULT FALSE,
created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
-- Updated via trigger -- Updated via trigger
updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
@ -193,6 +206,42 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
return deserialize_image_record(dict(result)) return deserialize_image_record(dict(result))
def update(
self,
image_name: str,
image_type: ImageType,
changes: ImageRecordChanges,
) -> None:
try:
self._lock.acquire()
# Change the category of the image
if changes.image_category is not None:
self._cursor.execute(
f"""--sql
UPDATE images
SET image_category = ?
WHERE image_name = ?;
""",
(changes.image_category, image_name),
)
# Change the session associated with the image
if changes.session_id is not None:
self._cursor.execute(
f"""--sql
UPDATE images
SET session_id = ?
WHERE image_name = ?;
""",
(changes.session_id, image_name),
)
self._conn.commit()
except sqlite3.Error as e:
self._conn.rollback()
raise ImageRecordSaveException from e
finally:
self._lock.release()
def get_many( def get_many(
self, self,
image_type: ImageType, image_type: ImageType,
@ -265,6 +314,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
height: int, height: int,
node_id: Optional[str], node_id: Optional[str],
metadata: Optional[ImageMetadata], metadata: Optional[ImageMetadata],
is_intermediate: bool = False,
) -> datetime: ) -> datetime:
try: try:
metadata_json = ( metadata_json = (
@ -281,9 +331,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
height, height,
node_id, node_id,
session_id, session_id,
metadata metadata,
is_intermediate
) )
VALUES (?, ?, ?, ?, ?, ?, ?, ?); VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?);
""", """,
( (
image_name, image_name,
@ -294,6 +345,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
node_id, node_id,
session_id, session_id,
metadata_json, metadata_json,
is_intermediate,
), ),
) )
self._conn.commit() self._conn.commit()

View File

@ -20,6 +20,7 @@ from invokeai.app.services.image_record_storage import (
from invokeai.app.services.models.image_record import ( from invokeai.app.services.models.image_record import (
ImageRecord, ImageRecord,
ImageDTO, ImageDTO,
ImageRecordChanges,
image_record_to_dto, image_record_to_dto,
) )
from invokeai.app.services.image_file_storage import ( from invokeai.app.services.image_file_storage import (
@ -31,7 +32,6 @@ from invokeai.app.services.image_file_storage import (
from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults from invokeai.app.services.item_storage import ItemStorageABC, PaginatedResults
from invokeai.app.services.metadata import MetadataServiceBase from invokeai.app.services.metadata import MetadataServiceBase
from invokeai.app.services.urls import UrlServiceBase from invokeai.app.services.urls import UrlServiceBase
from invokeai.app.util.misc import get_iso_timestamp
if TYPE_CHECKING: if TYPE_CHECKING:
from invokeai.app.services.graph import GraphExecutionState from invokeai.app.services.graph import GraphExecutionState
@ -48,11 +48,21 @@ class ImageServiceABC(ABC):
image_category: ImageCategory, image_category: ImageCategory,
node_id: Optional[str] = None, node_id: Optional[str] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
metadata: Optional[ImageMetadata] = None, intermediate: bool = False,
) -> ImageDTO: ) -> ImageDTO:
"""Creates an image, storing the file and its metadata.""" """Creates an image, storing the file and its metadata."""
pass pass
@abstractmethod
def update(
self,
image_type: ImageType,
image_name: str,
changes: ImageRecordChanges,
) -> ImageDTO:
"""Updates an image."""
pass
@abstractmethod @abstractmethod
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType: def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
"""Gets an image as a PIL image.""" """Gets an image as a PIL image."""
@ -157,6 +167,7 @@ class ImageService(ImageServiceABC):
image_category: ImageCategory, image_category: ImageCategory,
node_id: Optional[str] = None, node_id: Optional[str] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
is_intermediate: bool = False,
) -> ImageDTO: ) -> ImageDTO:
if image_type not in ImageType: if image_type not in ImageType:
raise InvalidImageTypeException raise InvalidImageTypeException
@ -184,6 +195,8 @@ class ImageService(ImageServiceABC):
image_category=image_category, image_category=image_category,
width=width, width=width,
height=height, height=height,
# Meta fields
is_intermediate=is_intermediate,
# Nullable fields # Nullable fields
node_id=node_id, node_id=node_id,
session_id=session_id, session_id=session_id,
@ -217,6 +230,7 @@ class ImageService(ImageServiceABC):
created_at=created_at, created_at=created_at,
updated_at=created_at, # this is always the same as the created_at at this time updated_at=created_at, # this is always the same as the created_at at this time
deleted_at=None, deleted_at=None,
is_intermediate=is_intermediate,
# Extra non-nullable fields for DTO # Extra non-nullable fields for DTO
image_url=image_url, image_url=image_url,
thumbnail_url=thumbnail_url, thumbnail_url=thumbnail_url,
@ -231,6 +245,23 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem saving image record and file") self._services.logger.error("Problem saving image record and file")
raise e raise e
def update(
self,
image_type: ImageType,
image_name: str,
changes: ImageRecordChanges,
) -> ImageDTO:
try:
self._services.records.update(image_name, image_type, changes)
return self.get_dto(image_type, image_name)
except ImageRecordSaveException:
self._services.logger.error("Failed to update image record")
raise
except Exception as e:
self._services.logger.error("Problem updating image record")
raise e
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType: def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType:
try: try:
return self._services.files.get(image_type, image_name) return self._services.files.get(image_type, image_name)

View File

@ -1,6 +1,6 @@
import datetime import datetime
from typing import Optional, Union from typing import Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Extra, Field, StrictStr
from invokeai.app.models.image import ImageCategory, ImageType from invokeai.app.models.image import ImageCategory, ImageType
from invokeai.app.models.metadata import ImageMetadata from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.misc import get_iso_timestamp from invokeai.app.util.misc import get_iso_timestamp
@ -31,6 +31,8 @@ class ImageRecord(BaseModel):
description="The deleted timestamp of the image." description="The deleted timestamp of the image."
) )
"""The deleted timestamp of the image.""" """The deleted timestamp of the image."""
is_intermediate: bool = Field(description="Whether this is an intermediate image.")
"""Whether this is an intermediate image."""
session_id: Optional[str] = Field( session_id: Optional[str] = Field(
default=None, default=None,
description="The session ID that generated this image, if it is a generated image.", description="The session ID that generated this image, if it is a generated image.",
@ -48,6 +50,25 @@ class ImageRecord(BaseModel):
"""A limited subset of the image's generation metadata. Retrieve the image's session for full metadata.""" """A limited subset of the image's generation metadata. Retrieve the image's session for full metadata."""
class ImageRecordChanges(BaseModel, extra=Extra.forbid):
"""A set of changes to apply to an image record.
Only limited changes are valid:
- `image_category`: change the category of an image
- `session_id`: change the session associated with an image
"""
image_category: Optional[ImageCategory] = Field(
description="The image's new category."
)
"""The image's new category."""
session_id: Optional[StrictStr] = Field(
default=None,
description="The image's new session ID.",
)
"""The image's new session ID."""
class ImageUrlsDTO(BaseModel): class ImageUrlsDTO(BaseModel):
"""The URLs for an image and its thumbnail.""" """The URLs for an image and its thumbnail."""
@ -95,6 +116,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
created_at = image_dict.get("created_at", get_iso_timestamp()) created_at = image_dict.get("created_at", get_iso_timestamp())
updated_at = image_dict.get("updated_at", get_iso_timestamp()) updated_at = image_dict.get("updated_at", get_iso_timestamp())
deleted_at = image_dict.get("deleted_at", get_iso_timestamp()) deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
is_intermediate = image_dict.get("is_intermediate", False)
raw_metadata = image_dict.get("metadata") raw_metadata = image_dict.get("metadata")
@ -115,4 +137,5 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
created_at=created_at, created_at=created_at,
updated_at=updated_at, updated_at=updated_at,
deleted_at=deleted_at, deleted_at=deleted_at,
is_intermediate=is_intermediate,
) )