mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add clipseg support for creating inpaint masks from text
On the command line, the new option is --text_mask or -tm. Example: ``` invoke> a baseball -I /path/to/still_life.png -tm orange ``` This will find the orange fruit in the still life painting and replace it with an image of a baseball.
This commit is contained in:
parent
57bff2a663
commit
5eb0f8ffa7
BIN
docs/assets/still-life-inpainted.png
Normal file
BIN
docs/assets/still-life-inpainted.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 338 KiB |
BIN
docs/assets/still-life-scaled.jpg
Normal file
BIN
docs/assets/still-life-scaled.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 59 KiB |
@ -154,7 +154,7 @@ Here are the invoke> command that apply to txt2img:
|
|||||||
| --seed <int> | -S<int> | None | Set the random seed for the next series of images. This can be used to recreate an image generated previously.|
|
| --seed <int> | -S<int> | None | Set the random seed for the next series of images. This can be used to recreate an image generated previously.|
|
||||||
| --sampler <sampler>| -A<sampler>| k_lms | Sampler to use. Use -h to get list of available samplers. |
|
| --sampler <sampler>| -A<sampler>| k_lms | Sampler to use. Use -h to get list of available samplers. |
|
||||||
| --hires_fix | | | Larger images often have duplication artefacts. This option suppresses duplicates by generating the image at low res, and then using img2img to increase the resolution |
|
| --hires_fix | | | Larger images often have duplication artefacts. This option suppresses duplicates by generating the image at low res, and then using img2img to increase the resolution |
|
||||||
| `--png_compression <0-9>` | `-z<0-9>` | 6 | Select level of compression for output files, from 0 (no compression) to 9 (max compression) |
|
| --png_compression <0-9> | -z<0-9> | 6 | Select level of compression for output files, from 0 (no compression) to 9 (max compression) |
|
||||||
| --grid | -g | False | Turn on grid mode to return a single image combining all the images generated by this prompt |
|
| --grid | -g | False | Turn on grid mode to return a single image combining all the images generated by this prompt |
|
||||||
| --individual | -i | True | Turn off grid mode (deprecated; leave off --grid instead) |
|
| --individual | -i | True | Turn off grid mode (deprecated; leave off --grid instead) |
|
||||||
| --outdir <path> | -o<path> | outputs/img_samples | Temporarily change the location of these images |
|
| --outdir <path> | -o<path> | outputs/img_samples | Temporarily change the location of these images |
|
||||||
@ -212,11 +212,35 @@ accepts additional options:
|
|||||||
[Inpainting](./INPAINTING.md) for details.
|
[Inpainting](./INPAINTING.md) for details.
|
||||||
|
|
||||||
inpainting accepts all the arguments used for txt2img and img2img, as
|
inpainting accepts all the arguments used for txt2img and img2img, as
|
||||||
well as the --mask (-M) argument:
|
well as the --mask (-M) and --text_mask (-tm) arguments:
|
||||||
|
|
||||||
| Argument <img width="100" align="right"/> | Shortcut | Default | Description |
|
| Argument <img width="100" align="right"/> | Shortcut | Default | Description |
|
||||||
|--------------------|------------|---------------------|--------------|
|
|--------------------|------------|---------------------|--------------|
|
||||||
| `--init_mask <path>` | `-M<path>` | `None` |Path to an image the same size as the initial_image, with areas for inpainting made transparent.|
|
| `--init_mask <path>` | `-M<path>` | `None` |Path to an image the same size as the initial_image, with areas for inpainting made transparent.|
|
||||||
|
| `--text_mask <prompt> [<float>]` | `-tm <prompt> [<float>]` | <none> | Create a mask from a text prompt describing part of the image|
|
||||||
|
|
||||||
|
`--text_mask` (short form `-tm`) is a way to generate a mask using a
|
||||||
|
text description of the part of the image to replace. For example, if
|
||||||
|
you have an image of a breakfast plate with a bagel, toast and
|
||||||
|
scrambled eggs, you can selectively mask the bagel and replace it with
|
||||||
|
a piece of cake this way:
|
||||||
|
|
||||||
|
~~~
|
||||||
|
invoke> a piece of cake -I /path/to/breakfast.png -tm bagel
|
||||||
|
~~~
|
||||||
|
|
||||||
|
The algorithm uses <a
|
||||||
|
href="https://github.com/timojl/clipseg">clipseg</a> to classify
|
||||||
|
different regions of the image. The classifier puts out a confidence
|
||||||
|
score for each region it identifies. Generally regions that score
|
||||||
|
above 0.5 are reliable, but if you are getting too much or too little
|
||||||
|
masking you can adjust the threshold down (to get more mask), or up
|
||||||
|
(to get less). In this example, by passing `-tm` a higher value, we
|
||||||
|
are insisting on a more stringent classification.
|
||||||
|
|
||||||
|
~~~
|
||||||
|
invoke> a piece of cake -I /path/to/breakfast.png -tm bagel 0.6
|
||||||
|
~~~
|
||||||
|
|
||||||
# Other Commands
|
# Other Commands
|
||||||
|
|
||||||
|
@ -34,7 +34,46 @@ original unedited image and the masked (partially transparent) image:
|
|||||||
invoke> "man with cat on shoulder" -I./images/man.png -M./images/man-transparent.png
|
invoke> "man with cat on shoulder" -I./images/man.png -M./images/man-transparent.png
|
||||||
```
|
```
|
||||||
|
|
||||||
We are hoping to get rid of the need for this workaround in an upcoming release.
|
## **Masking using Text**
|
||||||
|
|
||||||
|
You can also create a mask using a text prompt to select the part of
|
||||||
|
the image you want to alter, using the <a
|
||||||
|
href="https://github.com/timojl/clipseg">clipseg</a> algorithm. This
|
||||||
|
works on any image, not just ones generated by InvokeAI.
|
||||||
|
|
||||||
|
The `--text_mask` (short form `-tm`) option takes two arguments. The
|
||||||
|
first argument is a text description of the part of the image you wish
|
||||||
|
to mask (paint over). If the text description contains a space, you must
|
||||||
|
surround it with quotation marks. The optional second argument is the
|
||||||
|
minimum threshold for the mask classifier's confidence score, described
|
||||||
|
in more detail below.
|
||||||
|
|
||||||
|
To see how this works in practice, here's an image of a still life
|
||||||
|
painting that I got off the web.
|
||||||
|
|
||||||
|
<img src="../assets/still-life-scaled.jpg">
|
||||||
|
|
||||||
|
You can selectively mask out the
|
||||||
|
orange and replace it with a baseball in this way:
|
||||||
|
|
||||||
|
~~~
|
||||||
|
invoke> a baseball -I /path/to/still_life.png -tm orange
|
||||||
|
~~~
|
||||||
|
|
||||||
|
<img src="../assets/still-life-inpainted.png">
|
||||||
|
|
||||||
|
The clipseg classifier produces a confidence score for each region it
|
||||||
|
identifies. Generally regions that score above 0.5 are reliable, but
|
||||||
|
if you are getting too much or too little masking you can adjust the
|
||||||
|
threshold down (to get more mask), or up (to get less). In this
|
||||||
|
example, by passing `-tm` a higher value, we are insisting on a tigher
|
||||||
|
mask. However, if you make it too high, the orange may not be picked
|
||||||
|
up at all!
|
||||||
|
|
||||||
|
~~~
|
||||||
|
invoke> a baseball -I /path/to/breakfast.png -tm orange 0.6
|
||||||
|
~~~
|
||||||
|
|
||||||
|
|
||||||
### Inpainting is not changing the masked region enough!
|
### Inpainting is not changing the masked region enough!
|
||||||
|
|
||||||
|
@ -34,7 +34,8 @@ from ldm.invoke.image_util import InitImageResizer
|
|||||||
from ldm.invoke.devices import choose_torch_device, choose_precision
|
from ldm.invoke.devices import choose_torch_device, choose_precision
|
||||||
from ldm.invoke.conditioning import get_uc_and_c
|
from ldm.invoke.conditioning import get_uc_and_c
|
||||||
from ldm.invoke.model_cache import ModelCache
|
from ldm.invoke.model_cache import ModelCache
|
||||||
|
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
|
||||||
|
|
||||||
def fix_func(orig):
|
def fix_func(orig):
|
||||||
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
||||||
def new_func(*args, **kw):
|
def new_func(*args, **kw):
|
||||||
@ -188,6 +189,7 @@ class Generate:
|
|||||||
self.esrgan = esrgan
|
self.esrgan = esrgan
|
||||||
self.free_gpu_mem = free_gpu_mem
|
self.free_gpu_mem = free_gpu_mem
|
||||||
self.size_matters = True # used to warn once about large image sizes and VRAM
|
self.size_matters = True # used to warn once about large image sizes and VRAM
|
||||||
|
self.txt2mask = None
|
||||||
|
|
||||||
# Note that in previous versions, there was an option to pass the
|
# Note that in previous versions, there was an option to pass the
|
||||||
# device to Generate(). However the device was then ignored, so
|
# device to Generate(). However the device was then ignored, so
|
||||||
@ -266,6 +268,7 @@ class Generate:
|
|||||||
# these are specific to img2img and inpaint
|
# these are specific to img2img and inpaint
|
||||||
init_img = None,
|
init_img = None,
|
||||||
init_mask = None,
|
init_mask = None,
|
||||||
|
text_mask = None,
|
||||||
fit = False,
|
fit = False,
|
||||||
strength = None,
|
strength = None,
|
||||||
init_color = None,
|
init_color = None,
|
||||||
@ -298,6 +301,8 @@ class Generate:
|
|||||||
seamless // whether the generated image should tile
|
seamless // whether the generated image should tile
|
||||||
hires_fix // whether the Hires Fix should be applied during generation
|
hires_fix // whether the Hires Fix should be applied during generation
|
||||||
init_img // path to an initial image
|
init_img // path to an initial image
|
||||||
|
init_mask // path to a mask for the initial image
|
||||||
|
text_mask // a text string that will be used to guide clipseg generation of the init_mask
|
||||||
strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely
|
strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely
|
||||||
facetool_strength // strength for GFPGAN/CodeFormer. 0.0 preserves image exactly, 1.0 replaces it completely
|
facetool_strength // strength for GFPGAN/CodeFormer. 0.0 preserves image exactly, 1.0 replaces it completely
|
||||||
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
|
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
|
||||||
@ -405,6 +410,7 @@ class Generate:
|
|||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
fit=fit,
|
fit=fit,
|
||||||
|
text_mask=text_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Hacky selection of operation to perform. Needs to be refactored.
|
# TODO: Hacky selection of operation to perform. Needs to be refactored.
|
||||||
@ -620,17 +626,14 @@ class Generate:
|
|||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
fit=False,
|
fit=False,
|
||||||
|
text_mask=None,
|
||||||
):
|
):
|
||||||
init_image = None
|
init_image = None
|
||||||
init_mask = None
|
init_mask = None
|
||||||
if not img:
|
if not img:
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
image = self._load_img(
|
image = self._load_img(img)
|
||||||
img,
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
)
|
|
||||||
|
|
||||||
if image.width < self.width and image.height < self.height:
|
if image.width < self.width and image.height < self.height:
|
||||||
print(f'>> WARNING: img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions')
|
print(f'>> WARNING: img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions')
|
||||||
@ -648,10 +651,12 @@ class Generate:
|
|||||||
init_image = self._create_init_image(image,width,height,fit=fit) # this returns a torch tensor
|
init_image = self._create_init_image(image,width,height,fit=fit) # this returns a torch tensor
|
||||||
|
|
||||||
if mask:
|
if mask:
|
||||||
mask_image = self._load_img(
|
mask_image = self._load_img(mask) # this returns an Image
|
||||||
mask, width, height) # this returns an Image
|
|
||||||
init_mask = self._create_init_mask(mask_image,width,height,fit=fit)
|
init_mask = self._create_init_mask(mask_image,width,height,fit=fit)
|
||||||
|
|
||||||
|
elif text_mask:
|
||||||
|
init_mask = self._txt2mask(image, text_mask, width, height, fit=fit)
|
||||||
|
|
||||||
return init_image, init_mask
|
return init_image, init_mask
|
||||||
|
|
||||||
def _make_base(self):
|
def _make_base(self):
|
||||||
@ -830,7 +835,7 @@ class Generate:
|
|||||||
|
|
||||||
print(msg)
|
print(msg)
|
||||||
|
|
||||||
def _load_img(self, img, width, height)->Image:
|
def _load_img(self, img)->Image:
|
||||||
if isinstance(img, Image.Image):
|
if isinstance(img, Image.Image):
|
||||||
image = img
|
image = img
|
||||||
print(
|
print(
|
||||||
@ -892,6 +897,29 @@ class Generate:
|
|||||||
mask = ImageOps.invert(mask)
|
mask = ImageOps.invert(mask)
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
# TODO: The latter part of this method repeats code from _create_init_mask()
|
||||||
|
def _txt2mask(self, image:Image, text_mask:list, width, height, fit=True) -> Image:
|
||||||
|
prompt = text_mask[0]
|
||||||
|
confidence_level = text_mask[1] if len(text_mask)>1 else 0.5
|
||||||
|
if self.txt2mask is None:
|
||||||
|
self.txt2mask = Txt2Mask(device = self.device)
|
||||||
|
|
||||||
|
segmented = self.txt2mask.segment(image, prompt)
|
||||||
|
mask = segmented.to_mask(float(confidence_level))
|
||||||
|
mask = mask.convert('RGB')
|
||||||
|
# now we adjust the size
|
||||||
|
if fit:
|
||||||
|
mask = self._fit_image(mask, (width, height))
|
||||||
|
else:
|
||||||
|
mask = self._squeeze_image(mask)
|
||||||
|
mask = mask.resize((mask.width//downsampling, mask.height //
|
||||||
|
downsampling), resample=Image.Resampling.NEAREST)
|
||||||
|
mask = np.array(mask)
|
||||||
|
mask = mask.astype(np.float32) / 255.0
|
||||||
|
mask = mask[None].transpose(0, 3, 1, 2)
|
||||||
|
mask = torch.from_numpy(mask)
|
||||||
|
return mask.to(self.device)
|
||||||
|
|
||||||
def _has_transparency(self, image):
|
def _has_transparency(self, image):
|
||||||
if image.info.get("transparency", None) is not None:
|
if image.info.get("transparency", None) is not None:
|
||||||
return True
|
return True
|
||||||
|
@ -678,6 +678,14 @@ class Args(object):
|
|||||||
type=str,
|
type=str,
|
||||||
help='Path to input mask for inpainting mode (supersedes width and height)',
|
help='Path to input mask for inpainting mode (supersedes width and height)',
|
||||||
)
|
)
|
||||||
|
img2img_group.add_argument(
|
||||||
|
'-tm',
|
||||||
|
'--text_mask',
|
||||||
|
nargs='+',
|
||||||
|
type=str,
|
||||||
|
help='Use the clipseg classifier to generate the mask area for inpainting. Provide a description of the area to mask ("a mug"), optionally followed by the confidence level threshold (0-1.0; defaults to 0.5).',
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
img2img_group.add_argument(
|
img2img_group.add_argument(
|
||||||
'--init_color',
|
'--init_color',
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -74,3 +74,4 @@ class Txt2Img(Generator):
|
|||||||
if self.perlin > 0.0:
|
if self.perlin > 0.0:
|
||||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
|
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -54,6 +54,7 @@ COMMANDS = (
|
|||||||
'--hires_fix',
|
'--hires_fix',
|
||||||
'--inpaint_replace','-r',
|
'--inpaint_replace','-r',
|
||||||
'--png_compression','-z',
|
'--png_compression','-z',
|
||||||
|
'--text_mask','-tm',
|
||||||
'!fix','!fetch','!history','!search','!clear',
|
'!fix','!fetch','!history','!search','!clear',
|
||||||
'!models','!switch','!import_model','!edit_model'
|
'!models','!switch','!import_model','!edit_model'
|
||||||
)
|
)
|
||||||
|
@ -36,6 +36,7 @@ from torchvision import transforms
|
|||||||
|
|
||||||
CLIP_VERSION = 'ViT-B/16'
|
CLIP_VERSION = 'ViT-B/16'
|
||||||
CLIPSEG_WEIGHTS = 'src/clipseg/weights/rd64-uni.pth'
|
CLIPSEG_WEIGHTS = 'src/clipseg/weights/rd64-uni.pth'
|
||||||
|
CLIPSEG_SIZE = 352
|
||||||
|
|
||||||
class SegmentedGrayscale(object):
|
class SegmentedGrayscale(object):
|
||||||
def __init__(self, image:Image, heatmap:torch.Tensor):
|
def __init__(self, image:Image, heatmap:torch.Tensor):
|
||||||
@ -43,28 +44,39 @@ class SegmentedGrayscale(object):
|
|||||||
self.image = image
|
self.image = image
|
||||||
|
|
||||||
def to_grayscale(self)->Image:
|
def to_grayscale(self)->Image:
|
||||||
return Image.fromarray(np.uint8(self.heatmap*255))
|
return self._rescale(Image.fromarray(np.uint8(self.heatmap*255)))
|
||||||
|
|
||||||
def to_mask(self,threshold:float=0.5)->Image:
|
def to_mask(self,threshold:float=0.5)->Image:
|
||||||
discrete_heatmap = self.heatmap.lt(threshold).int()
|
discrete_heatmap = self.heatmap.lt(threshold).int()
|
||||||
return Image.fromarray(np.uint8(discrete_heatmap*255),mode='L')
|
return self._rescale(Image.fromarray(np.uint8(discrete_heatmap*255),mode='L'))
|
||||||
|
|
||||||
def to_transparent(self)->Image:
|
def to_transparent(self)->Image:
|
||||||
transparent_image = self.image.copy()
|
transparent_image = self.image.copy()
|
||||||
transparent_image.putalpha(self.to_image)
|
transparent_image.putalpha(self.to_grayscale())
|
||||||
return transparent_image
|
return transparent_image
|
||||||
|
|
||||||
|
# unscales and uncrops the 352x352 heatmap so that it matches the image again
|
||||||
|
def _rescale(self, heatmap:Image)->Image:
|
||||||
|
size = self.image.width if (self.image.width > self.image.height) else self.image.height
|
||||||
|
resized_image = heatmap.resize(
|
||||||
|
(size,size),
|
||||||
|
resample=Image.Resampling.LANCZOS
|
||||||
|
)
|
||||||
|
return resized_image.crop((0,0,self.image.width,self.image.height))
|
||||||
|
|
||||||
class Txt2Mask(object):
|
class Txt2Mask(object):
|
||||||
'''
|
'''
|
||||||
Create new Txt2Mask object. The optional device argument can be one of
|
Create new Txt2Mask object. The optional device argument can be one of
|
||||||
'cuda', 'mps' or 'cpu'.
|
'cuda', 'mps' or 'cpu'.
|
||||||
'''
|
'''
|
||||||
def __init__(self,device='cpu'):
|
def __init__(self,device='cpu'):
|
||||||
print('>> Initializing clipseg model')
|
print('>> Initializing clipseg model for text to mask inference')
|
||||||
|
self.device = device
|
||||||
self.model = CLIPDensePredT(version=CLIP_VERSION, reduce_dim=64, )
|
self.model = CLIPDensePredT(version=CLIP_VERSION, reduce_dim=64, )
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
self.model.to(device)
|
# initially we keep everything in cpu to conserve space
|
||||||
self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS, map_location=torch.device(device)), strict=False)
|
self.model.to('cpu')
|
||||||
|
self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS, map_location=torch.device('cpu')), strict=False)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def segment(self, image:Image, prompt:str) -> SegmentedGrayscale:
|
def segment(self, image:Image, prompt:str) -> SegmentedGrayscale:
|
||||||
@ -73,18 +85,38 @@ class Txt2Mask(object):
|
|||||||
provided image and returns a SegmentedGrayscale object in which the brighter
|
provided image and returns a SegmentedGrayscale object in which the brighter
|
||||||
pixels indicate where the object is inferred to be.
|
pixels indicate where the object is inferred to be.
|
||||||
'''
|
'''
|
||||||
|
self._to_device(self.device)
|
||||||
prompts = [prompt] # right now we operate on just a single prompt at a time
|
prompts = [prompt] # right now we operate on just a single prompt at a time
|
||||||
|
|
||||||
transform = transforms.Compose([
|
transform = transforms.Compose([
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
transforms.Resize((image.width, image.height)), # must be multiple of 64...
|
transforms.Resize((CLIPSEG_SIZE, CLIPSEG_SIZE)), # must be multiple of 64...
|
||||||
])
|
])
|
||||||
img = transform(image).unsqueeze(0)
|
|
||||||
|
img = self._scale_and_crop(image)
|
||||||
|
img = transform(img).unsqueeze(0)
|
||||||
|
|
||||||
preds = self.model(img.repeat(len(prompts),1,1,1), prompts)[0]
|
preds = self.model(img.repeat(len(prompts),1,1,1), prompts)[0]
|
||||||
heatmap = torch.sigmoid(preds[0][0]).cpu()
|
heatmap = torch.sigmoid(preds[0][0]).cpu()
|
||||||
|
self._to_device('cpu')
|
||||||
return SegmentedGrayscale(image, heatmap)
|
return SegmentedGrayscale(image, heatmap)
|
||||||
|
|
||||||
|
def _to_device(self, device):
|
||||||
|
self.model.to(device)
|
||||||
|
|
||||||
|
def _scale_and_crop(self, image:Image)->Image:
|
||||||
|
scaled_image = Image.new('RGB',(CLIPSEG_SIZE,CLIPSEG_SIZE))
|
||||||
|
if image.width > image.height: # width is constraint
|
||||||
|
scale = CLIPSEG_SIZE / image.width
|
||||||
|
else:
|
||||||
|
scale = CLIPSEG_SIZE / image.height
|
||||||
|
scaled_image.paste(
|
||||||
|
image.resize(
|
||||||
|
(int(scale * image.width),
|
||||||
|
int(scale * image.height)
|
||||||
|
),
|
||||||
|
resample=Image.Resampling.LANCZOS
|
||||||
|
),box=(0,0)
|
||||||
|
)
|
||||||
|
return scaled_image
|
||||||
|
Loading…
x
Reference in New Issue
Block a user