diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index 3ddd3a3051..2c70a141fb 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -1032,7 +1032,11 @@ class CanvasV2MaskAndCropOutput(ImageOutput): class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard): """Handles Canvas V2 image output masking and cropping""" - image: ImageField = InputField(description="The image to apply the mask to") + source_image: ImageField | None = InputField( + default=None, + description="The source image onto which the masked generated image is pasted. If omitted, the masked generated image is returned with transparency.", + ) + generated_image: ImageField = InputField(description="The image to apply the mask to") mask: ImageField = InputField(description="The mask to apply") mask_blur: int = InputField(default=0, ge=0, description="The amount to blur the mask by") @@ -1046,33 +1050,25 @@ class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard): return ImageOps.invert(mask.convert("L")) def invoke(self, context: InvocationContext) -> CanvasV2MaskAndCropOutput: - image = context.images.get_pil(self.image.image_name) mask = self._prepare_mask(context.images.get_pil(self.mask.image_name)) - image.putalpha(mask) + + if self.source_image: + generated_image = context.images.get_pil(self.generated_image.image_name) + source_image = context.images.get_pil(self.source_image.image_name) + source_image.paste(generated_image, (0, 0), mask) + image_dto = context.images.save(image=source_image) + else: + generated_image = context.images.get_pil(self.generated_image.image_name) + generated_image.putalpha(mask) + image_dto = context.images.save(image=generated_image) + # bbox = image.getbbox() # image = image.crop(bbox) - image_dto = context.images.save(image=image) return CanvasV2MaskAndCropOutput( image=ImageField(image_name=image_dto.image_name), offset_x=0, offset_y=0, - width=image.width, - height=image.height, + width=image_dto.width, + height=image_dto.height, ) - - # def invoke(self, context: InvocationContext) -> CanvasV2MaskAndCropOutput: - # image = context.images.get_pil(self.image.image_name) - # mask = self._prepare_mask(context.images.get_pil(self.mask.image_name)) - # image.putalpha(mask) - # bbox = image.getbbox() - # image = image.crop(bbox) - # image_dto = context.images.save(image=image) - - # return CanvasV2MaskAndCropOutput( - # image=ImageField(image_name=image_dto.image_name), - # offset_x=bbox[0], - # offset_y=bbox[1], - # width=image.width, - # height=image.height, - # )