canvas: improve paste back (or try to)

This commit is contained in:
blessedcoolant
2024-02-22 10:23:18 +05:30
committed by psychedelicious
parent 8f6c2a8b92
commit 68d79c002d
7 changed files with 208 additions and 122 deletions

View File

@ -17,16 +17,12 @@ from invokeai.app.invocations.fields import (
WithMetadata,
)
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.safety_checker import SafetyChecker
from .baseinvocation import (
BaseInvocation,
Classification,
invocation,
)
from .baseinvocation import BaseInvocation, Classification, invocation
@invocation("show_image", title="Show Image", tags=["image"], category="image", version="1.0.1")
@ -934,3 +930,42 @@ class SaveImageInvocation(BaseInvocation, WithMetadata, WithBoard):
image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto)
@invocation(
"iai_canvas_paste_back",
title="InvokeAI Canvas Paste Back",
tags=["image", "combine"],
category="image",
version="1.0.0",
)
class IAICanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Combines two images by using the mask provided"""
source_image: ImageField = InputField(description="The source image")
target_image: ImageField = InputField(default=None, description="The target image")
mask: ImageField = InputField(
description="The mask to use when pasting",
)
mask_blur: int = InputField(default=0, ge=0, description="The amount to blur the mask by")
def _prepare_mask(self, mask: Image.Image):
mask_array = numpy.array(mask)
kernel = numpy.ones((self.mask_blur, self.mask_blur), numpy.uint8)
dilated_mask_array = cv2.erode(mask_array, kernel, iterations=3)
dilated_mask = Image.fromarray(dilated_mask_array)
if self.mask_blur > 0:
mask = dilated_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
return ImageOps.invert(mask.convert("L"))
def invoke(self, context: InvocationContext) -> ImageOutput:
source_image = context.images.get_pil(self.source_image.image_name)
target_image = context.images.get_pil(self.target_image.image_name)
mask = self._prepare_mask(context.images.get_pil(self.mask.image_name))
# Merge the bands back together
source_image.paste(target_image, (0, 0), mask)
image_dto = context.images.save(image=source_image)
return ImageOutput.build(image_dto)