Add mask dilation.

This commit is contained in:
Ryan Dick 2024-07-25 18:17:18 -04:00
parent 5dd0aeb1ab
commit d5e824e782

View File

@ -1,3 +1,4 @@
import cv2
import numpy as np import numpy as np
from PIL import Image from PIL import Image
@ -32,6 +33,12 @@ class VTOInvocation(BaseInvocation, WithMetadata, WithBoard):
ge=0.0, ge=0.0,
) )
mask_dilation: int = InputField(
description="The number of pixels to dilate the mask by. Default is 1.",
default=1,
ge=0,
)
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
# TODO(ryand): Avoid all the unnecessary flip-flopping between PIL and numpy. # TODO(ryand): Avoid all the unnecessary flip-flopping between PIL and numpy.
original_image = context.images.get_pil(self.original_image.image_name) original_image = context.images.get_pil(self.original_image.image_name)
@ -59,9 +66,13 @@ class VTOInvocation(BaseInvocation, WithMetadata, WithBoard):
pattern_with_shadows = multiply_images(expanded_pattern, Image.fromarray(shadows)) pattern_with_shadows = multiply_images(expanded_pattern, Image.fromarray(shadows))
# Dilate the mask.
clothing_mask_np = np.array(clothing_mask)
if self.mask_dilation > 0:
clothing_mask_np = cv2.dilate(clothing_mask_np, np.ones((3, 3), np.uint8), iterations=self.mask_dilation)
# Merge the pattern with the model image. # Merge the pattern with the model image.
pattern_with_shadows_np = np.array(pattern_with_shadows) pattern_with_shadows_np = np.array(pattern_with_shadows)
clothing_mask_np = np.array(clothing_mask)
original_image_np = np.array(original_image) original_image_np = np.array(original_image)
merged_image = np.where(clothing_mask_np[:, :, None], pattern_with_shadows_np, original_image_np) merged_image = np.where(clothing_mask_np[:, :, None], pattern_with_shadows_np, original_image_np)
merged_image = Image.fromarray(merged_image) merged_image = Image.fromarray(merged_image)