InvokeAI/invokeai/backend/image_util/txt2mask.py

131 lines
4.6 KiB
Python
Raw Normal View History

2023-03-03 06:02:00 +00:00
"""Makes available the Txt2Mask class, which assists in the automatic
assignment of masks via text prompt using clipseg.
Here is typical usage:
2023-03-03 06:02:00 +00:00
from invokeai.backend.image_util.txt2mask import Txt2Mask, SegmentedGrayscale
from PIL import Image
txt2mask = Txt2Mask(self.device)
segmented = txt2mask.segment(Image.open('/path/to/img.png'),'a bagel')
# this will return a grayscale Image of the segmented data
grayscale = segmented.to_grayscale()
# this will return a semi-transparent image in which the
# selected object(s) are opaque and the rest is at various
# levels of transparency
transparent = segmented.to_transparent()
# this will return a masked image suitable for use in inpainting:
mask = segmented.to_mask(threshold=0.5)
The threshold used in the call to to_mask() selects pixels for use in
the mask that exceed the indicated confidence threshold. Values range
from 0.0 to 1.0. The higher the threshold, the more confident the
algorithm is. In limited testing, I have found that values around 0.5
work fine.
2023-03-03 06:02:00 +00:00
"""
2023-03-03 06:02:00 +00:00
import numpy as np
import torch
from PIL import Image, ImageOps
2023-03-03 06:02:00 +00:00
from transformers import AutoProcessor, CLIPSegForImageSegmentation
2023-04-29 13:43:40 +00:00
import invokeai.backend.util.logging as logger
2023-03-03 05:02:15 +00:00
from invokeai.backend.globals import global_cache_dir
2023-03-03 06:02:00 +00:00
CLIPSEG_MODEL = "CIDAS/clipseg-rd64-refined"
CLIPSEG_SIZE = 352
2023-03-03 06:02:00 +00:00
class SegmentedGrayscale(object):
2023-03-03 06:02:00 +00:00
def __init__(self, image: Image, heatmap: torch.Tensor):
self.heatmap = heatmap
self.image = image
2023-03-03 06:02:00 +00:00
def to_grayscale(self, invert: bool = False) -> Image:
return self._rescale(
Image.fromarray(
np.uint8(255 - self.heatmap * 255 if invert else self.heatmap * 255)
)
)
2023-03-03 06:02:00 +00:00
def to_mask(self, threshold: float = 0.5) -> Image:
discrete_heatmap = self.heatmap.lt(threshold).int()
2023-03-03 06:02:00 +00:00
return self._rescale(
Image.fromarray(np.uint8(discrete_heatmap * 255), mode="L")
)
2023-03-03 06:02:00 +00:00
def to_transparent(self, invert: bool = False) -> Image:
transparent_image = self.image.copy()
# For img2img, we want the selected regions to be transparent,
# but to_grayscale() returns the opposite. Thus invert.
2022-11-01 20:37:33 +00:00
gs = self.to_grayscale(not invert)
transparent_image.putalpha(gs)
return transparent_image
# unscales and uncrops the 352x352 heatmap so that it matches the image again
2023-03-03 06:02:00 +00:00
def _rescale(self, heatmap: Image) -> Image:
size = (
self.image.width
if (self.image.width > self.image.height)
else self.image.height
)
2023-03-03 06:02:00 +00:00
resized_image = heatmap.resize((size, size), resample=Image.Resampling.LANCZOS)
return resized_image.crop((0, 0, self.image.width, self.image.height))
class Txt2Mask(object):
2023-03-03 06:02:00 +00:00
"""
Create new Txt2Mask object. The optional device argument can be one of
'cuda', 'mps' or 'cpu'.
2023-03-03 06:02:00 +00:00
"""
def __init__(self, device="cpu", refined=False):
2023-04-29 13:43:40 +00:00
logger.info("Initializing clipseg model for text to mask inference")
# BUG: we are not doing anything with the device option at this time
self.device = device
2023-03-03 06:02:00 +00:00
self.processor = AutoProcessor.from_pretrained(
CLIPSEG_MODEL, cache_dir=global_cache_dir("hub")
)
self.model = CLIPSegForImageSegmentation.from_pretrained(
CLIPSEG_MODEL, cache_dir=global_cache_dir("hub")
)
@torch.no_grad()
2023-03-03 06:02:00 +00:00
def segment(self, image, prompt: str) -> SegmentedGrayscale:
"""
Given a prompt string such as "a bagel", tries to identify the object in the
provided image and returns a SegmentedGrayscale object in which the brighter
pixels indicate where the object is inferred to be.
2023-03-03 06:02:00 +00:00
"""
if type(image) is str:
2023-03-03 06:02:00 +00:00
image = Image.open(image).convert("RGB")
image = ImageOps.exif_transpose(image)
img = self._scale_and_crop(image)
2023-03-03 06:02:00 +00:00
inputs = self.processor(
text=[prompt], images=[img], padding=True, return_tensors="pt"
)
outputs = self.model(**inputs)
2023-01-26 14:46:34 +00:00
heatmap = torch.sigmoid(outputs.logits)
return SegmentedGrayscale(image, heatmap)
2023-03-03 06:02:00 +00:00
def _scale_and_crop(self, image: Image) -> Image:
scaled_image = Image.new("RGB", (CLIPSEG_SIZE, CLIPSEG_SIZE))
if image.width > image.height: # width is constraint
scale = CLIPSEG_SIZE / image.width
else:
scale = CLIPSEG_SIZE / image.height
scaled_image.paste(
image.resize(
2023-03-03 06:02:00 +00:00
(int(scale * image.width), int(scale * image.height)),
resample=Image.Resampling.LANCZOS,
),
box=(0, 0),
)
return scaled_image