This commit is contained in:
Ryan Dick 2024-07-25 14:14:12 -04:00
parent ba747373db
commit 36d72baaaa
5 changed files with 276 additions and 0 deletions

110
clothing_workflow.ipynb Normal file
View File

@ -0,0 +1,110 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "aeb428d0-0817-462c-b5d8-455a0615d305",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from PIL import Image\n",
"import numpy as np\n",
"import cv2\n",
"\n",
"from invokeai.backend.vto_workflow.overlay_pattern import generate_dress_mask\n",
"from invokeai.backend.vto_workflow.extract_channel import extract_channel, ImageChannel\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6140d4b7-8238-431c-848e-6f6ae27652f5",
"metadata": {},
"outputs": [],
"source": [
" # Load the model image.\n",
"model_image = Image.open(\"/home/ryan/src/InvokeAI/invokeai/backend/vto_workflow/dress.jpeg\")\n",
"\n",
"# Load the pattern image.\n",
"pattern_image = Image.open(\"/home/ryan/src/InvokeAI/invokeai/backend/vto_workflow/pattern1.jpg\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fb7186ba-dc0c-4520-ac30-49073a65601a",
"metadata": {},
"outputs": [],
"source": [
"mask = generate_dress_mask(model_image)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9b935de4-94c5-4be5-bf8e-a5a6e445c811",
"metadata": {},
"outputs": [],
"source": [
"# Visualize mask\n",
"model_image_np = np.array(model_image)\n",
"masked_model_image = (model_image_np * np.expand_dims(mask, -1).astype(np.float32)).astype(np.uint8)\n",
"mask_image = Image.fromarray(masked_model_image)\n",
"mask_image"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e51bb545",
"metadata": {},
"outputs": [],
"source": [
"shadows = extract_channel(np.array(model_image), ImageChannel.LAB_L)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ec43de4a",
"metadata": {},
"outputs": [],
"source": [
"# Visualize masked shadows\n",
"masked_shadows = (shadows * mask).astype(np.uint8)\n",
"masked_shadows_image = Image.fromarray(masked_shadows)\n",
"masked_shadows_image"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dbb53794",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,62 @@
import torch
from PIL import Image
from transformers import AutoProcessor, CLIPSegForImageSegmentation, CLIPSegProcessor
def load_clipseg_model() -> tuple[CLIPSegProcessor, CLIPSegForImageSegmentation]:
# Load the model.
clipseg_processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
clipseg_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
return clipseg_processor, clipseg_model
def run_clipseg(
images: list[Image.Image],
prompt: str,
clipseg_processor,
clipseg_model,
clipseg_temp: float,
device: torch.device,
) -> list[Image.Image]:
"""Run ClipSeg on a list of images.
Args:
clipseg_temp (float): Temperature applied to the CLIPSeg logits. Higher values cause the mask to be 'smoother'
and include more of the background. Recommended range: 0.5 to 1.0.
"""
orig_image_sizes = [img.size for img in images]
prompts = [prompt] * len(images)
# TODO(ryand): Should we run the same image with and without the prompt to normalize for any bias in the model?
inputs = clipseg_processor(text=prompts, images=images, padding=True, return_tensors="pt")
# Move inputs and clipseg_model to the correct device and dtype.
inputs = {k: v.to(device=device) for k, v in inputs.items()}
clipseg_model = clipseg_model.to(device=device)
outputs = clipseg_model(**inputs)
logits = outputs.logits
if logits.ndim == 2:
# The model squeezes the batch dimension if it's 1, so we need to unsqueeze it.
logits = logits.unsqueeze(0)
probs = torch.nn.functional.sigmoid(logits / clipseg_temp)
# Normalize each mask to 0-255. Note that each mask is normalized independently.
probs = 255 * probs / probs.amax(dim=(1, 2), keepdim=True)
# Make mask greyscale.
masks: list[Image.Image] = []
for prob, orig_size in zip(probs, orig_image_sizes, strict=True):
mask = Image.fromarray(prob.cpu().detach().numpy()).convert("L")
mask = mask.resize(orig_size)
masks.append(mask)
return masks
def select_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")

View File

@ -0,0 +1,57 @@
from enum import Enum
import cv2
import numpy as np
import numpy.typing as npt
class ImageChannel(Enum):
RGB_R = "RGB_R"
RGB_G = "RGB_G"
RGB_B = "RGB_B"
LAB_L = "LAB_L"
LAB_A = "LAB_A"
LAB_B = "LAB_B"
HSV_H = "HSV_H"
HSV_S = "HSV_S"
HSV_V = "HSV_V"
def extract_channel(image: npt.NDArray[np.uint8], channel: ImageChannel) -> npt.NDArray[np.uint8]:
"""Extract a channel from an image.
Args:
image (np.ndarray): Shape (H, W, 3) of dtype uint8.
channel (ImageChannel): The channel to extract.
Returns:
np.ndarray: Shape (H, W) of dtype uint8.
"""
if channel == ImageChannel.RGB_R:
return image[:, :, 0]
elif channel == ImageChannel.RGB_G:
return image[:, :, 1]
elif channel == ImageChannel.RGB_B:
return image[:, :, 2]
elif channel == ImageChannel.LAB_L:
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
return lab[:, :, 0]
elif channel == ImageChannel.LAB_A:
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
return lab[:, :, 1]
elif channel == ImageChannel.LAB_B:
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
return lab[:, :, 2]
elif channel == ImageChannel.HSV_H:
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
return hsv[:, :, 0]
elif channel == ImageChannel.HSV_S:
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
return hsv[:, :, 1]
elif channel == ImageChannel.HSV_V:
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
return hsv[:, :, 2]
else:
raise ValueError(f"Unknown channel: {channel}")

View File

@ -0,0 +1,47 @@
import numpy as np
import torch
from PIL import Image
from invokeai.backend.vto_workflow.clipseg import load_clipseg_model, run_clipseg
@torch.no_grad()
def generate_dress_mask(model_image):
"""Return a mask of the dress in the image.
Returns:
np.ndarray: Shape (H, W) of dtype bool. True where the dress is, False elsewhere.
"""
clipseg_processor, clipseg_model = load_clipseg_model()
masks = run_clipseg(
images=[model_image],
prompt="a dress",
clipseg_processor=clipseg_processor,
clipseg_model=clipseg_model,
clipseg_temp=1.0,
device=torch.device("cuda"),
)
mask_np = np.array(masks[0])
thresh = 128
binary_mask = mask_np > thresh
return binary_mask
@torch.inference_mode()
def main():
# Load the model image.
model_image = Image.open("/home/ryan/src/InvokeAI/invokeai/backend/vto_workflow/dress.jpeg")
# Load the pattern image.
pattern_image = Image.open("/home/ryan/src/InvokeAI/invokeai/backend/vto_workflow/pattern1.jpg")
# Generate a mask for the dress.
mask = generate_dress_mask(model_image)
print("hi")
if __name__ == "__main__":
main()