feat(nodes): refactor image types

- Remove `ImageType` entirely, it is confusing
- Create `ResourceOrigin`, may be `internal` or `external`
- Revamp `ImageCategory`, may be `general`, `mask`, `control`, `user`, `other`. Expect to add more as time goes on
- Update images `list` route to accept `include_categories` OR `exclude_categories` query parameters to afford finer-grained querying. All services are updated to accomodate this change.

The new setup should account for our types of images, including the combinations we couldn't really handle until now:
- Canvas init and masks
- Canvas when saved-to-gallery or merged
This commit is contained in:
psychedelicious 2023-05-27 21:39:20 +10:00 committed by Kent Keirsey
parent fd47e70c92
commit 160267c71a
17 changed files with 291 additions and 311 deletions

View File

@ -1,39 +0,0 @@
from typing import Optional
from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageType
class ImageResponseMetadata(BaseModel):
"""An image's metadata. Used only in HTTP responses."""
created: int = Field(description="The creation timestamp of the image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# invokeai: Optional[InvokeAIMetadata] = Field(
# description="The image's InvokeAI-specific metadata"
# )
class ImageResponse(BaseModel):
"""The response type for images"""
image_type: ImageType = Field(description="The type of the image")
image_name: str = Field(description="The name of the image")
image_url: str = Field(description="The url of the image")
thumbnail_url: str = Field(description="The url of the image's thumbnail")
metadata: ImageResponseMetadata = Field(description="The image's metadata")
class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing"""
width: int = Field(description="The effective width of the image in pixels")
height: int = Field(description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL")
class SavedImage(BaseModel):
image_name: str = Field(description="The name of the saved image")
thumbnail_name: str = Field(description="The name of the saved thumbnail")
created: int = Field(description="The created timestamp of the saved image")

View File

@ -6,7 +6,7 @@ from fastapi.responses import FileResponse
from PIL import Image from PIL import Image
from invokeai.app.models.image import ( from invokeai.app.models.image import (
ImageCategory, ImageCategory,
ImageType, ResourceOrigin,
) )
from invokeai.app.services.models.image_record import ( from invokeai.app.services.models.image_record import (
ImageDTO, ImageDTO,
@ -36,9 +36,6 @@ async def upload_image(
response: Response, response: Response,
image_category: ImageCategory = Query(description="The category of the image"), image_category: ImageCategory = Query(description="The category of the image"),
is_intermediate: bool = Query(description="Whether this is an intermediate image"), is_intermediate: bool = Query(description="Whether this is an intermediate image"),
show_in_gallery: bool = Query(
description="Whether this image should be shown in the gallery"
),
session_id: Optional[str] = Query( session_id: Optional[str] = Query(
default=None, description="The session ID associated with this upload, if any" default=None, description="The session ID associated with this upload, if any"
), ),
@ -58,11 +55,10 @@ async def upload_image(
try: try:
image_dto = ApiDependencies.invoker.services.images.create( image_dto = ApiDependencies.invoker.services.images.create(
image=pil_image, image=pil_image,
image_type=ImageType.UPLOAD, image_origin=ResourceOrigin.EXTERNAL,
image_category=image_category, image_category=image_category,
session_id=session_id, session_id=session_id,
is_intermediate=is_intermediate, is_intermediate=is_intermediate,
show_in_gallery=show_in_gallery,
) )
response.status_code = 201 response.status_code = 201
@ -73,27 +69,27 @@ async def upload_image(
raise HTTPException(status_code=500, detail="Failed to create image") raise HTTPException(status_code=500, detail="Failed to create image")
@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image") @images_router.delete("/{image_origin}/{image_name}", operation_id="delete_image")
async def delete_image( async def delete_image(
image_type: ImageType = Path(description="The type of image to delete"), image_origin: ResourceOrigin = Path(description="The origin of image to delete"),
image_name: str = Path(description="The name of the image to delete"), image_name: str = Path(description="The name of the image to delete"),
) -> None: ) -> None:
"""Deletes an image""" """Deletes an image"""
try: try:
ApiDependencies.invoker.services.images.delete(image_type, image_name) ApiDependencies.invoker.services.images.delete(image_origin, image_name)
except Exception as e: except Exception as e:
# TODO: Does this need any exception handling at all? # TODO: Does this need any exception handling at all?
pass pass
@images_router.patch( @images_router.patch(
"/{image_type}/{image_name}", "/{image_origin}/{image_name}",
operation_id="update_image", operation_id="update_image",
response_model=ImageDTO, response_model=ImageDTO,
) )
async def update_image( async def update_image(
image_type: ImageType = Path(description="The type of image to update"), image_origin: ResourceOrigin = Path(description="The origin of image to update"),
image_name: str = Path(description="The name of the image to update"), image_name: str = Path(description="The name of the image to update"),
image_changes: ImageRecordChanges = Body( image_changes: ImageRecordChanges = Body(
description="The changes to apply to the image" description="The changes to apply to the image"
@ -103,31 +99,31 @@ async def update_image(
try: try:
return ApiDependencies.invoker.services.images.update( return ApiDependencies.invoker.services.images.update(
image_type, image_name, image_changes image_origin, image_name, image_changes
) )
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail="Failed to update image") raise HTTPException(status_code=400, detail="Failed to update image")
@images_router.get( @images_router.get(
"/{image_type}/{image_name}/metadata", "/{image_origin}/{image_name}/metadata",
operation_id="get_image_metadata", operation_id="get_image_metadata",
response_model=ImageDTO, response_model=ImageDTO,
) )
async def get_image_metadata( async def get_image_metadata(
image_type: ImageType = Path(description="The type of image to get"), image_origin: ResourceOrigin = Path(description="The origin of image to get"),
image_name: str = Path(description="The name of image to get"), image_name: str = Path(description="The name of image to get"),
) -> ImageDTO: ) -> ImageDTO:
"""Gets an image's metadata""" """Gets an image's metadata"""
try: try:
return ApiDependencies.invoker.services.images.get_dto(image_type, image_name) return ApiDependencies.invoker.services.images.get_dto(image_origin, image_name)
except Exception as e: except Exception as e:
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@images_router.get( @images_router.get(
"/{image_type}/{image_name}", "/{image_origin}/{image_name}",
operation_id="get_image_full", operation_id="get_image_full",
response_class=Response, response_class=Response,
responses={ responses={
@ -139,7 +135,7 @@ async def get_image_metadata(
}, },
) )
async def get_image_full( async def get_image_full(
image_type: ImageType = Path( image_origin: ResourceOrigin = Path(
description="The type of full-resolution image file to get" description="The type of full-resolution image file to get"
), ),
image_name: str = Path(description="The name of full-resolution image file to get"), image_name: str = Path(description="The name of full-resolution image file to get"),
@ -147,7 +143,7 @@ async def get_image_full(
"""Gets a full-resolution image file""" """Gets a full-resolution image file"""
try: try:
path = ApiDependencies.invoker.services.images.get_path(image_type, image_name) path = ApiDependencies.invoker.services.images.get_path(image_origin, image_name)
if not ApiDependencies.invoker.services.images.validate_path(path): if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@ -163,7 +159,7 @@ async def get_image_full(
@images_router.get( @images_router.get(
"/{image_type}/{image_name}/thumbnail", "/{image_origin}/{image_name}/thumbnail",
operation_id="get_image_thumbnail", operation_id="get_image_thumbnail",
response_class=Response, response_class=Response,
responses={ responses={
@ -175,14 +171,14 @@ async def get_image_full(
}, },
) )
async def get_image_thumbnail( async def get_image_thumbnail(
image_type: ImageType = Path(description="The type of thumbnail image file to get"), image_origin: ResourceOrigin = Path(description="The origin of thumbnail image file to get"),
image_name: str = Path(description="The name of thumbnail image file to get"), image_name: str = Path(description="The name of thumbnail image file to get"),
) -> FileResponse: ) -> FileResponse:
"""Gets a thumbnail image file""" """Gets a thumbnail image file"""
try: try:
path = ApiDependencies.invoker.services.images.get_path( path = ApiDependencies.invoker.services.images.get_path(
image_type, image_name, thumbnail=True image_origin, image_name, thumbnail=True
) )
if not ApiDependencies.invoker.services.images.validate_path(path): if not ApiDependencies.invoker.services.images.validate_path(path):
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
@ -195,25 +191,25 @@ async def get_image_thumbnail(
@images_router.get( @images_router.get(
"/{image_type}/{image_name}/urls", "/{image_origin}/{image_name}/urls",
operation_id="get_image_urls", operation_id="get_image_urls",
response_model=ImageUrlsDTO, response_model=ImageUrlsDTO,
) )
async def get_image_urls( async def get_image_urls(
image_type: ImageType = Path(description="The type of the image whose URL to get"), image_origin: ResourceOrigin = Path(description="The origin of the image whose URL to get"),
image_name: str = Path(description="The name of the image whose URL to get"), image_name: str = Path(description="The name of the image whose URL to get"),
) -> ImageUrlsDTO: ) -> ImageUrlsDTO:
"""Gets an image and thumbnail URL""" """Gets an image and thumbnail URL"""
try: try:
image_url = ApiDependencies.invoker.services.images.get_url( image_url = ApiDependencies.invoker.services.images.get_url(
image_type, image_name image_origin, image_name
) )
thumbnail_url = ApiDependencies.invoker.services.images.get_url( thumbnail_url = ApiDependencies.invoker.services.images.get_url(
image_type, image_name, thumbnail=True image_origin, image_name, thumbnail=True
) )
return ImageUrlsDTO( return ImageUrlsDTO(
image_type=image_type, image_origin=image_origin,
image_name=image_name, image_name=image_name,
image_url=image_url, image_url=image_url,
thumbnail_url=thumbnail_url, thumbnail_url=thumbnail_url,
@ -228,30 +224,33 @@ async def get_image_urls(
response_model=PaginatedResults[ImageDTO], response_model=PaginatedResults[ImageDTO],
) )
async def list_images_with_metadata( async def list_images_with_metadata(
image_type: Optional[ImageType] = Query( image_origin: Optional[ResourceOrigin] = Query(
default=None, description="The type of images to list" default=None, description="The origin of images to list"
), ),
image_category: Optional[ImageCategory] = Query( include_categories: Optional[list[ImageCategory]] = Query(
default=None, description="The kind of images to list" default=None, description="The categories of image to include"
),
exclude_categories: Optional[list[ImageCategory]] = Query(
default=None, description="The categories of image to exclude"
), ),
is_intermediate: Optional[bool] = Query( is_intermediate: Optional[bool] = Query(
default=None, description="Whether to list intermediate images" default=None, description="Whether to list intermediate images"
), ),
show_in_gallery: Optional[bool] = Query(
default=None, description="Whether to list images that show in the gallery"
),
page: int = Query(default=0, description="The page of images to get"), page: int = Query(default=0, description="The page of images to get"),
per_page: int = Query(default=10, description="The number of images per page"), per_page: int = Query(default=10, description="The number of images per page"),
) -> PaginatedResults[ImageDTO]: ) -> PaginatedResults[ImageDTO]:
"""Gets a list of images""" """Gets a list of images"""
if include_categories is not None and exclude_categories is not None:
raise HTTPException(status_code=400, detail="Cannot use both 'include_category' and 'exclude_category' at the same time.")
image_dtos = ApiDependencies.invoker.services.images.get_many( image_dtos = ApiDependencies.invoker.services.images.get_many(
page, page,
per_page, per_page,
image_type, image_origin,
image_category, include_categories,
exclude_categories,
is_intermediate, is_intermediate,
show_in_gallery,
) )
return image_dtos return image_dtos

View File

@ -7,7 +7,7 @@ import numpy
from PIL import Image, ImageOps from PIL import Image, ImageOps
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.models.image import ImageCategory, ImageField, ImageType from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput from .image import ImageOutput
@ -37,10 +37,10 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
mask = context.services.images.get_pil_image( mask = context.services.images.get_pil_image(
self.mask.image_type, self.mask.image_name self.mask.image_origin, self.mask.image_name
) )
# Convert to cv image/mask # Convert to cv image/mask
@ -57,7 +57,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=image_inpainted, image=image_inpainted,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
@ -67,7 +67,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,

View File

@ -10,9 +10,9 @@ import torch
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.app.models.image import ColorField, ImageField, ImageType from invokeai.app.models.image import ColorField, ImageField, ResourceOrigin
from invokeai.app.invocations.util.choose_model import choose_model from invokeai.app.invocations.util.choose_model import choose_model
from invokeai.app.models.image import ImageCategory, ImageType from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.generator.inpaint import infill_methods from invokeai.backend.generator.inpaint import infill_methods
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
@ -120,7 +120,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=generate_output.image, image=generate_output.image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
node_id=self.id, node_id=self.id,
@ -130,7 +130,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -170,7 +170,7 @@ class ImageToImageInvocation(TextToImageInvocation):
None None
if self.image is None if self.image is None
else context.services.images.get_pil_image( else context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
) )
@ -201,7 +201,7 @@ class ImageToImageInvocation(TextToImageInvocation):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=generator_output.image, image=generator_output.image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
node_id=self.id, node_id=self.id,
@ -211,7 +211,7 @@ class ImageToImageInvocation(TextToImageInvocation):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -283,13 +283,13 @@ class InpaintInvocation(ImageToImageInvocation):
None None
if self.image is None if self.image is None
else context.services.images.get_pil_image( else context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
) )
mask = ( mask = (
None None
if self.mask is None if self.mask is None
else context.services.images.get_pil_image(self.mask.image_type, self.mask.image_name) else context.services.images.get_pil_image(self.mask.image_origin, self.mask.image_name)
) )
# Handle invalid model parameter # Handle invalid model parameter
@ -317,7 +317,7 @@ class InpaintInvocation(ImageToImageInvocation):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=generator_output.image, image=generator_output.image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
node_id=self.id, node_id=self.id,
@ -327,7 +327,7 @@ class InpaintInvocation(ImageToImageInvocation):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,

View File

@ -7,7 +7,7 @@ import numpy
from PIL import Image, ImageFilter, ImageOps, ImageChops from PIL import Image, ImageFilter, ImageOps, ImageChops
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from ..models.image import ImageCategory, ImageField, ImageType from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
BaseInvocationOutput, BaseInvocationOutput,
@ -72,12 +72,12 @@ class LoadImageInvocation(BaseInvocation):
) )
# fmt: on # fmt: on
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_type, self.image.image_name) image = context.services.images.get_pil_image(self.image.image_origin, self.image.image_name)
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=self.image.image_name, image_name=self.image.image_name,
image_type=self.image.image_type, image_origin=self.image.image_origin,
), ),
width=image.width, width=image.width,
height=image.height, height=image.height,
@ -96,7 +96,7 @@ class ShowImageInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
if image: if image:
image.show() image.show()
@ -106,7 +106,7 @@ class ShowImageInvocation(BaseInvocation):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=self.image.image_name, image_name=self.image.image_name,
image_type=self.image.image_type, image_origin=self.image.image_origin,
), ),
width=image.width, width=image.width,
height=image.height, height=image.height,
@ -129,7 +129,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
image_crop = Image.new( image_crop = Image.new(
@ -139,7 +139,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=image_crop, image=image_crop,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
@ -149,7 +149,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -172,17 +172,17 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get_pil_image( base_image = context.services.images.get_pil_image(
self.base_image.image_type, self.base_image.image_name self.base_image.image_origin, self.base_image.image_name
) )
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
mask = ( mask = (
None None
if self.mask is None if self.mask is None
else ImageOps.invert( else ImageOps.invert(
context.services.images.get_pil_image( context.services.images.get_pil_image(
self.mask.image_type, self.mask.image_name self.mask.image_origin, self.mask.image_name
) )
) )
) )
@ -201,7 +201,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=new_image, image=new_image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
@ -211,7 +211,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -231,7 +231,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> MaskOutput: def invoke(self, context: InvocationContext) -> MaskOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
image_mask = image.split()[-1] image_mask = image.split()[-1]
@ -240,7 +240,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=image_mask, image=image_mask,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.MASK, image_category=ImageCategory.MASK,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
@ -249,7 +249,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig):
return MaskOutput( return MaskOutput(
mask=ImageField( mask=ImageField(
image_type=image_dto.image_type, image_name=image_dto.image_name image_origin=image_dto.image_origin, image_name=image_dto.image_name
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -269,17 +269,17 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image1 = context.services.images.get_pil_image( image1 = context.services.images.get_pil_image(
self.image1.image_type, self.image1.image_name self.image1.image_origin, self.image1.image_name
) )
image2 = context.services.images.get_pil_image( image2 = context.services.images.get_pil_image(
self.image2.image_type, self.image2.image_name self.image2.image_origin, self.image2.image_name
) )
multiply_image = ImageChops.multiply(image1, image2) multiply_image = ImageChops.multiply(image1, image2)
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=multiply_image, image=multiply_image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
@ -288,7 +288,7 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_type=image_dto.image_type, image_name=image_dto.image_name image_origin=image_dto.image_origin, image_name=image_dto.image_name
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -311,14 +311,14 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
channel_image = image.getchannel(self.channel) channel_image = image.getchannel(self.channel)
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=channel_image, image=channel_image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
@ -327,7 +327,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_type=image_dto.image_type, image_name=image_dto.image_name image_origin=image_dto.image_origin, image_name=image_dto.image_name
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -350,14 +350,14 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
converted_image = image.convert(self.mode) converted_image = image.convert(self.mode)
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=converted_image, image=converted_image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
@ -366,7 +366,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_type=image_dto.image_type, image_name=image_dto.image_name image_origin=image_dto.image_origin, image_name=image_dto.image_name
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -387,7 +387,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
blur = ( blur = (
@ -399,7 +399,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=blur_image, image=blur_image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
@ -409,7 +409,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -430,7 +430,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
image_arr = numpy.asarray(image, dtype=numpy.float32) / 255 image_arr = numpy.asarray(image, dtype=numpy.float32) / 255
@ -440,7 +440,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=lerp_image, image=lerp_image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
@ -450,7 +450,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -471,7 +471,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
image_arr = numpy.asarray(image, dtype=numpy.float32) image_arr = numpy.asarray(image, dtype=numpy.float32)
@ -486,7 +486,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=ilerp_image, image=ilerp_image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
@ -496,7 +496,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,

View File

@ -11,7 +11,7 @@ from invokeai.app.invocations.image import ImageOutput
from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.app.util.misc import SEED_MAX, get_random_seed
from invokeai.backend.image_util.patchmatch import PatchMatch from invokeai.backend.image_util.patchmatch import PatchMatch
from ..models.image import ColorField, ImageCategory, ImageField, ImageType from ..models.image import ColorField, ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
InvocationContext, InvocationContext,
@ -135,7 +135,7 @@ class InfillColorInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
solid_bg = Image.new("RGBA", image.size, self.color.tuple()) solid_bg = Image.new("RGBA", image.size, self.color.tuple())
@ -145,7 +145,7 @@ class InfillColorInvocation(BaseInvocation):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=infilled, image=infilled,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
@ -155,7 +155,7 @@ class InfillColorInvocation(BaseInvocation):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -180,7 +180,7 @@ class InfillTileInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
infilled = tile_fill_missing( infilled = tile_fill_missing(
@ -190,7 +190,7 @@ class InfillTileInvocation(BaseInvocation):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=infilled, image=infilled,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
@ -200,7 +200,7 @@ class InfillTileInvocation(BaseInvocation):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -218,7 +218,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
if PatchMatch.patchmatch_available(): if PatchMatch.patchmatch_available():
@ -228,7 +228,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=infilled, image=infilled,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
@ -238,7 +238,7 @@ class InfillPatchMatchInvocation(BaseInvocation):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,

View File

@ -28,7 +28,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import ControlNetData
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
import numpy as np import numpy as np
from ..services.image_file_storage import ImageType from ..services.image_file_storage import ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput from .image import ImageField, ImageOutput
from .compel import ConditioningField from .compel import ConditioningField
@ -468,7 +468,7 @@ class LatentsToImageInvocation(BaseInvocation):
# and gnenerate unique image_name # and gnenerate unique image_name
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=image, image=image,
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
node_id=self.id, node_id=self.id,
@ -478,7 +478,7 @@ class LatentsToImageInvocation(BaseInvocation):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
@ -576,7 +576,7 @@ class ImageToLatentsInvocation(BaseInvocation):
# self.image.image_type, self.image.image_name # self.image.image_type, self.image.image_name
# ) # )
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
# TODO: this only really needs the vae # TODO: this only really needs the vae

View File

@ -2,7 +2,7 @@ from typing import Literal, Union
from pydantic import Field from pydantic import Field
from invokeai.app.models.image import ImageCategory, ImageField, ImageType from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput from .image import ImageOutput
@ -29,7 +29,7 @@ class RestoreFaceInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
results = context.services.restoration.upscale_and_reconstruct( results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]], image_list=[[image, 0]],
@ -43,7 +43,7 @@ class RestoreFaceInvocation(BaseInvocation):
# TODO: can this return multiple results? # TODO: can this return multiple results?
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=results[0][0], image=results[0][0],
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
@ -53,7 +53,7 @@ class RestoreFaceInvocation(BaseInvocation):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,

View File

@ -4,7 +4,7 @@ from typing import Literal, Union
from pydantic import Field from pydantic import Field
from invokeai.app.models.image import ImageCategory, ImageField, ImageType from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .image import ImageOutput from .image import ImageOutput
@ -31,7 +31,7 @@ class UpscaleInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image( image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name self.image.image_origin, self.image.image_name
) )
results = context.services.restoration.upscale_and_reconstruct( results = context.services.restoration.upscale_and_reconstruct(
image_list=[[image, 0]], image_list=[[image, 0]],
@ -45,7 +45,7 @@ class UpscaleInvocation(BaseInvocation):
# TODO: can this return multiple results? # TODO: can this return multiple results?
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=results[0][0], image=results[0][0],
image_type=ImageType.RESULT, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
@ -55,7 +55,7 @@ class UpscaleInvocation(BaseInvocation):
return ImageOutput( return ImageOutput(
image=ImageField( image=ImageField(
image_name=image_dto.image_name, image_name=image_dto.image_name,
image_type=image_dto.image_type, image_origin=image_dto.image_origin,
), ),
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,

View File

@ -5,30 +5,52 @@ from pydantic import BaseModel, Field
from invokeai.app.util.metaenum import MetaEnum from invokeai.app.util.metaenum import MetaEnum
class ImageType(str, Enum, metaclass=MetaEnum): class ResourceOrigin(str, Enum, metaclass=MetaEnum):
"""The type of an image.""" """The origin of a resource (eg image).
RESULT = "results" - INTERNAL: The resource was created by the application.
UPLOAD = "uploads" - EXTERNAL: The resource was not created by the application.
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
"""
INTERNAL = "internal"
"""The resource was created by the application."""
EXTERNAL = "external"
"""The resource was not created by the application.
This may be a user-initiated upload, or an internal application upload (eg Canvas init image).
"""
class InvalidImageTypeException(ValueError): class InvalidOriginException(ValueError):
"""Raised when a provided value is not a valid ImageType. """Raised when a provided value is not a valid ResourceOrigin.
Subclasses `ValueError`. Subclasses `ValueError`.
""" """
def __init__(self, message="Invalid image type."): def __init__(self, message="Invalid resource origin."):
super().__init__(message) super().__init__(message)
class ImageCategory(str, Enum, metaclass=MetaEnum): class ImageCategory(str, Enum, metaclass=MetaEnum):
"""The category of an image. Use ImageCategory.OTHER for non-default categories.""" """The category of an image.
- GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose.
- MASK: The image is a mask image.
- CONTROL: The image is a ControlNet control image.
- USER: The image is a user-provide image.
- OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes.
"""
GENERAL = "general" GENERAL = "general"
CONTROL = "control" """GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose."""
MASK = "mask" MASK = "mask"
"""MASK: The image is a mask image."""
CONTROL = "control"
"""CONTROL: The image is a ControlNet control image."""
USER = "user"
"""USER: The image is a user-provide image."""
OTHER = "other" OTHER = "other"
"""OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes."""
class InvalidImageCategoryException(ValueError): class InvalidImageCategoryException(ValueError):
@ -44,13 +66,13 @@ class InvalidImageCategoryException(ValueError):
class ImageField(BaseModel): class ImageField(BaseModel):
"""An image field used for passing image objects between invocations""" """An image field used for passing image objects between invocations"""
image_type: ImageType = Field( image_origin: ResourceOrigin = Field(
default=ImageType.RESULT, description="The type of the image" default=ResourceOrigin.INTERNAL, description="The type of the image"
) )
image_name: Optional[str] = Field(default=None, description="The name of the image") image_name: Optional[str] = Field(default=None, description="The name of the image")
class Config: class Config:
schema_extra = {"required": ["image_type", "image_name"]} schema_extra = {"required": ["image_origin", "image_name"]}
class ColorField(BaseModel): class ColorField(BaseModel):
@ -61,3 +83,11 @@ class ColorField(BaseModel):
def tuple(self) -> Tuple[int, int, int, int]: def tuple(self) -> Tuple[int, int, int, int]:
return (self.r, self.g, self.b, self.a) return (self.r, self.g, self.b, self.a)
class ProgressImage(BaseModel):
"""The progress image sent intermittently during processing"""
width: int = Field(description="The effective width of the image in pixels")
height: int = Field(description="The effective height of the image in pixels")
dataURL: str = Field(description="The image data as a b64 data URL")

View File

@ -1,7 +1,7 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from typing import Any, Optional from typing import Any
from invokeai.app.api.models.images import ProgressImage from invokeai.app.models.image import ProgressImage
from invokeai.app.util.misc import get_timestamp from invokeai.app.util.misc import get_timestamp

View File

@ -9,7 +9,7 @@ from PIL.Image import Image as PILImageType
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
from send2trash import send2trash from send2trash import send2trash
from invokeai.app.models.image import ImageType from invokeai.app.models.image import ResourceOrigin
from invokeai.app.models.metadata import ImageMetadata from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail
@ -40,13 +40,13 @@ class ImageFileStorageBase(ABC):
"""Low-level service responsible for storing and retrieving image files.""" """Low-level service responsible for storing and retrieving image files."""
@abstractmethod @abstractmethod
def get(self, image_type: ImageType, image_name: str) -> PILImageType: def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
"""Retrieves an image as PIL Image.""" """Retrieves an image as PIL Image."""
pass pass
@abstractmethod @abstractmethod
def get_path( def get_path(
self, image_type: ImageType, image_name: str, thumbnail: bool = False self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str: ) -> str:
"""Gets the internal path to an image or thumbnail.""" """Gets the internal path to an image or thumbnail."""
pass pass
@ -62,7 +62,7 @@ class ImageFileStorageBase(ABC):
def save( def save(
self, self,
image: PILImageType, image: PILImageType,
image_type: ImageType, image_origin: ResourceOrigin,
image_name: str, image_name: str,
metadata: Optional[ImageMetadata] = None, metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256, thumbnail_size: int = 256,
@ -71,7 +71,7 @@ class ImageFileStorageBase(ABC):
pass pass
@abstractmethod @abstractmethod
def delete(self, image_type: ImageType, image_name: str) -> None: def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
"""Deletes an image and its thumbnail (if one exists).""" """Deletes an image and its thumbnail (if one exists)."""
pass pass
@ -93,17 +93,17 @@ class DiskImageFileStorage(ImageFileStorageBase):
Path(output_folder).mkdir(parents=True, exist_ok=True) Path(output_folder).mkdir(parents=True, exist_ok=True)
# TODO: don't hard-code. get/save/delete should maybe take subpath? # TODO: don't hard-code. get/save/delete should maybe take subpath?
for image_type in ImageType: for image_origin in ResourceOrigin:
Path(os.path.join(output_folder, image_type)).mkdir( Path(os.path.join(output_folder, image_origin)).mkdir(
parents=True, exist_ok=True parents=True, exist_ok=True
) )
Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir( Path(os.path.join(output_folder, image_origin, "thumbnails")).mkdir(
parents=True, exist_ok=True parents=True, exist_ok=True
) )
def get(self, image_type: ImageType, image_name: str) -> PILImageType: def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
try: try:
image_path = self.get_path(image_type, image_name) image_path = self.get_path(image_origin, image_name)
cache_item = self.__get_cache(image_path) cache_item = self.__get_cache(image_path)
if cache_item: if cache_item:
return cache_item return cache_item
@ -117,13 +117,13 @@ class DiskImageFileStorage(ImageFileStorageBase):
def save( def save(
self, self,
image: PILImageType, image: PILImageType,
image_type: ImageType, image_origin: ResourceOrigin,
image_name: str, image_name: str,
metadata: Optional[ImageMetadata] = None, metadata: Optional[ImageMetadata] = None,
thumbnail_size: int = 256, thumbnail_size: int = 256,
) -> None: ) -> None:
try: try:
image_path = self.get_path(image_type, image_name) image_path = self.get_path(image_origin, image_name)
if metadata is not None: if metadata is not None:
pnginfo = PngImagePlugin.PngInfo() pnginfo = PngImagePlugin.PngInfo()
@ -133,7 +133,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
image.save(image_path, "PNG") image.save(image_path, "PNG")
thumbnail_name = get_thumbnail_name(image_name) thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_type, thumbnail_name, thumbnail=True) thumbnail_path = self.get_path(image_origin, thumbnail_name, thumbnail=True)
thumbnail_image = make_thumbnail(image, thumbnail_size) thumbnail_image = make_thumbnail(image, thumbnail_size)
thumbnail_image.save(thumbnail_path) thumbnail_image.save(thumbnail_path)
@ -142,10 +142,10 @@ class DiskImageFileStorage(ImageFileStorageBase):
except Exception as e: except Exception as e:
raise ImageFileSaveException from e raise ImageFileSaveException from e
def delete(self, image_type: ImageType, image_name: str) -> None: def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
try: try:
basename = os.path.basename(image_name) basename = os.path.basename(image_name)
image_path = self.get_path(image_type, basename) image_path = self.get_path(image_origin, basename)
if os.path.exists(image_path): if os.path.exists(image_path):
send2trash(image_path) send2trash(image_path)
@ -153,7 +153,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
del self.__cache[image_path] del self.__cache[image_path]
thumbnail_name = get_thumbnail_name(image_name) thumbnail_name = get_thumbnail_name(image_name)
thumbnail_path = self.get_path(image_type, thumbnail_name, True) thumbnail_path = self.get_path(image_origin, thumbnail_name, True)
if os.path.exists(thumbnail_path): if os.path.exists(thumbnail_path):
send2trash(thumbnail_path) send2trash(thumbnail_path)
@ -164,7 +164,7 @@ class DiskImageFileStorage(ImageFileStorageBase):
# TODO: make this a bit more flexible for e.g. cloud storage # TODO: make this a bit more flexible for e.g. cloud storage
def get_path( def get_path(
self, image_type: ImageType, image_name: str, thumbnail: bool = False self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str: ) -> str:
# strip out any relative path shenanigans # strip out any relative path shenanigans
basename = os.path.basename(image_name) basename = os.path.basename(image_name)
@ -172,10 +172,10 @@ class DiskImageFileStorage(ImageFileStorageBase):
if thumbnail: if thumbnail:
thumbnail_name = get_thumbnail_name(basename) thumbnail_name = get_thumbnail_name(basename)
path = os.path.join( path = os.path.join(
self.__output_folder, image_type, "thumbnails", thumbnail_name self.__output_folder, image_origin, "thumbnails", thumbnail_name
) )
else: else:
path = os.path.join(self.__output_folder, image_type, basename) path = os.path.join(self.__output_folder, image_origin, basename)
abspath = os.path.abspath(path) abspath = os.path.abspath(path)

View File

@ -8,7 +8,7 @@ from typing import Optional, Union
from invokeai.app.models.metadata import ImageMetadata from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.models.image import ( from invokeai.app.models.image import (
ImageCategory, ImageCategory,
ImageType, ResourceOrigin,
) )
from invokeai.app.services.models.image_record import ( from invokeai.app.services.models.image_record import (
ImageRecord, ImageRecord,
@ -46,7 +46,7 @@ class ImageRecordStorageBase(ABC):
# TODO: Implement an `update()` method # TODO: Implement an `update()` method
@abstractmethod @abstractmethod
def get(self, image_type: ImageType, image_name: str) -> ImageRecord: def get(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
"""Gets an image record.""" """Gets an image record."""
pass pass
@ -54,7 +54,7 @@ class ImageRecordStorageBase(ABC):
def update( def update(
self, self,
image_name: str, image_name: str,
image_type: ImageType, image_origin: ResourceOrigin,
changes: ImageRecordChanges, changes: ImageRecordChanges,
) -> None: ) -> None:
"""Updates an image record.""" """Updates an image record."""
@ -65,10 +65,10 @@ class ImageRecordStorageBase(ABC):
self, self,
page: int = 0, page: int = 0,
per_page: int = 10, per_page: int = 10,
image_type: Optional[ImageType] = None, image_origin: Optional[ResourceOrigin] = None,
image_category: Optional[ImageCategory] = None, include_categories: Optional[list[ImageCategory]] = None,
exclude_categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None, is_intermediate: Optional[bool] = None,
show_in_gallery: Optional[bool] = None,
) -> PaginatedResults[ImageRecord]: ) -> PaginatedResults[ImageRecord]:
"""Gets a page of image records.""" """Gets a page of image records."""
pass pass
@ -76,7 +76,7 @@ class ImageRecordStorageBase(ABC):
# TODO: The database has a nullable `deleted_at` column, currently unused. # TODO: The database has a nullable `deleted_at` column, currently unused.
# Should we implement soft deletes? Would need coordination with ImageFileStorage. # Should we implement soft deletes? Would need coordination with ImageFileStorage.
@abstractmethod @abstractmethod
def delete(self, image_type: ImageType, image_name: str) -> None: def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
"""Deletes an image record.""" """Deletes an image record."""
pass pass
@ -84,7 +84,7 @@ class ImageRecordStorageBase(ABC):
def save( def save(
self, self,
image_name: str, image_name: str,
image_type: ImageType, image_origin: ResourceOrigin,
image_category: ImageCategory, image_category: ImageCategory,
width: int, width: int,
height: int, height: int,
@ -92,7 +92,6 @@ class ImageRecordStorageBase(ABC):
node_id: Optional[str], node_id: Optional[str],
metadata: Optional[ImageMetadata], metadata: Optional[ImageMetadata],
is_intermediate: bool = False, is_intermediate: bool = False,
show_in_gallery: bool = True,
) -> datetime: ) -> datetime:
"""Saves an image record.""" """Saves an image record."""
pass pass
@ -131,7 +130,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
CREATE TABLE IF NOT EXISTS images ( CREATE TABLE IF NOT EXISTS images (
image_name TEXT NOT NULL PRIMARY KEY, image_name TEXT NOT NULL PRIMARY KEY,
-- This is an enum in python, unrestricted string here for flexibility -- This is an enum in python, unrestricted string here for flexibility
image_type TEXT NOT NULL, image_origin TEXT NOT NULL,
-- This is an enum in python, unrestricted string here for flexibility -- This is an enum in python, unrestricted string here for flexibility
image_category TEXT NOT NULL, image_category TEXT NOT NULL,
width INTEGER NOT NULL, width INTEGER NOT NULL,
@ -139,7 +138,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
session_id TEXT, session_id TEXT,
node_id TEXT, node_id TEXT,
metadata TEXT, metadata TEXT,
show_in_gallery BOOLEAN DEFAULT TRUE,
is_intermediate BOOLEAN DEFAULT FALSE, is_intermediate BOOLEAN DEFAULT FALSE,
created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')),
-- Updated via trigger -- Updated via trigger
@ -158,7 +156,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
) )
self._cursor.execute( self._cursor.execute(
"""--sql """--sql
CREATE INDEX IF NOT EXISTS idx_images_image_type ON images(image_type); CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin);
""" """
) )
self._cursor.execute( self._cursor.execute(
@ -185,7 +183,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
""" """
) )
def get(self, image_type: ImageType, image_name: str) -> Union[ImageRecord, None]: def get(
self, image_origin: ResourceOrigin, image_name: str
) -> Union[ImageRecord, None]:
try: try:
self._lock.acquire() self._lock.acquire()
@ -212,7 +212,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
def update( def update(
self, self,
image_name: str, image_name: str,
image_type: ImageType, image_origin: ResourceOrigin,
changes: ImageRecordChanges, changes: ImageRecordChanges,
) -> None: ) -> None:
try: try:
@ -249,71 +249,72 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
self, self,
page: int = 0, page: int = 0,
per_page: int = 10, per_page: int = 10,
image_type: Optional[ImageType] = None, image_origin: Optional[ResourceOrigin] = None,
image_category: Optional[ImageCategory] = None, include_categories: Optional[list[ImageCategory]] = None,
exclude_categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None, is_intermediate: Optional[bool] = None,
show_in_gallery: Optional[bool] = None,
) -> PaginatedResults[ImageRecord]: ) -> PaginatedResults[ImageRecord]:
try: try:
self._lock.acquire() self._lock.acquire()
# Manually build two queries - one for the count, one for the records # Manually build two queries - one for the count, one for the records
count_query = """--sql count_query = f"""SELECT COUNT(*) FROM images WHERE 1=1\n"""
SELECT COUNT(*) FROM images WHERE 1=1 images_query = f"""SELECT * FROM images WHERE 1=1\n"""
"""
images_query = """--sql
SELECT * FROM images WHERE 1=1
"""
query_conditions = "" query_conditions = ""
query_params = [] query_params = []
if image_type is not None: if image_origin is not None:
query_conditions += """--sql query_conditions += f"""AND image_origin = ?\n"""
AND image_type = ? query_params.append(image_origin.value)
"""
query_params.append(image_type.value)
if image_category is not None: if include_categories is not None:
query_conditions += """--sql ## Convert the enum values to unique list of strings
AND image_category = ? include_category_strings = list(
""" map(lambda c: c.value, set(include_categories))
query_params.append(image_category.value) )
# Create the correct length of placeholders
placeholders = ",".join("?" * len(include_category_strings))
query_conditions += f"AND image_category IN ( {placeholders} )\n"
# Unpack the included categories into the query params
query_params.append(*include_category_strings)
if exclude_categories is not None:
## Convert the enum values to unique list of strings
exclude_category_strings = list(
map(lambda c: c.value, set(exclude_categories))
)
# Create the correct length of placeholders
placeholders = ",".join("?" * len(exclude_category_strings))
query_conditions += f"AND image_category NOT IN ( {placeholders} )\n"
# Unpack the included categories into the query params
query_params.append(*exclude_category_strings)
if is_intermediate is not None: if is_intermediate is not None:
query_conditions += """--sql query_conditions += f"""AND is_intermediate = ?\n"""
AND is_intermediate = ?
"""
query_params.append(is_intermediate) query_params.append(is_intermediate)
if show_in_gallery is not None: query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n"""
query_conditions += """--sql
AND show_in_gallery = ?
"""
query_params.append(show_in_gallery)
query_pagination = """--sql
ORDER BY created_at DESC LIMIT ? OFFSET ?
"""
count_query += query_conditions + ";"
count_params = query_params.copy()
# Final images query with pagination
images_query += query_conditions + query_pagination + ";" images_query += query_conditions + query_pagination + ";"
# Add all the parameters
images_params = query_params.copy() images_params = query_params.copy()
images_params.append(per_page) images_params.append(per_page)
images_params.append(page * per_page) images_params.append(page * per_page)
# Build the list of images, deserializing each row
self._cursor.execute(images_query, images_params) self._cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], self._cursor.fetchall()) result = cast(list[sqlite3.Row], self._cursor.fetchall())
images = list(map(lambda r: deserialize_image_record(dict(r)), result)) images = list(map(lambda r: deserialize_image_record(dict(r)), result))
# Set up and execute the count query, without pagination
count_query += query_conditions + ";"
count_params = query_params.copy()
self._cursor.execute(count_query, count_params) self._cursor.execute(count_query, count_params)
count = self._cursor.fetchone()[0] count = self._cursor.fetchone()[0]
except sqlite3.Error as e: except sqlite3.Error as e:
self._conn.rollback() self._conn.rollback()
@ -327,7 +328,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
items=images, page=page, pages=pageCount, per_page=per_page, total=count items=images, page=page, pages=pageCount, per_page=per_page, total=count
) )
def delete(self, image_type: ImageType, image_name: str) -> None: def delete(self, image_origin: ResourceOrigin, image_name: str) -> None:
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(
@ -347,7 +348,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
def save( def save(
self, self,
image_name: str, image_name: str,
image_type: ImageType, image_origin: ResourceOrigin,
image_category: ImageCategory, image_category: ImageCategory,
session_id: Optional[str], session_id: Optional[str],
width: int, width: int,
@ -355,7 +356,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
node_id: Optional[str], node_id: Optional[str],
metadata: Optional[ImageMetadata], metadata: Optional[ImageMetadata],
is_intermediate: bool = False, is_intermediate: bool = False,
show_in_gallery: bool = True,
) -> datetime: ) -> datetime:
try: try:
metadata_json = ( metadata_json = (
@ -366,21 +366,20 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
"""--sql """--sql
INSERT OR IGNORE INTO images ( INSERT OR IGNORE INTO images (
image_name, image_name,
image_type, image_origin,
image_category, image_category,
width, width,
height, height,
node_id, node_id,
session_id, session_id,
metadata, metadata,
is_intermediate, is_intermediate
show_in_gallery
) )
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?); VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?);
""", """,
( (
image_name, image_name,
image_type.value, image_origin.value,
image_category.value, image_category.value,
width, width,
height, height,
@ -388,7 +387,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
session_id, session_id,
metadata_json, metadata_json,
is_intermediate, is_intermediate,
show_in_gallery,
), ),
) )
self._conn.commit() self._conn.commit()

View File

@ -5,9 +5,9 @@ from PIL.Image import Image as PILImageType
from invokeai.app.models.image import ( from invokeai.app.models.image import (
ImageCategory, ImageCategory,
ImageType, ResourceOrigin,
InvalidImageCategoryException, InvalidImageCategoryException,
InvalidImageTypeException, InvalidOriginException,
) )
from invokeai.app.models.metadata import ImageMetadata from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.services.image_record_storage import ( from invokeai.app.services.image_record_storage import (
@ -44,12 +44,11 @@ class ImageServiceABC(ABC):
def create( def create(
self, self,
image: PILImageType, image: PILImageType,
image_type: ImageType, image_origin: ResourceOrigin,
image_category: ImageCategory, image_category: ImageCategory,
node_id: Optional[str] = None, node_id: Optional[str] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
intermediate: bool = False, intermediate: bool = False,
show_in_gallery: bool = True,
) -> ImageDTO: ) -> ImageDTO:
"""Creates an image, storing the file and its metadata.""" """Creates an image, storing the file and its metadata."""
pass pass
@ -57,7 +56,7 @@ class ImageServiceABC(ABC):
@abstractmethod @abstractmethod
def update( def update(
self, self,
image_type: ImageType, image_origin: ResourceOrigin,
image_name: str, image_name: str,
changes: ImageRecordChanges, changes: ImageRecordChanges,
) -> ImageDTO: ) -> ImageDTO:
@ -65,22 +64,22 @@ class ImageServiceABC(ABC):
pass pass
@abstractmethod @abstractmethod
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType: def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
"""Gets an image as a PIL image.""" """Gets an image as a PIL image."""
pass pass
@abstractmethod @abstractmethod
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord: def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
"""Gets an image record.""" """Gets an image record."""
pass pass
@abstractmethod @abstractmethod
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO: def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
"""Gets an image DTO.""" """Gets an image DTO."""
pass pass
@abstractmethod @abstractmethod
def get_path(self, image_type: ImageType, image_name: str) -> str: def get_path(self, image_origin: ResourceOrigin, image_name: str) -> str:
"""Gets an image's path.""" """Gets an image's path."""
pass pass
@ -91,7 +90,7 @@ class ImageServiceABC(ABC):
@abstractmethod @abstractmethod
def get_url( def get_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str: ) -> str:
"""Gets an image's or thumbnail's URL.""" """Gets an image's or thumbnail's URL."""
pass pass
@ -101,16 +100,16 @@ class ImageServiceABC(ABC):
self, self,
page: int = 0, page: int = 0,
per_page: int = 10, per_page: int = 10,
image_type: Optional[ImageType] = None, image_origin: Optional[ResourceOrigin] = None,
image_category: Optional[ImageCategory] = None, include_categories: Optional[list[ImageCategory]] = None,
exclude_categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None, is_intermediate: Optional[bool] = None,
show_in_gallery: Optional[bool] = None,
) -> PaginatedResults[ImageDTO]: ) -> PaginatedResults[ImageDTO]:
"""Gets a paginated list of image DTOs.""" """Gets a paginated list of image DTOs."""
pass pass
@abstractmethod @abstractmethod
def delete(self, image_type: ImageType, image_name: str): def delete(self, image_origin: ResourceOrigin, image_name: str):
"""Deletes an image.""" """Deletes an image."""
pass pass
@ -171,15 +170,14 @@ class ImageService(ImageServiceABC):
def create( def create(
self, self,
image: PILImageType, image: PILImageType,
image_type: ImageType, image_origin: ResourceOrigin,
image_category: ImageCategory, image_category: ImageCategory,
node_id: Optional[str] = None, node_id: Optional[str] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
is_intermediate: bool = False, is_intermediate: bool = False,
show_in_gallery: bool = True,
) -> ImageDTO: ) -> ImageDTO:
if image_type not in ImageType: if image_origin not in ResourceOrigin:
raise InvalidImageTypeException raise InvalidOriginException
if image_category not in ImageCategory: if image_category not in ImageCategory:
raise InvalidImageCategoryException raise InvalidImageCategoryException
@ -195,13 +193,12 @@ class ImageService(ImageServiceABC):
created_at = self._services.records.save( created_at = self._services.records.save(
# Non-nullable fields # Non-nullable fields
image_name=image_name, image_name=image_name,
image_type=image_type, image_origin=image_origin,
image_category=image_category, image_category=image_category,
width=width, width=width,
height=height, height=height,
# Meta fields # Meta fields
is_intermediate=is_intermediate, is_intermediate=is_intermediate,
show_in_gallery=show_in_gallery,
# Nullable fields # Nullable fields
node_id=node_id, node_id=node_id,
session_id=session_id, session_id=session_id,
@ -209,21 +206,21 @@ class ImageService(ImageServiceABC):
) )
self._services.files.save( self._services.files.save(
image_type=image_type, image_origin=image_origin,
image_name=image_name, image_name=image_name,
image=image, image=image,
metadata=metadata, metadata=metadata,
) )
image_url = self._services.urls.get_image_url(image_type, image_name) image_url = self._services.urls.get_image_url(image_origin, image_name)
thumbnail_url = self._services.urls.get_image_url( thumbnail_url = self._services.urls.get_image_url(
image_type, image_name, True image_origin, image_name, True
) )
return ImageDTO( return ImageDTO(
# Non-nullable fields # Non-nullable fields
image_name=image_name, image_name=image_name,
image_type=image_type, image_origin=image_origin,
image_category=image_category, image_category=image_category,
width=width, width=width,
height=height, height=height,
@ -236,7 +233,6 @@ class ImageService(ImageServiceABC):
updated_at=created_at, # this is always the same as the created_at at this time updated_at=created_at, # this is always the same as the created_at at this time
deleted_at=None, deleted_at=None,
is_intermediate=is_intermediate, is_intermediate=is_intermediate,
show_in_gallery=show_in_gallery,
# Extra non-nullable fields for DTO # Extra non-nullable fields for DTO
image_url=image_url, image_url=image_url,
thumbnail_url=thumbnail_url, thumbnail_url=thumbnail_url,
@ -253,13 +249,13 @@ class ImageService(ImageServiceABC):
def update( def update(
self, self,
image_type: ImageType, image_origin: ResourceOrigin,
image_name: str, image_name: str,
changes: ImageRecordChanges, changes: ImageRecordChanges,
) -> ImageDTO: ) -> ImageDTO:
try: try:
self._services.records.update(image_name, image_type, changes) self._services.records.update(image_name, image_origin, changes)
return self.get_dto(image_type, image_name) return self.get_dto(image_origin, image_name)
except ImageRecordSaveException: except ImageRecordSaveException:
self._services.logger.error("Failed to update image record") self._services.logger.error("Failed to update image record")
raise raise
@ -267,9 +263,9 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem updating image record") self._services.logger.error("Problem updating image record")
raise e raise e
def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType: def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType:
try: try:
return self._services.files.get(image_type, image_name) return self._services.files.get(image_origin, image_name)
except ImageFileNotFoundException: except ImageFileNotFoundException:
self._services.logger.error("Failed to get image file") self._services.logger.error("Failed to get image file")
raise raise
@ -277,9 +273,9 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image file") self._services.logger.error("Problem getting image file")
raise e raise e
def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord: def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord:
try: try:
return self._services.records.get(image_type, image_name) return self._services.records.get(image_origin, image_name)
except ImageRecordNotFoundException: except ImageRecordNotFoundException:
self._services.logger.error("Image record not found") self._services.logger.error("Image record not found")
raise raise
@ -287,14 +283,14 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting image record") self._services.logger.error("Problem getting image record")
raise e raise e
def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO: def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO:
try: try:
image_record = self._services.records.get(image_type, image_name) image_record = self._services.records.get(image_origin, image_name)
image_dto = image_record_to_dto( image_dto = image_record_to_dto(
image_record, image_record,
self._services.urls.get_image_url(image_type, image_name), self._services.urls.get_image_url(image_origin, image_name),
self._services.urls.get_image_url(image_type, image_name, True), self._services.urls.get_image_url(image_origin, image_name, True),
) )
return image_dto return image_dto
@ -306,10 +302,10 @@ class ImageService(ImageServiceABC):
raise e raise e
def get_path( def get_path(
self, image_type: ImageType, image_name: str, thumbnail: bool = False self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str: ) -> str:
try: try:
return self._services.files.get_path(image_type, image_name, thumbnail) return self._services.files.get_path(image_origin, image_name, thumbnail)
except Exception as e: except Exception as e:
self._services.logger.error("Problem getting image path") self._services.logger.error("Problem getting image path")
raise e raise e
@ -322,10 +318,10 @@ class ImageService(ImageServiceABC):
raise e raise e
def get_url( def get_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str: ) -> str:
try: try:
return self._services.urls.get_image_url(image_type, image_name, thumbnail) return self._services.urls.get_image_url(image_origin, image_name, thumbnail)
except Exception as e: except Exception as e:
self._services.logger.error("Problem getting image path") self._services.logger.error("Problem getting image path")
raise e raise e
@ -334,28 +330,28 @@ class ImageService(ImageServiceABC):
self, self,
page: int = 0, page: int = 0,
per_page: int = 10, per_page: int = 10,
image_type: Optional[ImageType] = None, image_origin: Optional[ResourceOrigin] = None,
image_category: Optional[ImageCategory] = None, include_categories: Optional[list[ImageCategory]] = None,
exclude_categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None, is_intermediate: Optional[bool] = None,
show_in_gallery: Optional[bool] = None,
) -> PaginatedResults[ImageDTO]: ) -> PaginatedResults[ImageDTO]:
try: try:
results = self._services.records.get_many( results = self._services.records.get_many(
page, page,
per_page, per_page,
image_type, image_origin,
image_category, include_categories,
exclude_categories,
is_intermediate, is_intermediate,
show_in_gallery,
) )
image_dtos = list( image_dtos = list(
map( map(
lambda r: image_record_to_dto( lambda r: image_record_to_dto(
r, r,
self._services.urls.get_image_url(r.image_type, r.image_name), self._services.urls.get_image_url(r.image_origin, r.image_name),
self._services.urls.get_image_url( self._services.urls.get_image_url(
r.image_type, r.image_name, True r.image_origin, r.image_name, True
), ),
), ),
results.items, results.items,
@ -373,10 +369,10 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Problem getting paginated image DTOs") self._services.logger.error("Problem getting paginated image DTOs")
raise e raise e
def delete(self, image_type: ImageType, image_name: str): def delete(self, image_origin: ResourceOrigin, image_name: str):
try: try:
self._services.files.delete(image_type, image_name) self._services.files.delete(image_origin, image_name)
self._services.records.delete(image_type, image_name) self._services.records.delete(image_origin, image_name)
except ImageRecordDeleteException: except ImageRecordDeleteException:
self._services.logger.error(f"Failed to delete image record") self._services.logger.error(f"Failed to delete image record")
raise raise

View File

@ -1,7 +1,7 @@
import datetime import datetime
from typing import Optional, Union from typing import Optional, Union
from pydantic import BaseModel, Extra, Field, StrictStr from pydantic import BaseModel, Extra, Field, StrictStr
from invokeai.app.models.image import ImageCategory, ImageType from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.app.models.metadata import ImageMetadata from invokeai.app.models.metadata import ImageMetadata
from invokeai.app.util.misc import get_iso_timestamp from invokeai.app.util.misc import get_iso_timestamp
@ -11,8 +11,8 @@ class ImageRecord(BaseModel):
image_name: str = Field(description="The unique name of the image.") image_name: str = Field(description="The unique name of the image.")
"""The unique name of the image.""" """The unique name of the image."""
image_type: ImageType = Field(description="The type of the image.") image_origin: ResourceOrigin = Field(description="The type of the image.")
"""The type of the image.""" """The origin of the image."""
image_category: ImageCategory = Field(description="The category of the image.") image_category: ImageCategory = Field(description="The category of the image.")
"""The category of the image.""" """The category of the image."""
width: int = Field(description="The width of the image in px.") width: int = Field(description="The width of the image in px.")
@ -33,8 +33,6 @@ class ImageRecord(BaseModel):
"""The deleted timestamp of the image.""" """The deleted timestamp of the image."""
is_intermediate: bool = Field(description="Whether this is an intermediate image.") is_intermediate: bool = Field(description="Whether this is an intermediate image.")
"""Whether this is an intermediate image.""" """Whether this is an intermediate image."""
show_in_gallery: bool = Field(description="Whether this image should be shown in the gallery.")
"""Whether this image should be shown in the gallery."""
session_id: Optional[str] = Field( session_id: Optional[str] = Field(
default=None, default=None,
description="The session ID that generated this image, if it is a generated image.", description="The session ID that generated this image, if it is a generated image.",
@ -76,8 +74,8 @@ class ImageUrlsDTO(BaseModel):
image_name: str = Field(description="The unique name of the image.") image_name: str = Field(description="The unique name of the image.")
"""The unique name of the image.""" """The unique name of the image."""
image_type: ImageType = Field(description="The type of the image.") image_origin: ResourceOrigin = Field(description="The type of the image.")
"""The type of the image.""" """The origin of the image."""
image_url: str = Field(description="The URL of the image.") image_url: str = Field(description="The URL of the image.")
"""The URL of the image.""" """The URL of the image."""
thumbnail_url: str = Field(description="The URL of the image's thumbnail.") thumbnail_url: str = Field(description="The URL of the image's thumbnail.")
@ -107,7 +105,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
# Retrieve all the values, setting "reasonable" defaults if they are not present. # Retrieve all the values, setting "reasonable" defaults if they are not present.
image_name = image_dict.get("image_name", "unknown") image_name = image_dict.get("image_name", "unknown")
image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value)) image_origin = ResourceOrigin(image_dict.get("image_origin", ResourceOrigin.INTERNAL.value))
image_category = ImageCategory( image_category = ImageCategory(
image_dict.get("image_category", ImageCategory.GENERAL.value) image_dict.get("image_category", ImageCategory.GENERAL.value)
) )
@ -119,7 +117,6 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
updated_at = image_dict.get("updated_at", get_iso_timestamp()) updated_at = image_dict.get("updated_at", get_iso_timestamp())
deleted_at = image_dict.get("deleted_at", get_iso_timestamp()) deleted_at = image_dict.get("deleted_at", get_iso_timestamp())
is_intermediate = image_dict.get("is_intermediate", False) is_intermediate = image_dict.get("is_intermediate", False)
show_in_gallery = image_dict.get("show_in_gallery", True)
raw_metadata = image_dict.get("metadata") raw_metadata = image_dict.get("metadata")
@ -130,7 +127,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
return ImageRecord( return ImageRecord(
image_name=image_name, image_name=image_name,
image_type=image_type, image_origin=image_origin,
image_category=image_category, image_category=image_category,
width=width, width=width,
height=height, height=height,
@ -141,5 +138,4 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
updated_at=updated_at, updated_at=updated_at,
deleted_at=deleted_at, deleted_at=deleted_at,
is_intermediate=is_intermediate, is_intermediate=is_intermediate,
show_in_gallery=show_in_gallery,
) )

View File

@ -1,7 +1,7 @@
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from invokeai.app.models.image import ImageType from invokeai.app.models.image import ResourceOrigin
from invokeai.app.util.thumbnails import get_thumbnail_name from invokeai.app.util.thumbnails import get_thumbnail_name
@ -10,7 +10,7 @@ class UrlServiceBase(ABC):
@abstractmethod @abstractmethod
def get_image_url( def get_image_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str: ) -> str:
"""Gets the URL for an image or thumbnail.""" """Gets the URL for an image or thumbnail."""
pass pass
@ -21,14 +21,14 @@ class LocalUrlService(UrlServiceBase):
self._base_url = base_url self._base_url = base_url
def get_image_url( def get_image_url(
self, image_type: ImageType, image_name: str, thumbnail: bool = False self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False
) -> str: ) -> str:
image_basename = os.path.basename(image_name) image_basename = os.path.basename(image_name)
# These paths are determined by the routes in invokeai/app/api/routers/images.py # These paths are determined by the routes in invokeai/app/api/routers/images.py
if thumbnail: if thumbnail:
return ( return (
f"{self._base_url}/images/{image_type.value}/{image_basename}/thumbnail" f"{self._base_url}/images/{image_origin.value}/{image_basename}/thumbnail"
) )
return f"{self._base_url}/images/{image_type.value}/{image_basename}" return f"{self._base_url}/images/{image_origin.value}/{image_basename}"

View File

@ -1,5 +1,5 @@
from invokeai.app.api.models.images import ProgressImage
from invokeai.app.models.exceptions import CanceledException from invokeai.app.models.exceptions import CanceledException
from invokeai.app.models.image import ProgressImage
from ..invocations.baseinvocation import InvocationContext from ..invocations.baseinvocation import InvocationContext
from ...backend.util.util import image_to_dataURL from ...backend.util.util import image_to_dataURL
from ...backend.generator.base import Generator from ...backend.generator.base import Generator