mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add mask dilation.
This commit is contained in:
parent
5dd0aeb1ab
commit
d5e824e782
@ -1,3 +1,4 @@
|
|||||||
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
@ -32,6 +33,12 @@ class VTOInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
ge=0.0,
|
ge=0.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mask_dilation: int = InputField(
|
||||||
|
description="The number of pixels to dilate the mask by. Default is 1.",
|
||||||
|
default=1,
|
||||||
|
ge=0,
|
||||||
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
# TODO(ryand): Avoid all the unnecessary flip-flopping between PIL and numpy.
|
# TODO(ryand): Avoid all the unnecessary flip-flopping between PIL and numpy.
|
||||||
original_image = context.images.get_pil(self.original_image.image_name)
|
original_image = context.images.get_pil(self.original_image.image_name)
|
||||||
@ -59,9 +66,13 @@ class VTOInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
|
|
||||||
pattern_with_shadows = multiply_images(expanded_pattern, Image.fromarray(shadows))
|
pattern_with_shadows = multiply_images(expanded_pattern, Image.fromarray(shadows))
|
||||||
|
|
||||||
|
# Dilate the mask.
|
||||||
|
clothing_mask_np = np.array(clothing_mask)
|
||||||
|
if self.mask_dilation > 0:
|
||||||
|
clothing_mask_np = cv2.dilate(clothing_mask_np, np.ones((3, 3), np.uint8), iterations=self.mask_dilation)
|
||||||
|
|
||||||
# Merge the pattern with the model image.
|
# Merge the pattern with the model image.
|
||||||
pattern_with_shadows_np = np.array(pattern_with_shadows)
|
pattern_with_shadows_np = np.array(pattern_with_shadows)
|
||||||
clothing_mask_np = np.array(clothing_mask)
|
|
||||||
original_image_np = np.array(original_image)
|
original_image_np = np.array(original_image)
|
||||||
merged_image = np.where(clothing_mask_np[:, :, None], pattern_with_shadows_np, original_image_np)
|
merged_image = np.where(clothing_mask_np[:, :, None], pattern_with_shadows_np, original_image_np)
|
||||||
merged_image = Image.fromarray(merged_image)
|
merged_image = Image.fromarray(merged_image)
|
||||||
|
Loading…
Reference in New Issue
Block a user