From d2c223de8f19f56ac1f4ea0b6c66eba2cef06152 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 24 May 2023 15:50:55 +1000 Subject: [PATCH] feat(nodes): move fully* to new images service * except i haven't rebuilt inpaint in latents --- invokeai/app/api/dependencies.py | 22 +- invokeai/app/api/routers/images.py | 20 +- invokeai/app/invocations/cv.py | 42 ++-- invokeai/app/invocations/image.py | 226 ++++++++++--------- invokeai/app/invocations/infill.py | 105 +++++---- invokeai/app/invocations/latent.py | 32 ++- invokeai/app/invocations/reconstruct.py | 39 ++-- invokeai/app/invocations/upscale.py | 37 +-- invokeai/app/services/invocation_services.py | 8 +- 9 files changed, 273 insertions(+), 258 deletions(-) diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index ae351d4476..99e0f7238f 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -55,16 +55,6 @@ class ApiDependencies: os.path.join(os.path.dirname(__file__), "../../../../outputs") ) - latents = ForwardCacheLatentsStorage( - DiskLatentsStorage(f"{output_folder}/latents") - ) - - metadata = CoreMetadataService() - - urls = LocalUrlService() - - image_file_storage = DiskImageFileStorage(f"{output_folder}/images") - # TODO: build a file/path manager? db_location = os.path.join(output_folder, "invokeai.db") @@ -72,9 +62,16 @@ class ApiDependencies: filename=db_location, table_name="graph_executions" ) + urls = LocalUrlService() + metadata = CoreMetadataService() image_record_storage = SqliteImageRecordStorage(db_location) + image_file_storage = DiskImageFileStorage(f"{output_folder}/images") - images_new = ImageService( + latents = ForwardCacheLatentsStorage( + DiskLatentsStorage(f"{output_folder}/latents") + ) + + images = ImageService( image_record_storage=image_record_storage, image_file_storage=image_file_storage, metadata=metadata, @@ -87,8 +84,7 @@ class ApiDependencies: model_manager=get_model_manager(config, logger), events=events, latents=latents, - images=image_file_storage, - images_new=images_new, + images=images, queue=MemoryInvocationQueue(), graph_library=SqliteItemStorage[LibraryGraph]( filename=db_location, table_name="graphs" diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index 602b539da1..0615ff187e 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -45,7 +45,7 @@ async def upload_image( raise HTTPException(status_code=415, detail="Failed to read image") try: - image_dto = ApiDependencies.invoker.services.images_new.create( + image_dto = ApiDependencies.invoker.services.images.create( pil_image, image_type, image_category, @@ -67,7 +67,7 @@ async def delete_image( """Deletes an image""" try: - ApiDependencies.invoker.services.images_new.delete(image_type, image_name) + ApiDependencies.invoker.services.images.delete(image_type, image_name) except Exception as e: # TODO: Does this need any exception handling at all? pass @@ -85,7 +85,7 @@ async def get_image_metadata( """Gets an image's metadata""" try: - return ApiDependencies.invoker.services.images_new.get_dto( + return ApiDependencies.invoker.services.images.get_dto( image_type, image_name ) except Exception as e: @@ -113,11 +113,11 @@ async def get_image_full( """Gets a full-resolution image file""" try: - path = ApiDependencies.invoker.services.images_new.get_path( + path = ApiDependencies.invoker.services.images.get_path( image_type, image_name ) - if not ApiDependencies.invoker.services.images_new.validate_path(path): + if not ApiDependencies.invoker.services.images.validate_path(path): raise HTTPException(status_code=404) return FileResponse( @@ -149,10 +149,10 @@ async def get_image_thumbnail( """Gets a thumbnail image file""" try: - path = ApiDependencies.invoker.services.images_new.get_path( + path = ApiDependencies.invoker.services.images.get_path( image_type, image_name, thumbnail=True ) - if not ApiDependencies.invoker.services.images_new.validate_path(path): + if not ApiDependencies.invoker.services.images.validate_path(path): raise HTTPException(status_code=404) return FileResponse( @@ -174,10 +174,10 @@ async def get_image_urls( """Gets an image and thumbnail URL""" try: - image_url = ApiDependencies.invoker.services.images_new.get_url( + image_url = ApiDependencies.invoker.services.images.get_url( image_type, image_name ) - thumbnail_url = ApiDependencies.invoker.services.images_new.get_url( + thumbnail_url = ApiDependencies.invoker.services.images.get_url( image_type, image_name, thumbnail=True ) return ImageUrlsDTO( @@ -205,7 +205,7 @@ async def list_images_with_metadata( ) -> PaginatedResults[ImageDTO]: """Gets a list of images with metadata""" - image_dtos = ApiDependencies.invoker.services.images_new.get_many( + image_dtos = ApiDependencies.invoker.services.images.get_many( image_type, image_category, page, diff --git a/invokeai/app/invocations/cv.py b/invokeai/app/invocations/cv.py index 5a6d703d83..26e06a2af8 100644 --- a/invokeai/app/invocations/cv.py +++ b/invokeai/app/invocations/cv.py @@ -7,9 +7,9 @@ import numpy from PIL import Image, ImageOps from pydantic import BaseModel, Field -from invokeai.app.models.image import ImageField, ImageType +from invokeai.app.models.image import ImageCategory, ImageField, ImageType from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig -from .image import ImageOutput, build_image_output +from .image import ImageOutput class CvInvocationConfig(BaseModel): @@ -26,24 +26,27 @@ class CvInvocationConfig(BaseModel): class CvInpaintInvocation(BaseInvocation, CvInvocationConfig): """Simple inpaint using opencv.""" - #fmt: off + + # fmt: off type: Literal["cv_inpaint"] = "cv_inpaint" # Inputs image: ImageField = Field(default=None, description="The image to inpaint") mask: ImageField = Field(default=None, description="The mask to use when inpainting") - #fmt: on + # fmt: on def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get( + image = context.services.images.get_pil_image( self.image.image_type, self.image.image_name ) - mask = context.services.images.get(self.mask.image_type, self.mask.image_name) + mask = context.services.images.get_pil_image( + self.mask.image_type, self.mask.image_name + ) # Convert to cv image/mask # TODO: consider making these utility functions cv_image = cv.cvtColor(numpy.array(image.convert("RGB")), cv.COLOR_RGB2BGR) - cv_mask = numpy.array(ImageOps.invert(mask)) + cv_mask = numpy.array(ImageOps.invert(mask.convert("L"))) # Inpaint cv_inpainted = cv.inpaint(cv_image, cv_mask, 3, cv.INPAINT_TELEA) @@ -52,18 +55,19 @@ class CvInpaintInvocation(BaseInvocation, CvInvocationConfig): # TODO: consider making a utility function image_inpainted = Image.fromarray(cv.cvtColor(cv_inpainted, cv.COLOR_BGR2RGB)) - image_type = ImageType.INTERMEDIATE - image_name = context.services.images.create_name( - context.graph_execution_state_id, self.id + image_dto = context.services.images.create( + image=image_inpainted, + image_type=ImageType.INTERMEDIATE, + image_category=ImageCategory.GENERAL, + node_id=self.id, + session_id=context.graph_execution_state_id, ) - metadata = context.services.metadata.build_metadata( - session_id=context.graph_execution_state_id, node=self + return ImageOutput( + image=ImageField( + image_name=image_dto.image_name, + image_type=image_dto.image_type, + ), + width=image_dto.width, + height=image_dto.height, ) - - context.services.images.save(image_type, image_name, image_inpainted, metadata) - return build_image_output( - image_type=image_type, - image_name=image_name, - image=image_inpainted, - ) \ No newline at end of file diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 56141cbb0e..8f789853ac 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -1,13 +1,13 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) import io -from typing import Literal, Optional +from typing import Literal, Optional, Union import numpy from PIL import Image, ImageFilter, ImageOps from pydantic import BaseModel, Field -from ..models.image import ImageField, ImageType +from ..models.image import ImageCategory, ImageField, ImageType from .baseinvocation import ( BaseInvocation, BaseInvocationOutput, @@ -41,27 +41,14 @@ class ImageOutput(BaseInvocationOutput): schema_extra = {"required": ["type", "image", "width", "height"]} -def build_image_output( - image_type: ImageType, image_name: str, image: Image.Image -) -> ImageOutput: - """Builds an ImageOutput and its ImageField""" - image_field = ImageField( - image_name=image_name, - image_type=image_type, - ) - return ImageOutput( - image=image_field, - width=image.width, - height=image.height, - ) - - class MaskOutput(BaseInvocationOutput): """Base class for invocations that output a mask""" # fmt: off type: Literal["mask"] = "mask" mask: ImageField = Field(default=None, description="The output mask") + width: int = Field(description="The width of the mask in pixels") + height: int = Field(description="The height of the mask in pixels") # fmt: on class Config: @@ -84,12 +71,15 @@ class LoadImageInvocation(BaseInvocation): image_name: str = Field(description="The name of the image") # fmt: on def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get(self.image_type, self.image_name) + image = context.services.images.get_pil_image(self.image_type, self.image_name) - return build_image_output( - image_type=self.image_type, - image_name=self.image_name, - image=image, + return ImageOutput( + image=ImageField( + image_name=self.image_name, + image_type=self.image_type, + ), + width=image.width, + height=image.height, ) @@ -99,10 +89,12 @@ class ShowImageInvocation(BaseInvocation): type: Literal["show_image"] = "show_image" # Inputs - image: ImageField = Field(default=None, description="The image to show") + image: Union[ImageField, None] = Field( + default=None, description="The image to show" + ) def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get( + image = context.services.images.get_pil_image( self.image.image_type, self.image.image_name ) if image: @@ -110,10 +102,13 @@ class ShowImageInvocation(BaseInvocation): # TODO: how to handle failure? - return build_image_output( - image_type=self.image.image_type, - image_name=self.image.image_name, - image=image, + return ImageOutput( + image=ImageField( + image_name=self.image.image_name, + image_type=self.image.image_type, + ), + width=image.width, + height=image.height, ) @@ -124,7 +119,7 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig): type: Literal["crop"] = "crop" # Inputs - image: ImageField = Field(default=None, description="The image to crop") + image: Union[ImageField, None] = Field(default=None, description="The image to crop") x: int = Field(default=0, description="The left x coordinate of the crop rectangle") y: int = Field(default=0, description="The top y coordinate of the crop rectangle") width: int = Field(default=512, gt=0, description="The width of the crop rectangle") @@ -132,7 +127,7 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig): # fmt: on def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get( + image = context.services.images.get_pil_image( self.image.image_type, self.image.image_name ) @@ -141,20 +136,21 @@ class CropImageInvocation(BaseInvocation, PILInvocationConfig): ) image_crop.paste(image, (-self.x, -self.y)) - image_type = ImageType.INTERMEDIATE - image_name = context.services.images.create_name( - context.graph_execution_state_id, self.id - ) - - metadata = context.services.metadata.build_metadata( - session_id=context.graph_execution_state_id, node=self - ) - - context.services.images.save(image_type, image_name, image_crop, metadata) - return build_image_output( - image_type=image_type, - image_name=image_name, + image_dto = context.services.images.create( image=image_crop, + image_type=ImageType.INTERMEDIATE, + image_category=ImageCategory.GENERAL, + node_id=self.id, + session_id=context.graph_execution_state_id, + ) + + return ImageOutput( + image=ImageField( + image_name=image_dto.image_name, + image_type=image_dto.image_type, + ), + width=image_dto.width, + height=image_dto.height, ) @@ -165,25 +161,27 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig): type: Literal["paste"] = "paste" # Inputs - base_image: ImageField = Field(default=None, description="The base image") - image: ImageField = Field(default=None, description="The image to paste") + base_image: Union[ImageField, None] = Field(default=None, description="The base image") + image: Union[ImageField, None] = Field(default=None, description="The image to paste") mask: Optional[ImageField] = Field(default=None, description="The mask to use when pasting") x: int = Field(default=0, description="The left x coordinate at which to paste the image") y: int = Field(default=0, description="The top y coordinate at which to paste the image") # fmt: on def invoke(self, context: InvocationContext) -> ImageOutput: - base_image = context.services.images.get( + base_image = context.services.images.get_pil_image( self.base_image.image_type, self.base_image.image_name ) - image = context.services.images.get( + image = context.services.images.get_pil_image( self.image.image_type, self.image.image_name ) mask = ( None if self.mask is None else ImageOps.invert( - context.services.images.get(self.mask.image_type, self.mask.image_name) + context.services.images.get_pil_image( + self.mask.image_type, self.mask.image_name + ) ) ) # TODO: probably shouldn't invert mask here... should user be required to do it? @@ -199,20 +197,21 @@ class PasteImageInvocation(BaseInvocation, PILInvocationConfig): new_image.paste(base_image, (abs(min_x), abs(min_y))) new_image.paste(image, (max(0, self.x), max(0, self.y)), mask=mask) - image_type = ImageType.RESULT - image_name = context.services.images.create_name( - context.graph_execution_state_id, self.id - ) - - metadata = context.services.metadata.build_metadata( - session_id=context.graph_execution_state_id, node=self - ) - - context.services.images.save(image_type, image_name, new_image, metadata) - return build_image_output( - image_type=image_type, - image_name=image_name, + image_dto = context.services.images.create( image=new_image, + image_type=ImageType.RESULT, + image_category=ImageCategory.GENERAL, + node_id=self.id, + session_id=context.graph_execution_state_id, + ) + + return ImageOutput( + image=ImageField( + image_name=image_dto.image_name, + image_type=image_dto.image_type, + ), + width=image_dto.width, + height=image_dto.height, ) @@ -223,12 +222,12 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig): type: Literal["tomask"] = "tomask" # Inputs - image: ImageField = Field(default=None, description="The image to create the mask from") + image: Union[ImageField, None] = Field(default=None, description="The image to create the mask from") invert: bool = Field(default=False, description="Whether or not to invert the mask") # fmt: on def invoke(self, context: InvocationContext) -> MaskOutput: - image = context.services.images.get( + image = context.services.images.get_pil_image( self.image.image_type, self.image.image_name ) @@ -236,18 +235,22 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig): if self.invert: image_mask = ImageOps.invert(image_mask) - image_type = ImageType.INTERMEDIATE - image_name = context.services.images.create_name( - context.graph_execution_state_id, self.id + image_dto = context.services.images.create( + image=image_mask, + image_type=ImageType.INTERMEDIATE, + image_category=ImageCategory.GENERAL, + node_id=self.id, + session_id=context.graph_execution_state_id, ) - metadata = context.services.metadata.build_metadata( - session_id=context.graph_execution_state_id, node=self + return MaskOutput( + mask=ImageField( + image_type=image_dto.image_type, image_name=image_dto.image_name + ), + width=image_dto.width, + height=image_dto.height, ) - context.services.images.save(image_type, image_name, image_mask, metadata) - return MaskOutput(mask=ImageField(image_type=image_type, image_name=image_name)) - class BlurInvocation(BaseInvocation, PILInvocationConfig): """Blurs an image""" @@ -256,13 +259,13 @@ class BlurInvocation(BaseInvocation, PILInvocationConfig): type: Literal["blur"] = "blur" # Inputs - image: ImageField = Field(default=None, description="The image to blur") + image: Union[ImageField, None] = Field(default=None, description="The image to blur") radius: float = Field(default=8.0, ge=0, description="The blur radius") blur_type: Literal["gaussian", "box"] = Field(default="gaussian", description="The type of blur") # fmt: on def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get( + image = context.services.images.get_pil_image( self.image.image_type, self.image.image_name ) @@ -273,18 +276,21 @@ class BlurInvocation(BaseInvocation, PILInvocationConfig): ) blur_image = image.filter(blur) - image_type = ImageType.INTERMEDIATE - image_name = context.services.images.create_name( - context.graph_execution_state_id, self.id + image_dto = context.services.images.create( + image=blur_image, + image_type=ImageType.INTERMEDIATE, + image_category=ImageCategory.GENERAL, + node_id=self.id, + session_id=context.graph_execution_state_id, ) - metadata = context.services.metadata.build_metadata( - session_id=context.graph_execution_state_id, node=self - ) - - context.services.images.save(image_type, image_name, blur_image, metadata) - return build_image_output( - image_type=image_type, image_name=image_name, image=blur_image + return ImageOutput( + image=ImageField( + image_name=image_dto.image_name, + image_type=image_dto.image_type, + ), + width=image_dto.width, + height=image_dto.height, ) @@ -295,13 +301,13 @@ class LerpInvocation(BaseInvocation, PILInvocationConfig): type: Literal["lerp"] = "lerp" # Inputs - image: ImageField = Field(default=None, description="The image to lerp") + image: Union[ImageField, None] = Field(default=None, description="The image to lerp") min: int = Field(default=0, ge=0, le=255, description="The minimum output value") max: int = Field(default=255, ge=0, le=255, description="The maximum output value") # fmt: on def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get( + image = context.services.images.get_pil_image( self.image.image_type, self.image.image_name ) @@ -310,18 +316,21 @@ class LerpInvocation(BaseInvocation, PILInvocationConfig): lerp_image = Image.fromarray(numpy.uint8(image_arr)) - image_type = ImageType.INTERMEDIATE - image_name = context.services.images.create_name( - context.graph_execution_state_id, self.id + image_dto = context.services.images.create( + image=lerp_image, + image_type=ImageType.INTERMEDIATE, + image_category=ImageCategory.GENERAL, + node_id=self.id, + session_id=context.graph_execution_state_id, ) - metadata = context.services.metadata.build_metadata( - session_id=context.graph_execution_state_id, node=self - ) - - context.services.images.save(image_type, image_name, lerp_image, metadata) - return build_image_output( - image_type=image_type, image_name=image_name, image=lerp_image + return ImageOutput( + image=ImageField( + image_name=image_dto.image_name, + image_type=image_dto.image_type, + ), + width=image_dto.width, + height=image_dto.height, ) @@ -332,13 +341,13 @@ class InverseLerpInvocation(BaseInvocation, PILInvocationConfig): type: Literal["ilerp"] = "ilerp" # Inputs - image: ImageField = Field(default=None, description="The image to lerp") + image: Union[ImageField, None] = Field(default=None, description="The image to lerp") min: int = Field(default=0, ge=0, le=255, description="The minimum input value") max: int = Field(default=255, ge=0, le=255, description="The maximum input value") # fmt: on def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get( + image = context.services.images.get_pil_image( self.image.image_type, self.image.image_name ) @@ -352,16 +361,19 @@ class InverseLerpInvocation(BaseInvocation, PILInvocationConfig): ilerp_image = Image.fromarray(numpy.uint8(image_arr)) - image_type = ImageType.INTERMEDIATE - image_name = context.services.images.create_name( - context.graph_execution_state_id, self.id + image_dto = context.services.images.create( + image=ilerp_image, + image_type=ImageType.INTERMEDIATE, + image_category=ImageCategory.GENERAL, + node_id=self.id, + session_id=context.graph_execution_state_id, ) - metadata = context.services.metadata.build_metadata( - session_id=context.graph_execution_state_id, node=self - ) - - context.services.images.save(image_type, image_name, ilerp_image, metadata) - return build_image_output( - image_type=image_type, image_name=image_name, image=ilerp_image + return ImageOutput( + image=ImageField( + image_name=image_dto.image_name, + image_type=image_dto.image_type, + ), + width=image_dto.width, + height=image_dto.height, ) diff --git a/invokeai/app/invocations/infill.py b/invokeai/app/invocations/infill.py index ac055cef5b..17a43dbdac 100644 --- a/invokeai/app/invocations/infill.py +++ b/invokeai/app/invocations/infill.py @@ -1,17 +1,17 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team -from typing import Literal, Optional, Union, get_args +from typing import Literal, Union, get_args import numpy as np import math from PIL import Image, ImageOps from pydantic import Field -from invokeai.app.invocations.image import ImageOutput, build_image_output +from invokeai.app.invocations.image import ImageOutput from invokeai.app.util.misc import SEED_MAX, get_random_seed from invokeai.backend.image_util.patchmatch import PatchMatch -from ..models.image import ColorField, ImageField, ImageType +from ..models.image import ColorField, ImageCategory, ImageField, ImageType from .baseinvocation import ( BaseInvocation, InvocationContext, @@ -125,36 +125,39 @@ class InfillColorInvocation(BaseInvocation): """Infills transparent areas of an image with a solid color""" type: Literal["infill_rgba"] = "infill_rgba" - image: Optional[ImageField] = Field(default=None, description="The image to infill") - color: Optional[ColorField] = Field( + image: Union[ImageField, None] = Field( + default=None, description="The image to infill" + ) + color: ColorField = Field( default=ColorField(r=127, g=127, b=127, a=255), description="The color to use to infill", ) def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get( + image = context.services.images.get_pil_image( self.image.image_type, self.image.image_name ) solid_bg = Image.new("RGBA", image.size, self.color.tuple()) - infilled = Image.alpha_composite(solid_bg, image) + infilled = Image.alpha_composite(solid_bg, image.convert("RGBA")) infilled.paste(image, (0, 0), image.split()[-1]) - image_type = ImageType.RESULT - image_name = context.services.images.create_name( - context.graph_execution_state_id, self.id + image_dto = context.services.images.create( + image=infilled, + image_type=ImageType.RESULT, + image_category=ImageCategory.GENERAL, + node_id=self.id, + session_id=context.graph_execution_state_id, ) - metadata = context.services.metadata.build_metadata( - session_id=context.graph_execution_state_id, node=self - ) - - context.services.images.save(image_type, image_name, infilled, metadata) - return build_image_output( - image_type=image_type, - image_name=image_name, - image=image, + return ImageOutput( + image=ImageField( + image_name=image_dto.image_name, + image_type=image_dto.image_type, + ), + width=image_dto.width, + height=image_dto.height, ) @@ -163,7 +166,9 @@ class InfillTileInvocation(BaseInvocation): type: Literal["infill_tile"] = "infill_tile" - image: Optional[ImageField] = Field(default=None, description="The image to infill") + image: Union[ImageField, None] = Field( + default=None, description="The image to infill" + ) tile_size: int = Field(default=32, ge=1, description="The tile size (px)") seed: int = Field( ge=0, @@ -173,7 +178,7 @@ class InfillTileInvocation(BaseInvocation): ) def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get( + image = context.services.images.get_pil_image( self.image.image_type, self.image.image_name ) @@ -182,20 +187,21 @@ class InfillTileInvocation(BaseInvocation): ) infilled.paste(image, (0, 0), image.split()[-1]) - image_type = ImageType.RESULT - image_name = context.services.images.create_name( - context.graph_execution_state_id, self.id + image_dto = context.services.images.create( + image=infilled, + image_type=ImageType.RESULT, + image_category=ImageCategory.GENERAL, + node_id=self.id, + session_id=context.graph_execution_state_id, ) - metadata = context.services.metadata.build_metadata( - session_id=context.graph_execution_state_id, node=self - ) - - context.services.images.save(image_type, image_name, infilled, metadata) - return build_image_output( - image_type=image_type, - image_name=image_name, - image=image, + return ImageOutput( + image=ImageField( + image_name=image_dto.image_name, + image_type=image_dto.image_type, + ), + width=image_dto.width, + height=image_dto.height, ) @@ -204,10 +210,12 @@ class InfillPatchMatchInvocation(BaseInvocation): type: Literal["infill_patchmatch"] = "infill_patchmatch" - image: Optional[ImageField] = Field(default=None, description="The image to infill") + image: Union[ImageField, None] = Field( + default=None, description="The image to infill" + ) def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get( + image = context.services.images.get_pil_image( self.image.image_type, self.image.image_name ) @@ -216,18 +224,19 @@ class InfillPatchMatchInvocation(BaseInvocation): else: raise ValueError("PatchMatch is not available on this system") - image_type = ImageType.RESULT - image_name = context.services.images.create_name( - context.graph_execution_state_id, self.id + image_dto = context.services.images.create( + image=infilled, + image_type=ImageType.RESULT, + image_category=ImageCategory.GENERAL, + node_id=self.id, + session_id=context.graph_execution_state_id, ) - metadata = context.services.metadata.build_metadata( - session_id=context.graph_execution_state_id, node=self - ) - - context.services.images.save(image_type, image_name, infilled, metadata) - return build_image_output( - image_type=image_type, - image_name=image_name, - image=image, + return ImageOutput( + image=ImageField( + image_name=image_dto.image_name, + image_type=image_dto.image_type, + ), + width=image_dto.width, + height=image_dto.height, ) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 40ba67861a..1fcd434852 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -3,7 +3,7 @@ import random from typing import Literal, Optional, Union import einops -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, validator import torch from invokeai.app.invocations.util.choose_model import choose_model @@ -23,7 +23,7 @@ from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationCont import numpy as np from ..services.image_file_storage import ImageType from .baseinvocation import BaseInvocation, InvocationContext -from .image import ImageField, ImageOutput, build_image_output +from .image import ImageField, ImageOutput from .compel import ConditioningField from ...backend.stable_diffusion import PipelineIntermediateState from diffusers.schedulers import SchedulerMixin as Scheduler @@ -362,19 +362,9 @@ class LatentsToImageInvocation(BaseInvocation): np_image = model.decode_latents(latents) image = model.numpy_to_pil(np_image)[0] - # image_type = ImageType.RESULT - # image_name = context.services.images.create_name( - # context.graph_execution_state_id, self.id - # ) + torch.cuda.empty_cache() - # metadata = context.services.metadata.build_metadata( - # session_id=context.graph_execution_state_id, node=self - # ) - - # torch.cuda.empty_cache() - - # context.services.images.save(image_type, image_name, image, metadata) - image_dto = context.services.images_new.create( + image_dto = context.services.images.create( image=image, image_type=ImageType.RESULT, image_category=ImageCategory.GENERAL, @@ -382,10 +372,13 @@ class LatentsToImageInvocation(BaseInvocation): node_id=self.id, ) - return build_image_output( - image_type=image_dto.image_type, - image_name=image_dto.image_name, - image=image, + return ImageOutput( + image=ImageField( + image_name=image_dto.image_name, + image_type=image_dto.image_type, + ), + width=image_dto.width, + height=image_dto.height, ) @@ -474,7 +467,7 @@ class ImageToLatentsInvocation(BaseInvocation): @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: - image = context.services.images.get( + image = context.services.images.get_pil_image( self.image.image_type, self.image.image_name ) @@ -496,3 +489,4 @@ class ImageToLatentsInvocation(BaseInvocation): name = f"{context.graph_execution_state_id}__{self.id}" context.services.latents.save(name, latents) return build_latents_output(latents_name=name, latents=latents) + diff --git a/invokeai/app/invocations/reconstruct.py b/invokeai/app/invocations/reconstruct.py index 94a7277acd..024134cd46 100644 --- a/invokeai/app/invocations/reconstruct.py +++ b/invokeai/app/invocations/reconstruct.py @@ -2,21 +2,23 @@ from typing import Literal, Union from pydantic import Field -from invokeai.app.models.image import ImageField, ImageType +from invokeai.app.models.image import ImageCategory, ImageField, ImageType from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig -from .image import ImageOutput, build_image_output +from .image import ImageOutput + class RestoreFaceInvocation(BaseInvocation): """Restores faces in an image.""" - #fmt: off + + # fmt: off type: Literal["restore_face"] = "restore_face" # Inputs image: Union[ImageField, None] = Field(description="The input image") strength: float = Field(default=0.75, gt=0, le=1, description="The strength of the restoration" ) - #fmt: on - + # fmt: on + # Schema customisation class Config(InvocationConfig): schema_extra = { @@ -26,7 +28,7 @@ class RestoreFaceInvocation(BaseInvocation): } def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get( + image = context.services.images.get_pil_image( self.image.image_type, self.image.image_name ) results = context.services.restoration.upscale_and_reconstruct( @@ -39,18 +41,19 @@ class RestoreFaceInvocation(BaseInvocation): # Results are image and seed, unwrap for now # TODO: can this return multiple results? - image_type = ImageType.RESULT - image_name = context.services.images.create_name( - context.graph_execution_state_id, self.id + image_dto = context.services.images.create( + image=results[0][0], + image_type=ImageType.INTERMEDIATE, + image_category=ImageCategory.GENERAL, + node_id=self.id, + session_id=context.graph_execution_state_id, ) - metadata = context.services.metadata.build_metadata( - session_id=context.graph_execution_state_id, node=self + return ImageOutput( + image=ImageField( + image_name=image_dto.image_name, + image_type=image_dto.image_type, + ), + width=image_dto.width, + height=image_dto.height, ) - - context.services.images.save(image_type, image_name, results[0][0], metadata) - return build_image_output( - image_type=image_type, - image_name=image_name, - image=results[0][0] - ) \ No newline at end of file diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index c4938dfd19..75aeec784f 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -4,22 +4,22 @@ from typing import Literal, Union from pydantic import Field -from invokeai.app.models.image import ImageField, ImageType +from invokeai.app.models.image import ImageCategory, ImageField, ImageType from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig -from .image import ImageOutput, build_image_output +from .image import ImageOutput class UpscaleInvocation(BaseInvocation): """Upscales an image.""" - #fmt: off + + # fmt: off type: Literal["upscale"] = "upscale" # Inputs image: Union[ImageField, None] = Field(description="The input image", default=None) strength: float = Field(default=0.75, gt=0, le=1, description="The strength") level: Literal[2, 4] = Field(default=2, description="The upscale level") - #fmt: on - + # fmt: on # Schema customisation class Config(InvocationConfig): @@ -30,7 +30,7 @@ class UpscaleInvocation(BaseInvocation): } def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get( + image = context.services.images.get_pil_image( self.image.image_type, self.image.image_name ) results = context.services.restoration.upscale_and_reconstruct( @@ -43,18 +43,19 @@ class UpscaleInvocation(BaseInvocation): # Results are image and seed, unwrap for now # TODO: can this return multiple results? - image_type = ImageType.RESULT - image_name = context.services.images.create_name( - context.graph_execution_state_id, self.id + image_dto = context.services.images.create( + image=results[0][0], + image_type=ImageType.RESULT, + image_category=ImageCategory.GENERAL, + node_id=self.id, + session_id=context.graph_execution_state_id, ) - metadata = context.services.metadata.build_metadata( - session_id=context.graph_execution_state_id, node=self + return ImageOutput( + image=ImageField( + image_name=image_dto.image_name, + image_type=image_dto.image_type, + ), + width=image_dto.width, + height=image_dto.height, ) - - context.services.images.save(image_type, image_name, results[0][0], metadata) - return build_image_output( - image_type=image_type, - image_name=image_name, - image=results[0][0] - ) \ No newline at end of file diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index a85089554c..16b603e89f 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -6,7 +6,6 @@ from invokeai.app.services.images import ImageService from invokeai.backend import ModelManager from .events import EventServiceBase from .latent_storage import LatentsStorageBase -from .image_file_storage import ImageFileStorageBase from .restoration_services import RestorationServices from .invocation_queue import InvocationQueueABC from .item_storage import ItemStorageABC @@ -23,12 +22,11 @@ class InvocationServices: events: EventServiceBase latents: LatentsStorageBase - images: ImageFileStorageBase queue: InvocationQueueABC model_manager: ModelManager restoration: RestorationServices configuration: InvokeAISettings - images_new: ImageService + images: ImageService # NOTE: we must forward-declare any types that include invocations, since invocations can use services graph_library: ItemStorageABC["LibraryGraph"] @@ -41,9 +39,8 @@ class InvocationServices: events: EventServiceBase, logger: Logger, latents: LatentsStorageBase, - images: ImageFileStorageBase, + images: ImageService, queue: InvocationQueueABC, - images_new: ImageService, graph_library: ItemStorageABC["LibraryGraph"], graph_execution_manager: ItemStorageABC["GraphExecutionState"], processor: "InvocationProcessorABC", @@ -56,7 +53,6 @@ class InvocationServices: self.latents = latents self.images = images self.queue = queue - self.images_new = images_new self.graph_library = graph_library self.graph_execution_manager = graph_execution_manager self.processor = processor