mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
WIP
This commit is contained in:
parent
ba747373db
commit
36d72baaaa
110
clothing_workflow.ipynb
Normal file
110
clothing_workflow.ipynb
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "aeb428d0-0817-462c-b5d8-455a0615d305",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import torch\n",
|
||||||
|
"from PIL import Image\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import cv2\n",
|
||||||
|
"\n",
|
||||||
|
"from invokeai.backend.vto_workflow.overlay_pattern import generate_dress_mask\n",
|
||||||
|
"from invokeai.backend.vto_workflow.extract_channel import extract_channel, ImageChannel\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "6140d4b7-8238-431c-848e-6f6ae27652f5",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
" # Load the model image.\n",
|
||||||
|
"model_image = Image.open(\"/home/ryan/src/InvokeAI/invokeai/backend/vto_workflow/dress.jpeg\")\n",
|
||||||
|
"\n",
|
||||||
|
"# Load the pattern image.\n",
|
||||||
|
"pattern_image = Image.open(\"/home/ryan/src/InvokeAI/invokeai/backend/vto_workflow/pattern1.jpg\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "fb7186ba-dc0c-4520-ac30-49073a65601a",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"mask = generate_dress_mask(model_image)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "9b935de4-94c5-4be5-bf8e-a5a6e445c811",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Visualize mask\n",
|
||||||
|
"model_image_np = np.array(model_image)\n",
|
||||||
|
"masked_model_image = (model_image_np * np.expand_dims(mask, -1).astype(np.float32)).astype(np.uint8)\n",
|
||||||
|
"mask_image = Image.fromarray(masked_model_image)\n",
|
||||||
|
"mask_image"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "e51bb545",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"shadows = extract_channel(np.array(model_image), ImageChannel.LAB_L)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "ec43de4a",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Visualize masked shadows\n",
|
||||||
|
"masked_shadows = (shadows * mask).astype(np.uint8)\n",
|
||||||
|
"masked_shadows_image = Image.fromarray(masked_shadows)\n",
|
||||||
|
"masked_shadows_image"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "dbb53794",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.12"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
0
invokeai/backend/vto_workflow/__init__.py
Normal file
0
invokeai/backend/vto_workflow/__init__.py
Normal file
62
invokeai/backend/vto_workflow/clipseg.py
Normal file
62
invokeai/backend/vto_workflow/clipseg.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import AutoProcessor, CLIPSegForImageSegmentation, CLIPSegProcessor
|
||||||
|
|
||||||
|
|
||||||
|
def load_clipseg_model() -> tuple[CLIPSegProcessor, CLIPSegForImageSegmentation]:
|
||||||
|
# Load the model.
|
||||||
|
clipseg_processor = AutoProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
|
||||||
|
clipseg_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
|
||||||
|
return clipseg_processor, clipseg_model
|
||||||
|
|
||||||
|
|
||||||
|
def run_clipseg(
|
||||||
|
images: list[Image.Image],
|
||||||
|
prompt: str,
|
||||||
|
clipseg_processor,
|
||||||
|
clipseg_model,
|
||||||
|
clipseg_temp: float,
|
||||||
|
device: torch.device,
|
||||||
|
) -> list[Image.Image]:
|
||||||
|
"""Run ClipSeg on a list of images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
clipseg_temp (float): Temperature applied to the CLIPSeg logits. Higher values cause the mask to be 'smoother'
|
||||||
|
and include more of the background. Recommended range: 0.5 to 1.0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
orig_image_sizes = [img.size for img in images]
|
||||||
|
|
||||||
|
prompts = [prompt] * len(images)
|
||||||
|
# TODO(ryand): Should we run the same image with and without the prompt to normalize for any bias in the model?
|
||||||
|
inputs = clipseg_processor(text=prompts, images=images, padding=True, return_tensors="pt")
|
||||||
|
|
||||||
|
# Move inputs and clipseg_model to the correct device and dtype.
|
||||||
|
inputs = {k: v.to(device=device) for k, v in inputs.items()}
|
||||||
|
clipseg_model = clipseg_model.to(device=device)
|
||||||
|
|
||||||
|
outputs = clipseg_model(**inputs)
|
||||||
|
|
||||||
|
logits = outputs.logits
|
||||||
|
if logits.ndim == 2:
|
||||||
|
# The model squeezes the batch dimension if it's 1, so we need to unsqueeze it.
|
||||||
|
logits = logits.unsqueeze(0)
|
||||||
|
probs = torch.nn.functional.sigmoid(logits / clipseg_temp)
|
||||||
|
# Normalize each mask to 0-255. Note that each mask is normalized independently.
|
||||||
|
probs = 255 * probs / probs.amax(dim=(1, 2), keepdim=True)
|
||||||
|
|
||||||
|
# Make mask greyscale.
|
||||||
|
masks: list[Image.Image] = []
|
||||||
|
for prob, orig_size in zip(probs, orig_image_sizes, strict=True):
|
||||||
|
mask = Image.fromarray(prob.cpu().detach().numpy()).convert("L")
|
||||||
|
mask = mask.resize(orig_size)
|
||||||
|
masks.append(mask)
|
||||||
|
|
||||||
|
return masks
|
||||||
|
|
||||||
|
|
||||||
|
def select_device() -> torch.device:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return torch.device("cuda")
|
||||||
|
|
||||||
|
return torch.device("cpu")
|
57
invokeai/backend/vto_workflow/extract_channel.py
Normal file
57
invokeai/backend/vto_workflow/extract_channel.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
|
||||||
|
|
||||||
|
class ImageChannel(Enum):
|
||||||
|
RGB_R = "RGB_R"
|
||||||
|
RGB_G = "RGB_G"
|
||||||
|
RGB_B = "RGB_B"
|
||||||
|
|
||||||
|
LAB_L = "LAB_L"
|
||||||
|
LAB_A = "LAB_A"
|
||||||
|
LAB_B = "LAB_B"
|
||||||
|
|
||||||
|
HSV_H = "HSV_H"
|
||||||
|
HSV_S = "HSV_S"
|
||||||
|
HSV_V = "HSV_V"
|
||||||
|
|
||||||
|
|
||||||
|
def extract_channel(image: npt.NDArray[np.uint8], channel: ImageChannel) -> npt.NDArray[np.uint8]:
|
||||||
|
"""Extract a channel from an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (np.ndarray): Shape (H, W, 3) of dtype uint8.
|
||||||
|
channel (ImageChannel): The channel to extract.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Shape (H, W) of dtype uint8.
|
||||||
|
"""
|
||||||
|
if channel == ImageChannel.RGB_R:
|
||||||
|
return image[:, :, 0]
|
||||||
|
elif channel == ImageChannel.RGB_G:
|
||||||
|
return image[:, :, 1]
|
||||||
|
elif channel == ImageChannel.RGB_B:
|
||||||
|
return image[:, :, 2]
|
||||||
|
elif channel == ImageChannel.LAB_L:
|
||||||
|
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
|
||||||
|
return lab[:, :, 0]
|
||||||
|
elif channel == ImageChannel.LAB_A:
|
||||||
|
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
|
||||||
|
return lab[:, :, 1]
|
||||||
|
elif channel == ImageChannel.LAB_B:
|
||||||
|
lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
|
||||||
|
return lab[:, :, 2]
|
||||||
|
elif channel == ImageChannel.HSV_H:
|
||||||
|
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
|
||||||
|
return hsv[:, :, 0]
|
||||||
|
elif channel == ImageChannel.HSV_S:
|
||||||
|
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
|
||||||
|
return hsv[:, :, 1]
|
||||||
|
elif channel == ImageChannel.HSV_V:
|
||||||
|
hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
|
||||||
|
return hsv[:, :, 2]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown channel: {channel}")
|
47
invokeai/backend/vto_workflow/overlay_pattern.py
Normal file
47
invokeai/backend/vto_workflow/overlay_pattern.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from invokeai.backend.vto_workflow.clipseg import load_clipseg_model, run_clipseg
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def generate_dress_mask(model_image):
|
||||||
|
"""Return a mask of the dress in the image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Shape (H, W) of dtype bool. True where the dress is, False elsewhere.
|
||||||
|
"""
|
||||||
|
clipseg_processor, clipseg_model = load_clipseg_model()
|
||||||
|
|
||||||
|
masks = run_clipseg(
|
||||||
|
images=[model_image],
|
||||||
|
prompt="a dress",
|
||||||
|
clipseg_processor=clipseg_processor,
|
||||||
|
clipseg_model=clipseg_model,
|
||||||
|
clipseg_temp=1.0,
|
||||||
|
device=torch.device("cuda"),
|
||||||
|
)
|
||||||
|
|
||||||
|
mask_np = np.array(masks[0])
|
||||||
|
thresh = 128
|
||||||
|
binary_mask = mask_np > thresh
|
||||||
|
return binary_mask
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def main():
|
||||||
|
# Load the model image.
|
||||||
|
model_image = Image.open("/home/ryan/src/InvokeAI/invokeai/backend/vto_workflow/dress.jpeg")
|
||||||
|
|
||||||
|
# Load the pattern image.
|
||||||
|
pattern_image = Image.open("/home/ryan/src/InvokeAI/invokeai/backend/vto_workflow/pattern1.jpg")
|
||||||
|
|
||||||
|
# Generate a mask for the dress.
|
||||||
|
mask = generate_dress_mask(model_image)
|
||||||
|
|
||||||
|
print("hi")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Reference in New Issue
Block a user