This commit is contained in:
Ryan Dick
2024-07-25 14:14:12 -04:00
parent ba747373db
commit 36d72baaaa
5 changed files with 276 additions and 0 deletions

View File

@ -0,0 +1,62 @@
import torch
from PIL import Image
from transformers import AutoProcessor, CLIPSegForImageSegmentation, CLIPSegProcessor
def load_clipseg_model() -> tuple[CLIPSegProcessor, CLIPSegForImageSegmentation]:
# Load the model.
clipseg_processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
clipseg_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
return clipseg_processor, clipseg_model
def run_clipseg(
images: list[Image.Image],
prompt: str,
clipseg_processor,
clipseg_model,
clipseg_temp: float,
device: torch.device,
) -> list[Image.Image]:
"""Run ClipSeg on a list of images.
Args:
clipseg_temp (float): Temperature applied to the CLIPSeg logits. Higher values cause the mask to be 'smoother'
and include more of the background. Recommended range: 0.5 to 1.0.
"""
orig_image_sizes = [img.size for img in images]
prompts = [prompt] * len(images)
# TODO(ryand): Should we run the same image with and without the prompt to normalize for any bias in the model?
inputs = clipseg_processor(text=prompts, images=images, padding=True, return_tensors="pt")
# Move inputs and clipseg_model to the correct device and dtype.
inputs = {k: v.to(device=device) for k, v in inputs.items()}
clipseg_model = clipseg_model.to(device=device)
outputs = clipseg_model(**inputs)
logits = outputs.logits
if logits.ndim == 2:
# The model squeezes the batch dimension if it's 1, so we need to unsqueeze it.
logits = logits.unsqueeze(0)
probs = torch.nn.functional.sigmoid(logits / clipseg_temp)
# Normalize each mask to 0-255. Note that each mask is normalized independently.
probs = 255 * probs / probs.amax(dim=(1, 2), keepdim=True)
# Make mask greyscale.
masks: list[Image.Image] = []
for prob, orig_size in zip(probs, orig_image_sizes, strict=True):
mask = Image.fromarray(prob.cpu().detach().numpy()).convert("L")
mask = mask.resize(orig_size)
masks.append(mask)
return masks
def select_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")

View File

@ -0,0 +1,57 @@
from enum import Enum
import cv2
import numpy as np
import numpy.typing as npt
class ImageChannel(Enum):
RGB_R = "RGB_R"
RGB_G = "RGB_G"
RGB_B = "RGB_B"
LAB_L = "LAB_L"
LAB_A = "LAB_A"
LAB_B = "LAB_B"
HSV_H = "HSV_H"
HSV_S = "HSV_S"
HSV_V = "HSV_V"
def extract_channel(image: npt.NDArray[np.uint8], channel: ImageChannel) -> npt.NDArray[np.uint8]:
"""Extract a channel from an image.
Args:
image (np.ndarray): Shape (H, W, 3) of dtype uint8.
channel (ImageChannel): The channel to extract.
Returns:
np.ndarray: Shape (H, W) of dtype uint8.
"""
if channel == ImageChannel.RGB_R:
return image[:, :, 0]
elif channel == ImageChannel.RGB_G:
return image[:, :, 1]
elif channel == ImageChannel.RGB_B:
return image[:, :, 2]
elif channel == ImageChannel.LAB_L:
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
return lab[:, :, 0]
elif channel == ImageChannel.LAB_A:
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
return lab[:, :, 1]
elif channel == ImageChannel.LAB_B:
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
return lab[:, :, 2]
elif channel == ImageChannel.HSV_H:
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
return hsv[:, :, 0]
elif channel == ImageChannel.HSV_S:
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
return hsv[:, :, 1]
elif channel == ImageChannel.HSV_V:
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
return hsv[:, :, 2]
else:
raise ValueError(f"Unknown channel: {channel}")

View File

@ -0,0 +1,47 @@
import numpy as np
import torch
from PIL import Image
from invokeai.backend.vto_workflow.clipseg import load_clipseg_model, run_clipseg
@torch.no_grad()
def generate_dress_mask(model_image):
"""Return a mask of the dress in the image.
Returns:
np.ndarray: Shape (H, W) of dtype bool. True where the dress is, False elsewhere.
"""
clipseg_processor, clipseg_model = load_clipseg_model()
masks = run_clipseg(
images=[model_image],
prompt="a dress",
clipseg_processor=clipseg_processor,
clipseg_model=clipseg_model,
clipseg_temp=1.0,
device=torch.device("cuda"),
)
mask_np = np.array(masks[0])
thresh = 128
binary_mask = mask_np > thresh
return binary_mask
@torch.inference_mode()
def main():
# Load the model image.
model_image = Image.open("/home/ryan/src/InvokeAI/invokeai/backend/vto_workflow/dress.jpeg")
# Load the pattern image.
pattern_image = Image.open("/home/ryan/src/InvokeAI/invokeai/backend/vto_workflow/pattern1.jpg")
# Generate a mask for the dress.
mask = generate_dress_mask(model_image)
print("hi")
if __name__ == "__main__":
main()