mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
WIP
This commit is contained in:
0
invokeai/backend/vto_workflow/__init__.py
Normal file
0
invokeai/backend/vto_workflow/__init__.py
Normal file
62
invokeai/backend/vto_workflow/clipseg.py
Normal file
62
invokeai/backend/vto_workflow/clipseg.py
Normal file
@ -0,0 +1,62 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoProcessor, CLIPSegForImageSegmentation, CLIPSegProcessor
|
||||
|
||||
|
||||
def load_clipseg_model() -> tuple[CLIPSegProcessor, CLIPSegForImageSegmentation]:
|
||||
# Load the model.
|
||||
clipseg_processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
||||
clipseg_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
|
||||
return clipseg_processor, clipseg_model
|
||||
|
||||
|
||||
def run_clipseg(
|
||||
images: list[Image.Image],
|
||||
prompt: str,
|
||||
clipseg_processor,
|
||||
clipseg_model,
|
||||
clipseg_temp: float,
|
||||
device: torch.device,
|
||||
) -> list[Image.Image]:
|
||||
"""Run ClipSeg on a list of images.
|
||||
|
||||
Args:
|
||||
clipseg_temp (float): Temperature applied to the CLIPSeg logits. Higher values cause the mask to be 'smoother'
|
||||
and include more of the background. Recommended range: 0.5 to 1.0.
|
||||
"""
|
||||
|
||||
orig_image_sizes = [img.size for img in images]
|
||||
|
||||
prompts = [prompt] * len(images)
|
||||
# TODO(ryand): Should we run the same image with and without the prompt to normalize for any bias in the model?
|
||||
inputs = clipseg_processor(text=prompts, images=images, padding=True, return_tensors="pt")
|
||||
|
||||
# Move inputs and clipseg_model to the correct device and dtype.
|
||||
inputs = {k: v.to(device=device) for k, v in inputs.items()}
|
||||
clipseg_model = clipseg_model.to(device=device)
|
||||
|
||||
outputs = clipseg_model(**inputs)
|
||||
|
||||
logits = outputs.logits
|
||||
if logits.ndim == 2:
|
||||
# The model squeezes the batch dimension if it's 1, so we need to unsqueeze it.
|
||||
logits = logits.unsqueeze(0)
|
||||
probs = torch.nn.functional.sigmoid(logits / clipseg_temp)
|
||||
# Normalize each mask to 0-255. Note that each mask is normalized independently.
|
||||
probs = 255 * probs / probs.amax(dim=(1, 2), keepdim=True)
|
||||
|
||||
# Make mask greyscale.
|
||||
masks: list[Image.Image] = []
|
||||
for prob, orig_size in zip(probs, orig_image_sizes, strict=True):
|
||||
mask = Image.fromarray(prob.cpu().detach().numpy()).convert("L")
|
||||
mask = mask.resize(orig_size)
|
||||
masks.append(mask)
|
||||
|
||||
return masks
|
||||
|
||||
|
||||
def select_device() -> torch.device:
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
|
||||
return torch.device("cpu")
|
57
invokeai/backend/vto_workflow/extract_channel.py
Normal file
57
invokeai/backend/vto_workflow/extract_channel.py
Normal file
@ -0,0 +1,57 @@
|
||||
from enum import Enum
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
|
||||
class ImageChannel(Enum):
|
||||
RGB_R = "RGB_R"
|
||||
RGB_G = "RGB_G"
|
||||
RGB_B = "RGB_B"
|
||||
|
||||
LAB_L = "LAB_L"
|
||||
LAB_A = "LAB_A"
|
||||
LAB_B = "LAB_B"
|
||||
|
||||
HSV_H = "HSV_H"
|
||||
HSV_S = "HSV_S"
|
||||
HSV_V = "HSV_V"
|
||||
|
||||
|
||||
def extract_channel(image: npt.NDArray[np.uint8], channel: ImageChannel) -> npt.NDArray[np.uint8]:
|
||||
"""Extract a channel from an image.
|
||||
|
||||
Args:
|
||||
image (np.ndarray): Shape (H, W, 3) of dtype uint8.
|
||||
channel (ImageChannel): The channel to extract.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Shape (H, W) of dtype uint8.
|
||||
"""
|
||||
if channel == ImageChannel.RGB_R:
|
||||
return image[:, :, 0]
|
||||
elif channel == ImageChannel.RGB_G:
|
||||
return image[:, :, 1]
|
||||
elif channel == ImageChannel.RGB_B:
|
||||
return image[:, :, 2]
|
||||
elif channel == ImageChannel.LAB_L:
|
||||
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
|
||||
return lab[:, :, 0]
|
||||
elif channel == ImageChannel.LAB_A:
|
||||
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
|
||||
return lab[:, :, 1]
|
||||
elif channel == ImageChannel.LAB_B:
|
||||
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
|
||||
return lab[:, :, 2]
|
||||
elif channel == ImageChannel.HSV_H:
|
||||
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
|
||||
return hsv[:, :, 0]
|
||||
elif channel == ImageChannel.HSV_S:
|
||||
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
|
||||
return hsv[:, :, 1]
|
||||
elif channel == ImageChannel.HSV_V:
|
||||
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
|
||||
return hsv[:, :, 2]
|
||||
else:
|
||||
raise ValueError(f"Unknown channel: {channel}")
|
47
invokeai/backend/vto_workflow/overlay_pattern.py
Normal file
47
invokeai/backend/vto_workflow/overlay_pattern.py
Normal file
@ -0,0 +1,47 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.backend.vto_workflow.clipseg import load_clipseg_model, run_clipseg
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_dress_mask(model_image):
|
||||
"""Return a mask of the dress in the image.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Shape (H, W) of dtype bool. True where the dress is, False elsewhere.
|
||||
"""
|
||||
clipseg_processor, clipseg_model = load_clipseg_model()
|
||||
|
||||
masks = run_clipseg(
|
||||
images=[model_image],
|
||||
prompt="a dress",
|
||||
clipseg_processor=clipseg_processor,
|
||||
clipseg_model=clipseg_model,
|
||||
clipseg_temp=1.0,
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
mask_np = np.array(masks[0])
|
||||
thresh = 128
|
||||
binary_mask = mask_np > thresh
|
||||
return binary_mask
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def main():
|
||||
# Load the model image.
|
||||
model_image = Image.open("/home/ryan/src/InvokeAI/invokeai/backend/vto_workflow/dress.jpeg")
|
||||
|
||||
# Load the pattern image.
|
||||
pattern_image = Image.open("/home/ryan/src/InvokeAI/invokeai/backend/vto_workflow/pattern1.jpg")
|
||||
|
||||
# Generate a mask for the dress.
|
||||
mask = generate_dress_mask(model_image)
|
||||
|
||||
print("hi")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user