From 36d72baaaa6bfe95f29de71246ad3d111453508e Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 25 Jul 2024 14:14:12 -0400 Subject: [PATCH] WIP --- clothing_workflow.ipynb | 110 ++++++++++++++++++ invokeai/backend/vto_workflow/__init__.py | 0 invokeai/backend/vto_workflow/clipseg.py | 62 ++++++++++ .../backend/vto_workflow/extract_channel.py | 57 +++++++++ .../backend/vto_workflow/overlay_pattern.py | 47 ++++++++ 5 files changed, 276 insertions(+) create mode 100644 clothing_workflow.ipynb create mode 100644 invokeai/backend/vto_workflow/__init__.py create mode 100644 invokeai/backend/vto_workflow/clipseg.py create mode 100644 invokeai/backend/vto_workflow/extract_channel.py create mode 100644 invokeai/backend/vto_workflow/overlay_pattern.py diff --git a/clothing_workflow.ipynb b/clothing_workflow.ipynb new file mode 100644 index 0000000000..8dcff681ec --- /dev/null +++ b/clothing_workflow.ipynb @@ -0,0 +1,110 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "aeb428d0-0817-462c-b5d8-455a0615d305", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from PIL import Image\n", + "import numpy as np\n", + "import cv2\n", + "\n", + "from invokeai.backend.vto_workflow.overlay_pattern import generate_dress_mask\n", + "from invokeai.backend.vto_workflow.extract_channel import extract_channel, ImageChannel\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6140d4b7-8238-431c-848e-6f6ae27652f5", + "metadata": {}, + "outputs": [], + "source": [ + " # Load the model image.\n", + "model_image = Image.open(\"/home/ryan/src/InvokeAI/invokeai/backend/vto_workflow/dress.jpeg\")\n", + "\n", + "# Load the pattern image.\n", + "pattern_image = Image.open(\"/home/ryan/src/InvokeAI/invokeai/backend/vto_workflow/pattern1.jpg\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fb7186ba-dc0c-4520-ac30-49073a65601a", + "metadata": {}, + "outputs": [], + "source": [ + "mask = generate_dress_mask(model_image)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9b935de4-94c5-4be5-bf8e-a5a6e445c811", + "metadata": {}, + "outputs": [], + "source": [ + "# Visualize mask\n", + "model_image_np = np.array(model_image)\n", + "masked_model_image = (model_image_np * np.expand_dims(mask, -1).astype(np.float32)).astype(np.uint8)\n", + "mask_image = Image.fromarray(masked_model_image)\n", + "mask_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e51bb545", + "metadata": {}, + "outputs": [], + "source": [ + "shadows = extract_channel(np.array(model_image), ImageChannel.LAB_L)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ec43de4a", + "metadata": {}, + "outputs": [], + "source": [ + "# Visualize masked shadows\n", + "masked_shadows = (shadows * mask).astype(np.uint8)\n", + "masked_shadows_image = Image.fromarray(masked_shadows)\n", + "masked_shadows_image" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dbb53794", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/invokeai/backend/vto_workflow/__init__.py b/invokeai/backend/vto_workflow/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/invokeai/backend/vto_workflow/clipseg.py b/invokeai/backend/vto_workflow/clipseg.py new file mode 100644 index 0000000000..f427608dbb --- /dev/null +++ b/invokeai/backend/vto_workflow/clipseg.py @@ -0,0 +1,62 @@ +import torch +from PIL import Image +from transformers import AutoProcessor, CLIPSegForImageSegmentation, CLIPSegProcessor + + +def load_clipseg_model() -> tuple[CLIPSegProcessor, CLIPSegForImageSegmentation]: + # Load the model. + clipseg_processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") + clipseg_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") + return clipseg_processor, clipseg_model + + +def run_clipseg( + images: list[Image.Image], + prompt: str, + clipseg_processor, + clipseg_model, + clipseg_temp: float, + device: torch.device, +) -> list[Image.Image]: + """Run ClipSeg on a list of images. + + Args: + clipseg_temp (float): Temperature applied to the CLIPSeg logits. Higher values cause the mask to be 'smoother' + and include more of the background. Recommended range: 0.5 to 1.0. + """ + + orig_image_sizes = [img.size for img in images] + + prompts = [prompt] * len(images) + # TODO(ryand): Should we run the same image with and without the prompt to normalize for any bias in the model? + inputs = clipseg_processor(text=prompts, images=images, padding=True, return_tensors="pt") + + # Move inputs and clipseg_model to the correct device and dtype. + inputs = {k: v.to(device=device) for k, v in inputs.items()} + clipseg_model = clipseg_model.to(device=device) + + outputs = clipseg_model(**inputs) + + logits = outputs.logits + if logits.ndim == 2: + # The model squeezes the batch dimension if it's 1, so we need to unsqueeze it. + logits = logits.unsqueeze(0) + probs = torch.nn.functional.sigmoid(logits / clipseg_temp) + # Normalize each mask to 0-255. Note that each mask is normalized independently. + probs = 255 * probs / probs.amax(dim=(1, 2), keepdim=True) + + # Make mask greyscale. + masks: list[Image.Image] = [] + for prob, orig_size in zip(probs, orig_image_sizes, strict=True): + mask = Image.fromarray(prob.cpu().detach().numpy()).convert("L") + mask = mask.resize(orig_size) + masks.append(mask) + + return masks + + +def select_device() -> torch.device: + if torch.cuda.is_available(): + return torch.device("cuda") + + return torch.device("cpu") diff --git a/invokeai/backend/vto_workflow/extract_channel.py b/invokeai/backend/vto_workflow/extract_channel.py new file mode 100644 index 0000000000..0917b198a0 --- /dev/null +++ b/invokeai/backend/vto_workflow/extract_channel.py @@ -0,0 +1,57 @@ +from enum import Enum + +import cv2 +import numpy as np +import numpy.typing as npt + + +class ImageChannel(Enum): + RGB_R = "RGB_R" + RGB_G = "RGB_G" + RGB_B = "RGB_B" + + LAB_L = "LAB_L" + LAB_A = "LAB_A" + LAB_B = "LAB_B" + + HSV_H = "HSV_H" + HSV_S = "HSV_S" + HSV_V = "HSV_V" + + +def extract_channel(image: npt.NDArray[np.uint8], channel: ImageChannel) -> npt.NDArray[np.uint8]: + """Extract a channel from an image. + + Args: + image (np.ndarray): Shape (H, W, 3) of dtype uint8. + channel (ImageChannel): The channel to extract. + + Returns: + np.ndarray: Shape (H, W) of dtype uint8. + """ + if channel == ImageChannel.RGB_R: + return image[:, :, 0] + elif channel == ImageChannel.RGB_G: + return image[:, :, 1] + elif channel == ImageChannel.RGB_B: + return image[:, :, 2] + elif channel == ImageChannel.LAB_L: + lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB) + return lab[:, :, 0] + elif channel == ImageChannel.LAB_A: + lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB) + return lab[:, :, 1] + elif channel == ImageChannel.LAB_B: + lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB) + return lab[:, :, 2] + elif channel == ImageChannel.HSV_H: + hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV) + return hsv[:, :, 0] + elif channel == ImageChannel.HSV_S: + hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV) + return hsv[:, :, 1] + elif channel == ImageChannel.HSV_V: + hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV) + return hsv[:, :, 2] + else: + raise ValueError(f"Unknown channel: {channel}") diff --git a/invokeai/backend/vto_workflow/overlay_pattern.py b/invokeai/backend/vto_workflow/overlay_pattern.py new file mode 100644 index 0000000000..d40af66bee --- /dev/null +++ b/invokeai/backend/vto_workflow/overlay_pattern.py @@ -0,0 +1,47 @@ +import numpy as np +import torch +from PIL import Image + +from invokeai.backend.vto_workflow.clipseg import load_clipseg_model, run_clipseg + + +@torch.no_grad() +def generate_dress_mask(model_image): + """Return a mask of the dress in the image. + + Returns: + np.ndarray: Shape (H, W) of dtype bool. True where the dress is, False elsewhere. + """ + clipseg_processor, clipseg_model = load_clipseg_model() + + masks = run_clipseg( + images=[model_image], + prompt="a dress", + clipseg_processor=clipseg_processor, + clipseg_model=clipseg_model, + clipseg_temp=1.0, + device=torch.device("cuda"), + ) + + mask_np = np.array(masks[0]) + thresh = 128 + binary_mask = mask_np > thresh + return binary_mask + + +@torch.inference_mode() +def main(): + # Load the model image. + model_image = Image.open("/home/ryan/src/InvokeAI/invokeai/backend/vto_workflow/dress.jpeg") + + # Load the pattern image. + pattern_image = Image.open("/home/ryan/src/InvokeAI/invokeai/backend/vto_workflow/pattern1.jpg") + + # Generate a mask for the dress. + mask = generate_dress_mask(model_image) + + print("hi") + + +if __name__ == "__main__": + main()