diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index b25d3735c2..21dfb4c1cd 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -67,16 +67,17 @@ class LoadImageInvocation(BaseInvocation): type: Literal["load_image"] = "load_image" # Inputs - image_type: ImageType = Field(description="The type of the image") - image_name: str = Field(description="The name of the image") + image: Union[ImageField, None] = Field( + default=None, description="The image to load" + ) # fmt: on def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image_type, self.image_name) + image = context.services.images.get_pil_image(self.image.image_type, self.image.image_name) return ImageOutput( image=ImageField( - image_name=self.image_name, - image_type=self.image_type, + image_name=self.image.image_name, + image_type=self.image.image_type, ), width=image.width, height=image.height, @@ -138,7 +139,7 @@ class ImageCropInvocation(BaseInvocation, PILInvocationConfig): image_dto = context.services.images.create( image=image_crop, - image_type=ImageType.INTERMEDIATE, + image_type=ImageType.RESULT, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -237,7 +238,7 @@ class MaskFromAlphaInvocation(BaseInvocation, PILInvocationConfig): image_dto = context.services.images.create( image=image_mask, - image_type=ImageType.INTERMEDIATE, + image_type=ImageType.RESULT, image_category=ImageCategory.MASK, node_id=self.id, session_id=context.graph_execution_state_id, @@ -275,7 +276,7 @@ class ImageMultiplyInvocation(BaseInvocation, PILInvocationConfig): image_dto = context.services.images.create( image=multiply_image, - image_type=ImageType.INTERMEDIATE, + image_type=ImageType.RESULT, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -313,7 +314,7 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig): image_dto = context.services.images.create( image=channel_image, - image_type=ImageType.INTERMEDIATE, + image_type=ImageType.RESULT, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -328,7 +329,8 @@ class ImageChannelInvocation(BaseInvocation, PILInvocationConfig): ) -IMAGE_MODES = Literal['L', 'RGB', 'RGBA', 'CMYK', 'YCbCr', 'LAB', 'HSV', 'I', 'F'] +IMAGE_MODES = Literal["L", "RGB", "RGBA", "CMYK", "YCbCr", "LAB", "HSV", "I", "F"] + class ImageConvertInvocation(BaseInvocation, PILInvocationConfig): """Converts an image to a different mode.""" @@ -350,7 +352,7 @@ class ImageConvertInvocation(BaseInvocation, PILInvocationConfig): image_dto = context.services.images.create( image=converted_image, - image_type=ImageType.INTERMEDIATE, + image_type=ImageType.RESULT, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -391,7 +393,7 @@ class ImageBlurInvocation(BaseInvocation, PILInvocationConfig): image_dto = context.services.images.create( image=blur_image, - image_type=ImageType.INTERMEDIATE, + image_type=ImageType.RESULT, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -431,7 +433,7 @@ class ImageLerpInvocation(BaseInvocation, PILInvocationConfig): image_dto = context.services.images.create( image=lerp_image, - image_type=ImageType.INTERMEDIATE, + image_type=ImageType.RESULT, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, @@ -476,7 +478,7 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig): image_dto = context.services.images.create( image=ilerp_image, - image_type=ImageType.INTERMEDIATE, + image_type=ImageType.RESULT, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id,