Add image multiplication.

This commit is contained in:
Ryan Dick
2024-07-25 15:31:46 -04:00
parent 3573d39860
commit bead83a8bc
2 changed files with 61 additions and 14 deletions

View File

@ -29,6 +29,30 @@ def generate_dress_mask(model_image):
return binary_mask
def multiply_images(image_1: Image.Image, image_2: Image.Image) -> Image.Image:
"""Multiply two images together.
Args:
image_1 (Image.Image): The first image.
image_2 (Image.Image): The second image.
Returns:
Image.Image: The product of the two images.
"""
image_1_np = np.array(image_1, dtype=np.float32)
if image_1_np.ndim == 2:
# If the image is greyscale, add a channel dimension.
image_1_np = np.expand_dims(image_1_np, axis=-1)
image_2_np = np.array(image_2, dtype=np.float32)
if image_2_np.ndim == 2:
# If the image is greyscale, add a channel dimension.
image_2_np = np.expand_dims(image_2_np, axis=-1)
product_np = image_1_np * image_2_np // 255
product_np = np.clip(product_np, 0, 255).astype(np.uint8)
product = Image.fromarray(product_np)
return product
@torch.inference_mode()
def main():
# Load the model image.