diff --git a/invokeai/app/api/models/images.py b/invokeai/app/api/models/images.py deleted file mode 100644 index fa04702326..0000000000 --- a/invokeai/app/api/models/images.py +++ /dev/null @@ -1,39 +0,0 @@ -from typing import Optional -from pydantic import BaseModel, Field - -from invokeai.app.models.image import ImageType - - -class ImageResponseMetadata(BaseModel): - """An image's metadata. Used only in HTTP responses.""" - - created: int = Field(description="The creation timestamp of the image") - width: int = Field(description="The width of the image in pixels") - height: int = Field(description="The height of the image in pixels") - # invokeai: Optional[InvokeAIMetadata] = Field( - # description="The image's InvokeAI-specific metadata" - # ) - - -class ImageResponse(BaseModel): - """The response type for images""" - - image_type: ImageType = Field(description="The type of the image") - image_name: str = Field(description="The name of the image") - image_url: str = Field(description="The url of the image") - thumbnail_url: str = Field(description="The url of the image's thumbnail") - metadata: ImageResponseMetadata = Field(description="The image's metadata") - - -class ProgressImage(BaseModel): - """The progress image sent intermittently during processing""" - - width: int = Field(description="The effective width of the image in pixels") - height: int = Field(description="The effective height of the image in pixels") - dataURL: str = Field(description="The image data as a b64 data URL") - - -class SavedImage(BaseModel): - image_name: str = Field(description="The name of the saved image") - thumbnail_name: str = Field(description="The name of the saved thumbnail") - created: int = Field(description="The created timestamp of the saved image") diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 55556dd79a..f0399a2d07 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -6,7 +6,7 @@ from fastapi.responses import FileResponse from PIL import Image from invokeai.app.models.image import ( ImageCategory, - ImageType, + ResourceOrigin, ) from invokeai.app.services.models.image_record import ( ImageDTO, @@ -36,9 +36,6 @@ async def upload_image( response: Response, image_category: ImageCategory = Query(description="The category of the image"), is_intermediate: bool = Query(description="Whether this is an intermediate image"), - show_in_gallery: bool = Query( - description="Whether this image should be shown in the gallery" - ), session_id: Optional[str] = Query( default=None, description="The session ID associated with this upload, if any" ), @@ -58,11 +55,10 @@ async def upload_image( try: image_dto = ApiDependencies.invoker.services.images.create( image=pil_image, - image_type=ImageType.UPLOAD, + image_origin=ResourceOrigin.EXTERNAL, image_category=image_category, session_id=session_id, is_intermediate=is_intermediate, - show_in_gallery=show_in_gallery, ) response.status_code = 201 @@ -73,27 +69,27 @@ async def upload_image( raise HTTPException(status_code=500, detail="Failed to create image") -@images_router.delete("/{image_type}/{image_name}", operation_id="delete_image") +@images_router.delete("/{image_origin}/{image_name}", operation_id="delete_image") async def delete_image( - image_type: ImageType = Path(description="The type of image to delete"), + image_origin: ResourceOrigin = Path(description="The origin of image to delete"), image_name: str = Path(description="The name of the image to delete"), ) -> None: """Deletes an image""" try: - ApiDependencies.invoker.services.images.delete(image_type, image_name) + ApiDependencies.invoker.services.images.delete(image_origin, image_name) except Exception as e: # TODO: Does this need any exception handling at all? pass @images_router.patch( - "/{image_type}/{image_name}", + "/{image_origin}/{image_name}", operation_id="update_image", response_model=ImageDTO, ) async def update_image( - image_type: ImageType = Path(description="The type of image to update"), + image_origin: ResourceOrigin = Path(description="The origin of image to update"), image_name: str = Path(description="The name of the image to update"), image_changes: ImageRecordChanges = Body( description="The changes to apply to the image" @@ -103,31 +99,31 @@ async def update_image( try: return ApiDependencies.invoker.services.images.update( - image_type, image_name, image_changes + image_origin, image_name, image_changes ) except Exception as e: raise HTTPException(status_code=400, detail="Failed to update image") @images_router.get( - "/{image_type}/{image_name}/metadata", + "/{image_origin}/{image_name}/metadata", operation_id="get_image_metadata", response_model=ImageDTO, ) async def get_image_metadata( - image_type: ImageType = Path(description="The type of image to get"), + image_origin: ResourceOrigin = Path(description="The origin of image to get"), image_name: str = Path(description="The name of image to get"), ) -> ImageDTO: """Gets an image's metadata""" try: - return ApiDependencies.invoker.services.images.get_dto(image_type, image_name) + return ApiDependencies.invoker.services.images.get_dto(image_origin, image_name) except Exception as e: raise HTTPException(status_code=404) @images_router.get( - "/{image_type}/{image_name}", + "/{image_origin}/{image_name}", operation_id="get_image_full", response_class=Response, responses={ @@ -139,7 +135,7 @@ async def get_image_metadata( }, ) async def get_image_full( - image_type: ImageType = Path( + image_origin: ResourceOrigin = Path( description="The type of full-resolution image file to get" ), image_name: str = Path(description="The name of full-resolution image file to get"), @@ -147,7 +143,7 @@ async def get_image_full( """Gets a full-resolution image file""" try: - path = ApiDependencies.invoker.services.images.get_path(image_type, image_name) + path = ApiDependencies.invoker.services.images.get_path(image_origin, image_name) if not ApiDependencies.invoker.services.images.validate_path(path): raise HTTPException(status_code=404) @@ -163,7 +159,7 @@ async def get_image_full( @images_router.get( - "/{image_type}/{image_name}/thumbnail", + "/{image_origin}/{image_name}/thumbnail", operation_id="get_image_thumbnail", response_class=Response, responses={ @@ -175,14 +171,14 @@ async def get_image_full( }, ) async def get_image_thumbnail( - image_type: ImageType = Path(description="The type of thumbnail image file to get"), + image_origin: ResourceOrigin = Path(description="The origin of thumbnail image file to get"), image_name: str = Path(description="The name of thumbnail image file to get"), ) -> FileResponse: """Gets a thumbnail image file""" try: path = ApiDependencies.invoker.services.images.get_path( - image_type, image_name, thumbnail=True + image_origin, image_name, thumbnail=True ) if not ApiDependencies.invoker.services.images.validate_path(path): raise HTTPException(status_code=404) @@ -195,25 +191,25 @@ async def get_image_thumbnail( @images_router.get( - "/{image_type}/{image_name}/urls", + "/{image_origin}/{image_name}/urls", operation_id="get_image_urls", response_model=ImageUrlsDTO, ) async def get_image_urls( - image_type: ImageType = Path(description="The type of the image whose URL to get"), + image_origin: ResourceOrigin = Path(description="The origin of the image whose URL to get"), image_name: str = Path(description="The name of the image whose URL to get"), ) -> ImageUrlsDTO: """Gets an image and thumbnail URL""" try: image_url = ApiDependencies.invoker.services.images.get_url( - image_type, image_name + image_origin, image_name ) thumbnail_url = ApiDependencies.invoker.services.images.get_url( - image_type, image_name, thumbnail=True + image_origin, image_name, thumbnail=True ) return ImageUrlsDTO( - image_type=image_type, + image_origin=image_origin, image_name=image_name, image_url=image_url, thumbnail_url=thumbnail_url, @@ -228,30 +224,33 @@ async def get_image_urls( response_model=PaginatedResults[ImageDTO], ) async def list_images_with_metadata( - image_type: Optional[ImageType] = Query( - default=None, description="The type of images to list" + image_origin: Optional[ResourceOrigin] = Query( + default=None, description="The origin of images to list" ), - image_category: Optional[ImageCategory] = Query( - default=None, description="The kind of images to list" + include_categories: Optional[list[ImageCategory]] = Query( + default=None, description="The categories of image to include" + ), + exclude_categories: Optional[list[ImageCategory]] = Query( + default=None, description="The categories of image to exclude" ), is_intermediate: Optional[bool] = Query( default=None, description="Whether to list intermediate images" ), - show_in_gallery: Optional[bool] = Query( - default=None, description="Whether to list images that show in the gallery" - ), page: int = Query(default=0, description="The page of images to get"), per_page: int = Query(default=10, description="The number of images per page"), ) -> PaginatedResults[ImageDTO]: """Gets a list of images""" + if include_categories is not None and exclude_categories is not None: + raise HTTPException(status_code=400, detail="Cannot use both 'include_category' and 'exclude_category' at the same time.") + image_dtos = ApiDependencies.invoker.services.images.get_many( page, per_page, - image_type, - image_category, + image_origin, + include_categories, + exclude_categories, is_intermediate, - show_in_gallery, ) return image_dtos diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py index 5e9fe088b5..5275116a2a 100644 --- a/invokeai/app/invocations/cv.py +++ b/invokeai/app/invocations/cv.py @@ -7,7 +7,7 @@ import numpy from PIL import Image, ImageOps from pydantic import BaseModel, Field -from invokeai.app.models.image import ImageCategory, ImageField, ImageType +from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig from .image import ImageOutput @@ -37,10 +37,10 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) mask = context.services.images.get_pil_image( - self.mask.image_type, self.mask.image_name + self.mask.image_origin, self.mask.image_name ) # Convert to cv image/mask @@ -57,7 +57,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig): image_dto = context.services.images.create( image=image_inpainted, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -67,7 +67,7 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig): return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py index 0385c6a9f0..d2ce59d247 100644 --- a/invokeai/app/invocations/generate.py +++ b/invokeai/app/invocations/generate.py @@ -10,9 +10,9 @@ import torch from pydantic import BaseModel, Field -from invokeai.app.models.image import ColorField, ImageField, ImageType +from invokeai.app.models.image import ColorField, ImageField, ResourceOrigin from invokeai.app.invocations.util.choose_model import choose_model -from invokeai.app.models.image import ImageCategory, ImageType +from invokeai.app.models.image import ImageCategory, ResourceOrigin from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.backend.generator.inpaint import infill_methods from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig @@ -120,7 +120,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation): image_dto = context.services.images.create( image=generate_output.image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, session_id=context.graph_execution_state_id, node_id=self.id, @@ -130,7 +130,7 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation): return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, @@ -170,7 +170,7 @@ class ImageToImageInvocation(TextToImageInvocation): None if self.image is None else context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) ) @@ -201,7 +201,7 @@ class ImageToImageInvocation(TextToImageInvocation): image_dto = context.services.images.create( image=generator_output.image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, session_id=context.graph_execution_state_id, node_id=self.id, @@ -211,7 +211,7 @@ class ImageToImageInvocation(TextToImageInvocation): return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, @@ -283,13 +283,13 @@ class InpaintInvocation(ImageToImageInvocation): None if self.image is None else context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) ) mask = ( None if self.mask is None - else context.services.images.get_pil_image(self.mask.image_type, self.mask.image_name) + else context.services.images.get_pil_image(self.mask.image_origin, self.mask.image_name) ) # Handle invalid model parameter @@ -317,7 +317,7 @@ class InpaintInvocation(ImageToImageInvocation): image_dto = context.services.images.create( image=generator_output.image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, session_id=context.graph_execution_state_id, node_id=self.id, @@ -327,7 +327,7 @@ class InpaintInvocation(ImageToImageInvocation): return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 69d51e6158..7633bfbc16 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -7,7 +7,7 @@ import numpy from PIL import Image, ImageFilter, ImageOps, ImageChops from pydantic import BaseModel, Field -from ..models.image import ImageCategory, ImageField, ImageType +from ..models.image import ImageCategory, ImageField, ResourceOrigin from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -72,12 +72,12 @@ class LoadImageInvocation(BaseInvocation): ) # fmt: on def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_type, self.image.image_name) + image = context.services.images.get_pil_image(self.image.image_origin, self.image.image_name) return ImageOutput( image=ImageField( image_name=self.image.image_name, - image_type=self.image.image_type, + image_origin=self.image.image_origin, ), width=image.width, height=image.height, @@ -96,7 +96,7 @@ class ShowImageInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) if image: image.show() @@ -106,7 +106,7 @@ class ShowImageInvocation(BaseInvocation): return ImageOutput( image=ImageField( image_name=self.image.image_name, - image_type=self.image.image_type, + image_origin=self.image.image_origin, ), width=image.width, height=image.height, @@ -129,7 +129,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) image_crop = Image.new( @@ -139,7 +139,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig): image_dto = context.services.images.create( image=image_crop, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -149,7 +149,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig): return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, @@ -172,17 +172,17 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig): def invoke(self, context: InvocationContext) -> ImageOutput: base_image = context.services.images.get_pil_image( - self.base_image.image_type, self.base_image.image_name + self.base_image.image_origin, self.base_image.image_name ) image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) mask = ( None if self.mask is None else ImageOps.invert( context.services.images.get_pil_image( - self.mask.image_type, self.mask.image_name + self.mask.image_origin, self.mask.image_name ) ) ) @@ -201,7 +201,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig): image_dto = context.services.images.create( image=new_image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -211,7 +211,7 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig): return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, @@ -231,7 +231,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig): def invoke(self, context: InvocationContext) -> MaskOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) image_mask = image.split()[-1] @@ -240,7 +240,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig): image_dto = context.services.images.create( image=image_mask, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.MASK, node_id=self.id, session_id=context.graph_execution_state_id, @@ -249,7 +249,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig): return MaskOutput( mask=ImageField( - image_type=image_dto.image_type, image_name=image_dto.image_name + image_origin=image_dto.image_origin, image_name=image_dto.image_name ), width=image_dto.width, height=image_dto.height, @@ -269,17 +269,17 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig): def invoke(self, context: InvocationContext) -> ImageOutput: image1 = context.services.images.get_pil_image( - self.image1.image_type, self.image1.image_name + self.image1.image_origin, self.image1.image_name ) image2 = context.services.images.get_pil_image( - self.image2.image_type, self.image2.image_name + self.image2.image_origin, self.image2.image_name ) multiply_image = ImageChops.multiply(image1, image2) image_dto = context.services.images.create( image=multiply_image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -288,7 +288,7 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig): return ImageOutput( image=ImageField( - image_type=image_dto.image_type, image_name=image_dto.image_name + image_origin=image_dto.image_origin, image_name=image_dto.image_name ), width=image_dto.width, height=image_dto.height, @@ -311,14 +311,14 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) channel_image = image.getchannel(self.channel) image_dto = context.services.images.create( image=channel_image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -327,7 +327,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig): return ImageOutput( image=ImageField( - image_type=image_dto.image_type, image_name=image_dto.image_name + image_origin=image_dto.image_origin, image_name=image_dto.image_name ), width=image_dto.width, height=image_dto.height, @@ -350,14 +350,14 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) converted_image = image.convert(self.mode) image_dto = context.services.images.create( image=converted_image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -366,7 +366,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig): return ImageOutput( image=ImageField( - image_type=image_dto.image_type, image_name=image_dto.image_name + image_origin=image_dto.image_origin, image_name=image_dto.image_name ), width=image_dto.width, height=image_dto.height, @@ -387,7 +387,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) blur = ( @@ -399,7 +399,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig): image_dto = context.services.images.create( image=blur_image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -409,7 +409,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig): return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, @@ -430,7 +430,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) image_arr = numpy.asarray(image, dtype=numpy.float32) / 255 @@ -440,7 +440,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig): image_dto = context.services.images.create( image=lerp_image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -450,7 +450,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig): return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, @@ -471,7 +471,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) image_arr = numpy.asarray(image, dtype=numpy.float32) @@ -486,7 +486,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig): image_dto = context.services.images.create( image=ilerp_image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -496,7 +496,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig): return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index ad60b62633..a06780c1f5 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -11,7 +11,7 @@ from invokeai.app.invocations.image import ImageOutput from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.backend.image_util.patchmatch import PatchMatch -from ..models.image import ColorField, ImageCategory, ImageField, ImageType +from ..models.image import ColorField, ImageCategory, ImageField, ResourceOrigin from .baseinvocation import ( BaseInvocation, InvocationContext, @@ -135,7 +135,7 @@ class InfillColorInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) solid_bg = Image.new("RGBA", image.size, self.color.tuple()) @@ -145,7 +145,7 @@ class InfillColorInvocation(BaseInvocation): image_dto = context.services.images.create( image=infilled, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -155,7 +155,7 @@ class InfillColorInvocation(BaseInvocation): return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, @@ -180,7 +180,7 @@ class InfillTileInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) infilled = tile_fill_missing( @@ -190,7 +190,7 @@ class InfillTileInvocation(BaseInvocation): image_dto = context.services.images.create( image=infilled, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -200,7 +200,7 @@ class InfillTileInvocation(BaseInvocation): return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, @@ -218,7 +218,7 @@ class InfillPatchMatchInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) if PatchMatch.patchmatch_available(): @@ -228,7 +228,7 @@ class InfillPatchMatchInvocation(BaseInvocation): image_dto = context.services.images.create( image=infilled, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -238,7 +238,7 @@ class InfillPatchMatchInvocation(BaseInvocation): return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 4975b7b578..7085cfd308 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -28,7 +28,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import ControlNetData from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig import numpy as np -from ..services.image_file_storage import ImageType +from ..services.image_file_storage import ResourceOrigin from .baseinvocation import BaseInvocation, InvocationContext from .image import ImageField, ImageOutput from .compel import ConditioningField @@ -468,7 +468,7 @@ class LatentsToImageInvocation(BaseInvocation): # and gnenerate unique image_name image_dto = context.services.images.create( image=image, - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, session_id=context.graph_execution_state_id, node_id=self.id, @@ -478,7 +478,7 @@ class LatentsToImageInvocation(BaseInvocation): return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, @@ -576,7 +576,7 @@ class ImageToLatentsInvocation(BaseInvocation): # self.image.image_type, self.image.image_name # ) image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) # TODO: this only really needs the vae diff --git a/invokeai/app/invocations/reconstruct.py b/invokeai/app/invocations/reconstruct.py index db71e4201d..5313411400 100644 --- a/invokeai/app/invocations/reconstruct.py +++ b/invokeai/app/invocations/reconstruct.py @@ -2,7 +2,7 @@ from typing import Literal, Union from pydantic import Field -from invokeai.app.models.image import ImageCategory, ImageField, ImageType +from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig from .image import ImageOutput @@ -29,7 +29,7 @@ class RestoreFaceInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) results = context.services.restoration.upscale_and_reconstruct( image_list=[[image, 0]], @@ -43,7 +43,7 @@ class RestoreFaceInvocation(BaseInvocation): # TODO: can this return multiple results? image_dto = context.services.images.create( image=results[0][0], - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -53,7 +53,7 @@ class RestoreFaceInvocation(BaseInvocation): return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index 90c9e4bf4f..80e1567047 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -4,7 +4,7 @@ from typing import Literal, Union from pydantic import Field -from invokeai.app.models.image import ImageCategory, ImageField, ImageType +from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig from .image import ImageOutput @@ -31,7 +31,7 @@ class UpscaleInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image( - self.image.image_type, self.image.image_name + self.image.image_origin, self.image.image_name ) results = context.services.restoration.upscale_and_reconstruct( image_list=[[image, 0]], @@ -45,7 +45,7 @@ class UpscaleInvocation(BaseInvocation): # TODO: can this return multiple results? image_dto = context.services.images.create( image=results[0][0], - image_type=ImageType.RESULT, + image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -55,7 +55,7 @@ class UpscaleInvocation(BaseInvocation): return ImageOutput( image=ImageField( image_name=image_dto.image_name, - image_type=image_dto.image_type, + image_origin=image_dto.image_origin, ), width=image_dto.width, height=image_dto.height, diff --git a/invokeai/app/models/image.py b/invokeai/app/models/image.py index 46b50145aa..6d48f2dbb1 100644 --- a/invokeai/app/models/image.py +++ b/invokeai/app/models/image.py @@ -5,30 +5,52 @@ from pydantic import BaseModel, Field from invokeai.app.util.metaenum import MetaEnum -class ImageType(str, Enum, metaclass=MetaEnum): - """The type of an image.""" +class ResourceOrigin(str, Enum, metaclass=MetaEnum): + """The origin of a resource (eg image). - RESULT = "results" - UPLOAD = "uploads" + - INTERNAL: The resource was created by the application. + - EXTERNAL: The resource was not created by the application. + This may be a user-initiated upload, or an internal application upload (eg Canvas init image). + """ + + INTERNAL = "internal" + """The resource was created by the application.""" + EXTERNAL = "external" + """The resource was not created by the application. + This may be a user-initiated upload, or an internal application upload (eg Canvas init image). + """ -class InvalidImageTypeException(ValueError): - """Raised when a provided value is not a valid ImageType. +class InvalidOriginException(ValueError): + """Raised when a provided value is not a valid ResourceOrigin. Subclasses `ValueError`. """ - def __init__(self, message="Invalid image type."): + def __init__(self, message="Invalid resource origin."): super().__init__(message) class ImageCategory(str, Enum, metaclass=MetaEnum): - """The category of an image. Use ImageCategory.OTHER for non-default categories.""" + """The category of an image. + + - GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose. + - MASK: The image is a mask image. + - CONTROL: The image is a ControlNet control image. + - USER: The image is a user-provide image. + - OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes. + """ GENERAL = "general" - CONTROL = "control" + """GENERAL: The image is an output, init image, or otherwise an image without a specialized purpose.""" MASK = "mask" + """MASK: The image is a mask image.""" + CONTROL = "control" + """CONTROL: The image is a ControlNet control image.""" + USER = "user" + """USER: The image is a user-provide image.""" OTHER = "other" + """OTHER: The image is some other type of image with a specialized purpose. To be used by external nodes.""" class InvalidImageCategoryException(ValueError): @@ -44,13 +66,13 @@ class InvalidImageCategoryException(ValueError): class ImageField(BaseModel): """An image field used for passing image objects between invocations""" - image_type: ImageType = Field( - default=ImageType.RESULT, description="The type of the image" + image_origin: ResourceOrigin = Field( + default=ResourceOrigin.INTERNAL, description="The type of the image" ) image_name: Optional[str] = Field(default=None, description="The name of the image") class Config: - schema_extra = {"required": ["image_type", "image_name"]} + schema_extra = {"required": ["image_origin", "image_name"]} class ColorField(BaseModel): @@ -61,3 +83,11 @@ class ColorField(BaseModel): def tuple(self) -> Tuple[int, int, int, int]: return (self.r, self.g, self.b, self.a) + + +class ProgressImage(BaseModel): + """The progress image sent intermittently during processing""" + + width: int = Field(description="The effective width of the image in pixels") + height: int = Field(description="The effective height of the image in pixels") + dataURL: str = Field(description="The image data as a b64 data URL") diff --git a/invokeai/app/services/events.py b/invokeai/app/services/events.py index a3e7cdd5dc..788f24dbce 100644 --- a/invokeai/app/services/events.py +++ b/invokeai/app/services/events.py @@ -1,7 +1,7 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) -from typing import Any, Optional -from invokeai.app.api.models.images import ProgressImage +from typing import Any +from invokeai.app.models.image import ProgressImage from invokeai.app.util.misc import get_timestamp diff --git a/invokeai/app/services/image_file_storage.py b/invokeai/app/services/image_file_storage.py index 46070b3bf2..68a994ea75 100644 --- a/invokeai/app/services/image_file_storage.py +++ b/invokeai/app/services/image_file_storage.py @@ -9,7 +9,7 @@ from PIL.Image import Image as PILImageType from PIL import Image, PngImagePlugin from send2trash import send2trash -from invokeai.app.models.image import ImageType +from invokeai.app.models.image import ResourceOrigin from invokeai.app.models.metadata import ImageMetadata from invokeai.app.util.thumbnails import get_thumbnail_name, make_thumbnail @@ -40,13 +40,13 @@ class ImageFileStorageBase(ABC): """Low-level service responsible for storing and retrieving image files.""" @abstractmethod - def get(self, image_type: ImageType, image_name: str) -> PILImageType: + def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType: """Retrieves an image as PIL Image.""" pass @abstractmethod def get_path( - self, image_type: ImageType, image_name: str, thumbnail: bool = False + self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False ) -> str: """Gets the internal path to an image or thumbnail.""" pass @@ -62,7 +62,7 @@ class ImageFileStorageBase(ABC): def save( self, image: PILImageType, - image_type: ImageType, + image_origin: ResourceOrigin, image_name: str, metadata: Optional[ImageMetadata] = None, thumbnail_size: int = 256, @@ -71,7 +71,7 @@ class ImageFileStorageBase(ABC): pass @abstractmethod - def delete(self, image_type: ImageType, image_name: str) -> None: + def delete(self, image_origin: ResourceOrigin, image_name: str) -> None: """Deletes an image and its thumbnail (if one exists).""" pass @@ -93,17 +93,17 @@ class DiskImageFileStorage(ImageFileStorageBase): Path(output_folder).mkdir(parents=True, exist_ok=True) # TODO: don't hard-code. get/save/delete should maybe take subpath? - for image_type in ImageType: - Path(os.path.join(output_folder, image_type)).mkdir( + for image_origin in ResourceOrigin: + Path(os.path.join(output_folder, image_origin)).mkdir( parents=True, exist_ok=True ) - Path(os.path.join(output_folder, image_type, "thumbnails")).mkdir( + Path(os.path.join(output_folder, image_origin, "thumbnails")).mkdir( parents=True, exist_ok=True ) - def get(self, image_type: ImageType, image_name: str) -> PILImageType: + def get(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType: try: - image_path = self.get_path(image_type, image_name) + image_path = self.get_path(image_origin, image_name) cache_item = self.__get_cache(image_path) if cache_item: return cache_item @@ -117,13 +117,13 @@ class DiskImageFileStorage(ImageFileStorageBase): def save( self, image: PILImageType, - image_type: ImageType, + image_origin: ResourceOrigin, image_name: str, metadata: Optional[ImageMetadata] = None, thumbnail_size: int = 256, ) -> None: try: - image_path = self.get_path(image_type, image_name) + image_path = self.get_path(image_origin, image_name) if metadata is not None: pnginfo = PngImagePlugin.PngInfo() @@ -133,7 +133,7 @@ class DiskImageFileStorage(ImageFileStorageBase): image.save(image_path, "PNG") thumbnail_name = get_thumbnail_name(image_name) - thumbnail_path = self.get_path(image_type, thumbnail_name, thumbnail=True) + thumbnail_path = self.get_path(image_origin, thumbnail_name, thumbnail=True) thumbnail_image = make_thumbnail(image, thumbnail_size) thumbnail_image.save(thumbnail_path) @@ -142,10 +142,10 @@ class DiskImageFileStorage(ImageFileStorageBase): except Exception as e: raise ImageFileSaveException from e - def delete(self, image_type: ImageType, image_name: str) -> None: + def delete(self, image_origin: ResourceOrigin, image_name: str) -> None: try: basename = os.path.basename(image_name) - image_path = self.get_path(image_type, basename) + image_path = self.get_path(image_origin, basename) if os.path.exists(image_path): send2trash(image_path) @@ -153,7 +153,7 @@ class DiskImageFileStorage(ImageFileStorageBase): del self.__cache[image_path] thumbnail_name = get_thumbnail_name(image_name) - thumbnail_path = self.get_path(image_type, thumbnail_name, True) + thumbnail_path = self.get_path(image_origin, thumbnail_name, True) if os.path.exists(thumbnail_path): send2trash(thumbnail_path) @@ -164,7 +164,7 @@ class DiskImageFileStorage(ImageFileStorageBase): # TODO: make this a bit more flexible for e.g. cloud storage def get_path( - self, image_type: ImageType, image_name: str, thumbnail: bool = False + self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False ) -> str: # strip out any relative path shenanigans basename = os.path.basename(image_name) @@ -172,10 +172,10 @@ class DiskImageFileStorage(ImageFileStorageBase): if thumbnail: thumbnail_name = get_thumbnail_name(basename) path = os.path.join( - self.__output_folder, image_type, "thumbnails", thumbnail_name + self.__output_folder, image_origin, "thumbnails", thumbnail_name ) else: - path = os.path.join(self.__output_folder, image_type, basename) + path = os.path.join(self.__output_folder, image_origin, basename) abspath = os.path.abspath(path) diff --git a/invokeai/app/services/image_record_storage.py b/invokeai/app/services/image_record_storage.py index 8afa7000fb..6b6d1ce7b2 100644 --- a/invokeai/app/services/image_record_storage.py +++ b/invokeai/app/services/image_record_storage.py @@ -8,7 +8,7 @@ from typing import Optional, Union from invokeai.app.models.metadata import ImageMetadata from invokeai.app.models.image import ( ImageCategory, - ImageType, + ResourceOrigin, ) from invokeai.app.services.models.image_record import ( ImageRecord, @@ -46,7 +46,7 @@ class ImageRecordStorageBase(ABC): # TODO: Implement an `update()` method @abstractmethod - def get(self, image_type: ImageType, image_name: str) -> ImageRecord: + def get(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord: """Gets an image record.""" pass @@ -54,7 +54,7 @@ class ImageRecordStorageBase(ABC): def update( self, image_name: str, - image_type: ImageType, + image_origin: ResourceOrigin, changes: ImageRecordChanges, ) -> None: """Updates an image record.""" @@ -65,10 +65,10 @@ class ImageRecordStorageBase(ABC): self, page: int = 0, per_page: int = 10, - image_type: Optional[ImageType] = None, - image_category: Optional[ImageCategory] = None, + image_origin: Optional[ResourceOrigin] = None, + include_categories: Optional[list[ImageCategory]] = None, + exclude_categories: Optional[list[ImageCategory]] = None, is_intermediate: Optional[bool] = None, - show_in_gallery: Optional[bool] = None, ) -> PaginatedResults[ImageRecord]: """Gets a page of image records.""" pass @@ -76,7 +76,7 @@ class ImageRecordStorageBase(ABC): # TODO: The database has a nullable `deleted_at` column, currently unused. # Should we implement soft deletes? Would need coordination with ImageFileStorage. @abstractmethod - def delete(self, image_type: ImageType, image_name: str) -> None: + def delete(self, image_origin: ResourceOrigin, image_name: str) -> None: """Deletes an image record.""" pass @@ -84,7 +84,7 @@ class ImageRecordStorageBase(ABC): def save( self, image_name: str, - image_type: ImageType, + image_origin: ResourceOrigin, image_category: ImageCategory, width: int, height: int, @@ -92,7 +92,6 @@ class ImageRecordStorageBase(ABC): node_id: Optional[str], metadata: Optional[ImageMetadata], is_intermediate: bool = False, - show_in_gallery: bool = True, ) -> datetime: """Saves an image record.""" pass @@ -131,7 +130,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): CREATE TABLE IF NOT EXISTS images ( image_name TEXT NOT NULL PRIMARY KEY, -- This is an enum in python, unrestricted string here for flexibility - image_type TEXT NOT NULL, + image_origin TEXT NOT NULL, -- This is an enum in python, unrestricted string here for flexibility image_category TEXT NOT NULL, width INTEGER NOT NULL, @@ -139,7 +138,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): session_id TEXT, node_id TEXT, metadata TEXT, - show_in_gallery BOOLEAN DEFAULT TRUE, is_intermediate BOOLEAN DEFAULT FALSE, created_at DATETIME NOT NULL DEFAULT(STRFTIME('%Y-%m-%d %H:%M:%f', 'NOW')), -- Updated via trigger @@ -158,7 +156,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): ) self._cursor.execute( """--sql - CREATE INDEX IF NOT EXISTS idx_images_image_type ON images(image_type); + CREATE INDEX IF NOT EXISTS idx_images_image_origin ON images(image_origin); """ ) self._cursor.execute( @@ -185,7 +183,9 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): """ ) - def get(self, image_type: ImageType, image_name: str) -> Union[ImageRecord, None]: + def get( + self, image_origin: ResourceOrigin, image_name: str + ) -> Union[ImageRecord, None]: try: self._lock.acquire() @@ -212,7 +212,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): def update( self, image_name: str, - image_type: ImageType, + image_origin: ResourceOrigin, changes: ImageRecordChanges, ) -> None: try: @@ -249,71 +249,72 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): self, page: int = 0, per_page: int = 10, - image_type: Optional[ImageType] = None, - image_category: Optional[ImageCategory] = None, + image_origin: Optional[ResourceOrigin] = None, + include_categories: Optional[list[ImageCategory]] = None, + exclude_categories: Optional[list[ImageCategory]] = None, is_intermediate: Optional[bool] = None, - show_in_gallery: Optional[bool] = None, ) -> PaginatedResults[ImageRecord]: try: self._lock.acquire() # Manually build two queries - one for the count, one for the records - count_query = """--sql - SELECT COUNT(*) FROM images WHERE 1=1 - """ - - images_query = """--sql - SELECT * FROM images WHERE 1=1 - """ + count_query = f"""SELECT COUNT(*) FROM images WHERE 1=1\n""" + images_query = f"""SELECT * FROM images WHERE 1=1\n""" query_conditions = "" query_params = [] - if image_type is not None: - query_conditions += """--sql - AND image_type = ? - """ - query_params.append(image_type.value) + if image_origin is not None: + query_conditions += f"""AND image_origin = ?\n""" + query_params.append(image_origin.value) - if image_category is not None: - query_conditions += """--sql - AND image_category = ? - """ - query_params.append(image_category.value) + if include_categories is not None: + ## Convert the enum values to unique list of strings + include_category_strings = list( + map(lambda c: c.value, set(include_categories)) + ) + # Create the correct length of placeholders + placeholders = ",".join("?" * len(include_category_strings)) + query_conditions += f"AND image_category IN ( {placeholders} )\n" + + # Unpack the included categories into the query params + query_params.append(*include_category_strings) + + if exclude_categories is not None: + ## Convert the enum values to unique list of strings + exclude_category_strings = list( + map(lambda c: c.value, set(exclude_categories)) + ) + + # Create the correct length of placeholders + placeholders = ",".join("?" * len(exclude_category_strings)) + query_conditions += f"AND image_category NOT IN ( {placeholders} )\n" + + # Unpack the included categories into the query params + query_params.append(*exclude_category_strings) if is_intermediate is not None: - query_conditions += """--sql - AND is_intermediate = ? - """ + query_conditions += f"""AND is_intermediate = ?\n""" query_params.append(is_intermediate) - if show_in_gallery is not None: - query_conditions += """--sql - AND show_in_gallery = ? - """ - query_params.append(show_in_gallery) - - query_pagination = """--sql - ORDER BY created_at DESC LIMIT ? OFFSET ? - """ - - count_query += query_conditions + ";" - count_params = query_params.copy() + query_pagination = f"""ORDER BY created_at DESC LIMIT ? OFFSET ?\n""" + # Final images query with pagination images_query += query_conditions + query_pagination + ";" + # Add all the parameters images_params = query_params.copy() images_params.append(per_page) images_params.append(page * per_page) - + # Build the list of images, deserializing each row self._cursor.execute(images_query, images_params) - result = cast(list[sqlite3.Row], self._cursor.fetchall()) - images = list(map(lambda r: deserialize_image_record(dict(r)), result)) + # Set up and execute the count query, without pagination + count_query += query_conditions + ";" + count_params = query_params.copy() self._cursor.execute(count_query, count_params) - count = self._cursor.fetchone()[0] except sqlite3.Error as e: self._conn.rollback() @@ -327,7 +328,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): items=images, page=page, pages=pageCount, per_page=per_page, total=count ) - def delete(self, image_type: ImageType, image_name: str) -> None: + def delete(self, image_origin: ResourceOrigin, image_name: str) -> None: try: self._lock.acquire() self._cursor.execute( @@ -347,7 +348,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): def save( self, image_name: str, - image_type: ImageType, + image_origin: ResourceOrigin, image_category: ImageCategory, session_id: Optional[str], width: int, @@ -355,7 +356,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): node_id: Optional[str], metadata: Optional[ImageMetadata], is_intermediate: bool = False, - show_in_gallery: bool = True, ) -> datetime: try: metadata_json = ( @@ -366,21 +366,20 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): """--sql INSERT OR IGNORE INTO images ( image_name, - image_type, + image_origin, image_category, width, height, node_id, session_id, metadata, - is_intermediate, - show_in_gallery + is_intermediate ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?); + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?); """, ( image_name, - image_type.value, + image_origin.value, image_category.value, width, height, @@ -388,7 +387,6 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): session_id, metadata_json, is_intermediate, - show_in_gallery, ), ) self._conn.commit() diff --git a/invokeai/app/services/images.py b/invokeai/app/services/images.py index 1bde1acfd4..dca95f673f 100644 --- a/invokeai/app/services/images.py +++ b/invokeai/app/services/images.py @@ -5,9 +5,9 @@ from PIL.Image import Image as PILImageType from invokeai.app.models.image import ( ImageCategory, - ImageType, + ResourceOrigin, InvalidImageCategoryException, - InvalidImageTypeException, + InvalidOriginException, ) from invokeai.app.models.metadata import ImageMetadata from invokeai.app.services.image_record_storage import ( @@ -44,12 +44,11 @@ class ImageServiceABC(ABC): def create( self, image: PILImageType, - image_type: ImageType, + image_origin: ResourceOrigin, image_category: ImageCategory, node_id: Optional[str] = None, session_id: Optional[str] = None, intermediate: bool = False, - show_in_gallery: bool = True, ) -> ImageDTO: """Creates an image, storing the file and its metadata.""" pass @@ -57,7 +56,7 @@ class ImageServiceABC(ABC): @abstractmethod def update( self, - image_type: ImageType, + image_origin: ResourceOrigin, image_name: str, changes: ImageRecordChanges, ) -> ImageDTO: @@ -65,22 +64,22 @@ class ImageServiceABC(ABC): pass @abstractmethod - def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType: + def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType: """Gets an image as a PIL image.""" pass @abstractmethod - def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord: + def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord: """Gets an image record.""" pass @abstractmethod - def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO: + def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO: """Gets an image DTO.""" pass @abstractmethod - def get_path(self, image_type: ImageType, image_name: str) -> str: + def get_path(self, image_origin: ResourceOrigin, image_name: str) -> str: """Gets an image's path.""" pass @@ -91,7 +90,7 @@ class ImageServiceABC(ABC): @abstractmethod def get_url( - self, image_type: ImageType, image_name: str, thumbnail: bool = False + self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False ) -> str: """Gets an image's or thumbnail's URL.""" pass @@ -101,16 +100,16 @@ class ImageServiceABC(ABC): self, page: int = 0, per_page: int = 10, - image_type: Optional[ImageType] = None, - image_category: Optional[ImageCategory] = None, + image_origin: Optional[ResourceOrigin] = None, + include_categories: Optional[list[ImageCategory]] = None, + exclude_categories: Optional[list[ImageCategory]] = None, is_intermediate: Optional[bool] = None, - show_in_gallery: Optional[bool] = None, ) -> PaginatedResults[ImageDTO]: """Gets a paginated list of image DTOs.""" pass @abstractmethod - def delete(self, image_type: ImageType, image_name: str): + def delete(self, image_origin: ResourceOrigin, image_name: str): """Deletes an image.""" pass @@ -171,15 +170,14 @@ class ImageService(ImageServiceABC): def create( self, image: PILImageType, - image_type: ImageType, + image_origin: ResourceOrigin, image_category: ImageCategory, node_id: Optional[str] = None, session_id: Optional[str] = None, is_intermediate: bool = False, - show_in_gallery: bool = True, ) -> ImageDTO: - if image_type not in ImageType: - raise InvalidImageTypeException + if image_origin not in ResourceOrigin: + raise InvalidOriginException if image_category not in ImageCategory: raise InvalidImageCategoryException @@ -195,13 +193,12 @@ class ImageService(ImageServiceABC): created_at = self._services.records.save( # Non-nullable fields image_name=image_name, - image_type=image_type, + image_origin=image_origin, image_category=image_category, width=width, height=height, # Meta fields is_intermediate=is_intermediate, - show_in_gallery=show_in_gallery, # Nullable fields node_id=node_id, session_id=session_id, @@ -209,21 +206,21 @@ class ImageService(ImageServiceABC): ) self._services.files.save( - image_type=image_type, + image_origin=image_origin, image_name=image_name, image=image, metadata=metadata, ) - image_url = self._services.urls.get_image_url(image_type, image_name) + image_url = self._services.urls.get_image_url(image_origin, image_name) thumbnail_url = self._services.urls.get_image_url( - image_type, image_name, True + image_origin, image_name, True ) return ImageDTO( # Non-nullable fields image_name=image_name, - image_type=image_type, + image_origin=image_origin, image_category=image_category, width=width, height=height, @@ -236,7 +233,6 @@ class ImageService(ImageServiceABC): updated_at=created_at, # this is always the same as the created_at at this time deleted_at=None, is_intermediate=is_intermediate, - show_in_gallery=show_in_gallery, # Extra non-nullable fields for DTO image_url=image_url, thumbnail_url=thumbnail_url, @@ -253,13 +249,13 @@ class ImageService(ImageServiceABC): def update( self, - image_type: ImageType, + image_origin: ResourceOrigin, image_name: str, changes: ImageRecordChanges, ) -> ImageDTO: try: - self._services.records.update(image_name, image_type, changes) - return self.get_dto(image_type, image_name) + self._services.records.update(image_name, image_origin, changes) + return self.get_dto(image_origin, image_name) except ImageRecordSaveException: self._services.logger.error("Failed to update image record") raise @@ -267,9 +263,9 @@ class ImageService(ImageServiceABC): self._services.logger.error("Problem updating image record") raise e - def get_pil_image(self, image_type: ImageType, image_name: str) -> PILImageType: + def get_pil_image(self, image_origin: ResourceOrigin, image_name: str) -> PILImageType: try: - return self._services.files.get(image_type, image_name) + return self._services.files.get(image_origin, image_name) except ImageFileNotFoundException: self._services.logger.error("Failed to get image file") raise @@ -277,9 +273,9 @@ class ImageService(ImageServiceABC): self._services.logger.error("Problem getting image file") raise e - def get_record(self, image_type: ImageType, image_name: str) -> ImageRecord: + def get_record(self, image_origin: ResourceOrigin, image_name: str) -> ImageRecord: try: - return self._services.records.get(image_type, image_name) + return self._services.records.get(image_origin, image_name) except ImageRecordNotFoundException: self._services.logger.error("Image record not found") raise @@ -287,14 +283,14 @@ class ImageService(ImageServiceABC): self._services.logger.error("Problem getting image record") raise e - def get_dto(self, image_type: ImageType, image_name: str) -> ImageDTO: + def get_dto(self, image_origin: ResourceOrigin, image_name: str) -> ImageDTO: try: - image_record = self._services.records.get(image_type, image_name) + image_record = self._services.records.get(image_origin, image_name) image_dto = image_record_to_dto( image_record, - self._services.urls.get_image_url(image_type, image_name), - self._services.urls.get_image_url(image_type, image_name, True), + self._services.urls.get_image_url(image_origin, image_name), + self._services.urls.get_image_url(image_origin, image_name, True), ) return image_dto @@ -306,10 +302,10 @@ class ImageService(ImageServiceABC): raise e def get_path( - self, image_type: ImageType, image_name: str, thumbnail: bool = False + self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False ) -> str: try: - return self._services.files.get_path(image_type, image_name, thumbnail) + return self._services.files.get_path(image_origin, image_name, thumbnail) except Exception as e: self._services.logger.error("Problem getting image path") raise e @@ -322,10 +318,10 @@ class ImageService(ImageServiceABC): raise e def get_url( - self, image_type: ImageType, image_name: str, thumbnail: bool = False + self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False ) -> str: try: - return self._services.urls.get_image_url(image_type, image_name, thumbnail) + return self._services.urls.get_image_url(image_origin, image_name, thumbnail) except Exception as e: self._services.logger.error("Problem getting image path") raise e @@ -334,28 +330,28 @@ class ImageService(ImageServiceABC): self, page: int = 0, per_page: int = 10, - image_type: Optional[ImageType] = None, - image_category: Optional[ImageCategory] = None, + image_origin: Optional[ResourceOrigin] = None, + include_categories: Optional[list[ImageCategory]] = None, + exclude_categories: Optional[list[ImageCategory]] = None, is_intermediate: Optional[bool] = None, - show_in_gallery: Optional[bool] = None, ) -> PaginatedResults[ImageDTO]: try: results = self._services.records.get_many( page, per_page, - image_type, - image_category, + image_origin, + include_categories, + exclude_categories, is_intermediate, - show_in_gallery, ) image_dtos = list( map( lambda r: image_record_to_dto( r, - self._services.urls.get_image_url(r.image_type, r.image_name), + self._services.urls.get_image_url(r.image_origin, r.image_name), self._services.urls.get_image_url( - r.image_type, r.image_name, True + r.image_origin, r.image_name, True ), ), results.items, @@ -373,10 +369,10 @@ class ImageService(ImageServiceABC): self._services.logger.error("Problem getting paginated image DTOs") raise e - def delete(self, image_type: ImageType, image_name: str): + def delete(self, image_origin: ResourceOrigin, image_name: str): try: - self._services.files.delete(image_type, image_name) - self._services.records.delete(image_type, image_name) + self._services.files.delete(image_origin, image_name) + self._services.records.delete(image_origin, image_name) except ImageRecordDeleteException: self._services.logger.error(f"Failed to delete image record") raise diff --git a/invokeai/app/services/models/image_record.py b/invokeai/app/services/models/image_record.py index faa6e1b41a..f143a30928 100644 --- a/invokeai/app/services/models/image_record.py +++ b/invokeai/app/services/models/image_record.py @@ -1,7 +1,7 @@ import datetime from typing import Optional, Union from pydantic import BaseModel, Extra, Field, StrictStr -from invokeai.app.models.image import ImageCategory, ImageType +from invokeai.app.models.image import ImageCategory, ResourceOrigin from invokeai.app.models.metadata import ImageMetadata from invokeai.app.util.misc import get_iso_timestamp @@ -11,8 +11,8 @@ class ImageRecord(BaseModel): image_name: str = Field(description="The unique name of the image.") """The unique name of the image.""" - image_type: ImageType = Field(description="The type of the image.") - """The type of the image.""" + image_origin: ResourceOrigin = Field(description="The type of the image.") + """The origin of the image.""" image_category: ImageCategory = Field(description="The category of the image.") """The category of the image.""" width: int = Field(description="The width of the image in px.") @@ -33,8 +33,6 @@ class ImageRecord(BaseModel): """The deleted timestamp of the image.""" is_intermediate: bool = Field(description="Whether this is an intermediate image.") """Whether this is an intermediate image.""" - show_in_gallery: bool = Field(description="Whether this image should be shown in the gallery.") - """Whether this image should be shown in the gallery.""" session_id: Optional[str] = Field( default=None, description="The session ID that generated this image, if it is a generated image.", @@ -76,8 +74,8 @@ class ImageUrlsDTO(BaseModel): image_name: str = Field(description="The unique name of the image.") """The unique name of the image.""" - image_type: ImageType = Field(description="The type of the image.") - """The type of the image.""" + image_origin: ResourceOrigin = Field(description="The type of the image.") + """The origin of the image.""" image_url: str = Field(description="The URL of the image.") """The URL of the image.""" thumbnail_url: str = Field(description="The URL of the image's thumbnail.") @@ -107,7 +105,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord: # Retrieve all the values, setting "reasonable" defaults if they are not present. image_name = image_dict.get("image_name", "unknown") - image_type = ImageType(image_dict.get("image_type", ImageType.RESULT.value)) + image_origin = ResourceOrigin(image_dict.get("image_origin", ResourceOrigin.INTERNAL.value)) image_category = ImageCategory( image_dict.get("image_category", ImageCategory.GENERAL.value) ) @@ -119,7 +117,6 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord: updated_at = image_dict.get("updated_at", get_iso_timestamp()) deleted_at = image_dict.get("deleted_at", get_iso_timestamp()) is_intermediate = image_dict.get("is_intermediate", False) - show_in_gallery = image_dict.get("show_in_gallery", True) raw_metadata = image_dict.get("metadata") @@ -130,7 +127,7 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord: return ImageRecord( image_name=image_name, - image_type=image_type, + image_origin=image_origin, image_category=image_category, width=width, height=height, @@ -141,5 +138,4 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord: updated_at=updated_at, deleted_at=deleted_at, is_intermediate=is_intermediate, - show_in_gallery=show_in_gallery, ) diff --git a/invokeai/app/services/urls.py b/invokeai/app/services/urls.py index 2716da60ad..4c8354c899 100644 --- a/invokeai/app/services/urls.py +++ b/invokeai/app/services/urls.py @@ -1,7 +1,7 @@ import os from abc import ABC, abstractmethod -from invokeai.app.models.image import ImageType +from invokeai.app.models.image import ResourceOrigin from invokeai.app.util.thumbnails import get_thumbnail_name @@ -10,7 +10,7 @@ class UrlServiceBase(ABC): @abstractmethod def get_image_url( - self, image_type: ImageType, image_name: str, thumbnail: bool = False + self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False ) -> str: """Gets the URL for an image or thumbnail.""" pass @@ -21,14 +21,14 @@ class LocalUrlService(UrlServiceBase): self._base_url = base_url def get_image_url( - self, image_type: ImageType, image_name: str, thumbnail: bool = False + self, image_origin: ResourceOrigin, image_name: str, thumbnail: bool = False ) -> str: image_basename = os.path.basename(image_name) # These paths are determined by the routes in invokeai/app/api/routers/images.py if thumbnail: return ( - f"{self._base_url}/images/{image_type.value}/{image_basename}/thumbnail" + f"{self._base_url}/images/{image_origin.value}/{image_basename}/thumbnail" ) - return f"{self._base_url}/images/{image_type.value}/{image_basename}" + return f"{self._base_url}/images/{image_origin.value}/{image_basename}" diff --git a/invokeai/app/util/step_callback.py b/invokeai/app/util/step_callback.py index 963e770406..b4b9a25909 100644 --- a/invokeai/app/util/step_callback.py +++ b/invokeai/app/util/step_callback.py @@ -1,5 +1,5 @@ -from invokeai.app.api.models.images import ProgressImage from invokeai.app.models.exceptions import CanceledException +from invokeai.app.models.image import ProgressImage from ..invocations.baseinvocation import InvocationContext from ...backend.util.util import image_to_dataURL from ...backend.generator.base import Generator