diff --git a/docs/assets/still-life-inpainted.png b/docs/assets/still-life-inpainted.png new file mode 100644 index 0000000000..ab8c7bd69a Binary files /dev/null and b/docs/assets/still-life-inpainted.png differ diff --git a/docs/assets/still-life-scaled.jpg b/docs/assets/still-life-scaled.jpg new file mode 100644 index 0000000000..ba9c86be00 Binary files /dev/null and b/docs/assets/still-life-scaled.jpg differ diff --git a/docs/features/CLI.md b/docs/features/CLI.md index 1f861d45cb..8c3d04eb50 100644 --- a/docs/features/CLI.md +++ b/docs/features/CLI.md @@ -154,7 +154,7 @@ Here are the invoke> command that apply to txt2img: | --seed | -S | None | Set the random seed for the next series of images. This can be used to recreate an image generated previously.| | --sampler | -A| k_lms | Sampler to use. Use -h to get list of available samplers. | | --hires_fix | | | Larger images often have duplication artefacts. This option suppresses duplicates by generating the image at low res, and then using img2img to increase the resolution | -| `--png_compression <0-9>` | `-z<0-9>` | 6 | Select level of compression for output files, from 0 (no compression) to 9 (max compression) | +| --png_compression <0-9> | -z<0-9> | 6 | Select level of compression for output files, from 0 (no compression) to 9 (max compression) | | --grid | -g | False | Turn on grid mode to return a single image combining all the images generated by this prompt | | --individual | -i | True | Turn off grid mode (deprecated; leave off --grid instead) | | --outdir | -o | outputs/img_samples | Temporarily change the location of these images | @@ -212,11 +212,35 @@ accepts additional options: [Inpainting](./INPAINTING.md) for details. inpainting accepts all the arguments used for txt2img and img2img, as -well as the --mask (-M) argument: +well as the --mask (-M) and --text_mask (-tm) arguments: | Argument | Shortcut | Default | Description | |--------------------|------------|---------------------|--------------| | `--init_mask ` | `-M` | `None` |Path to an image the same size as the initial_image, with areas for inpainting made transparent.| +| `--text_mask []` | `-tm []` | | Create a mask from a text prompt describing part of the image| + +`--text_mask` (short form `-tm`) is a way to generate a mask using a +text description of the part of the image to replace. For example, if +you have an image of a breakfast plate with a bagel, toast and +scrambled eggs, you can selectively mask the bagel and replace it with +a piece of cake this way: + +~~~ +invoke> a piece of cake -I /path/to/breakfast.png -tm bagel +~~~ + +The algorithm uses clipseg to classify +different regions of the image. The classifier puts out a confidence +score for each region it identifies. Generally regions that score +above 0.5 are reliable, but if you are getting too much or too little +masking you can adjust the threshold down (to get more mask), or up +(to get less). In this example, by passing `-tm` a higher value, we +are insisting on a more stringent classification. + +~~~ +invoke> a piece of cake -I /path/to/breakfast.png -tm bagel 0.6 +~~~ # Other Commands diff --git a/docs/features/INPAINTING.md b/docs/features/INPAINTING.md index c488c72d16..ac558917e7 100644 --- a/docs/features/INPAINTING.md +++ b/docs/features/INPAINTING.md @@ -34,7 +34,46 @@ original unedited image and the masked (partially transparent) image: invoke> "man with cat on shoulder" -I./images/man.png -M./images/man-transparent.png ``` -We are hoping to get rid of the need for this workaround in an upcoming release. +## **Masking using Text** + +You can also create a mask using a text prompt to select the part of +the image you want to alter, using the clipseg algorithm. This +works on any image, not just ones generated by InvokeAI. + +The `--text_mask` (short form `-tm`) option takes two arguments. The +first argument is a text description of the part of the image you wish +to mask (paint over). If the text description contains a space, you must +surround it with quotation marks. The optional second argument is the +minimum threshold for the mask classifier's confidence score, described +in more detail below. + +To see how this works in practice, here's an image of a still life +painting that I got off the web. + + + +You can selectively mask out the +orange and replace it with a baseball in this way: + +~~~ +invoke> a baseball -I /path/to/still_life.png -tm orange +~~~ + + + +The clipseg classifier produces a confidence score for each region it +identifies. Generally regions that score above 0.5 are reliable, but +if you are getting too much or too little masking you can adjust the +threshold down (to get more mask), or up (to get less). In this +example, by passing `-tm` a higher value, we are insisting on a tigher +mask. However, if you make it too high, the orange may not be picked +up at all! + +~~~ +invoke> a baseball -I /path/to/breakfast.png -tm orange 0.6 +~~~ + ### Inpainting is not changing the masked region enough! diff --git a/ldm/generate.py b/ldm/generate.py index fb7b0dc26f..4db60af876 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -34,7 +34,8 @@ from ldm.invoke.image_util import InitImageResizer from ldm.invoke.devices import choose_torch_device, choose_precision from ldm.invoke.conditioning import get_uc_and_c from ldm.invoke.model_cache import ModelCache - +from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale + def fix_func(orig): if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): def new_func(*args, **kw): @@ -188,6 +189,7 @@ class Generate: self.esrgan = esrgan self.free_gpu_mem = free_gpu_mem self.size_matters = True # used to warn once about large image sizes and VRAM + self.txt2mask = None # Note that in previous versions, there was an option to pass the # device to Generate(). However the device was then ignored, so @@ -266,6 +268,7 @@ class Generate: # these are specific to img2img and inpaint init_img = None, init_mask = None, + text_mask = None, fit = False, strength = None, init_color = None, @@ -298,6 +301,8 @@ class Generate: seamless // whether the generated image should tile hires_fix // whether the Hires Fix should be applied during generation init_img // path to an initial image + init_mask // path to a mask for the initial image + text_mask // a text string that will be used to guide clipseg generation of the init_mask strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely facetool_strength // strength for GFPGAN/CodeFormer. 0.0 preserves image exactly, 1.0 replaces it completely ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image) @@ -405,6 +410,7 @@ class Generate: width, height, fit=fit, + text_mask=text_mask, ) # TODO: Hacky selection of operation to perform. Needs to be refactored. @@ -620,17 +626,14 @@ class Generate: width, height, fit=False, + text_mask=None, ): init_image = None init_mask = None if not img: return None, None - image = self._load_img( - img, - width, - height, - ) + image = self._load_img(img) if image.width < self.width and image.height < self.height: print(f'>> WARNING: img2img and inpainting may produce unexpected results with initial images smaller than {self.width}x{self.height} in both dimensions') @@ -648,10 +651,12 @@ class Generate: init_image = self._create_init_image(image,width,height,fit=fit) # this returns a torch tensor if mask: - mask_image = self._load_img( - mask, width, height) # this returns an Image + mask_image = self._load_img(mask) # this returns an Image init_mask = self._create_init_mask(mask_image,width,height,fit=fit) + elif text_mask: + init_mask = self._txt2mask(image, text_mask, width, height, fit=fit) + return init_image, init_mask def _make_base(self): @@ -830,7 +835,7 @@ class Generate: print(msg) - def _load_img(self, img, width, height)->Image: + def _load_img(self, img)->Image: if isinstance(img, Image.Image): image = img print( @@ -892,6 +897,29 @@ class Generate: mask = ImageOps.invert(mask) return mask + # TODO: The latter part of this method repeats code from _create_init_mask() + def _txt2mask(self, image:Image, text_mask:list, width, height, fit=True) -> Image: + prompt = text_mask[0] + confidence_level = text_mask[1] if len(text_mask)>1 else 0.5 + if self.txt2mask is None: + self.txt2mask = Txt2Mask(device = self.device) + + segmented = self.txt2mask.segment(image, prompt) + mask = segmented.to_mask(float(confidence_level)) + mask = mask.convert('RGB') + # now we adjust the size + if fit: + mask = self._fit_image(mask, (width, height)) + else: + mask = self._squeeze_image(mask) + mask = mask.resize((mask.width//downsampling, mask.height // + downsampling), resample=Image.Resampling.NEAREST) + mask = np.array(mask) + mask = mask.astype(np.float32) / 255.0 + mask = mask[None].transpose(0, 3, 1, 2) + mask = torch.from_numpy(mask) + return mask.to(self.device) + def _has_transparency(self, image): if image.info.get("transparency", None) is not None: return True diff --git a/ldm/invoke/args.py b/ldm/invoke/args.py index 293d93d5a8..8cb6e6ec59 100644 --- a/ldm/invoke/args.py +++ b/ldm/invoke/args.py @@ -677,6 +677,14 @@ class Args(object): type=str, help='Path to input mask for inpainting mode (supersedes width and height)', ) + img2img_group.add_argument( + '-tm', + '--text_mask', + nargs='+', + type=str, + help='Use the clipseg classifier to generate the mask area for inpainting. Provide a description of the area to mask ("a mug"), optionally followed by the confidence level threshold (0-1.0; defaults to 0.5).', + default=None, + ) img2img_group.add_argument( '--init_color', type=str, diff --git a/ldm/invoke/generator/txt2img.py b/ldm/invoke/generator/txt2img.py index 3502de9be9..edd12c948c 100644 --- a/ldm/invoke/generator/txt2img.py +++ b/ldm/invoke/generator/txt2img.py @@ -74,3 +74,4 @@ class Txt2Img(Generator): if self.perlin > 0.0: x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor) return x + diff --git a/ldm/invoke/readline.py b/ldm/invoke/readline.py index 309a7c55a1..0a79894029 100644 --- a/ldm/invoke/readline.py +++ b/ldm/invoke/readline.py @@ -54,6 +54,7 @@ COMMANDS = ( '--hires_fix', '--inpaint_replace','-r', '--png_compression','-z', + '--text_mask','-tm', '!fix','!fetch','!history','!search','!clear', '!models','!switch','!import_model','!edit_model' ) diff --git a/ldm/invoke/txt2mask.py b/ldm/invoke/txt2mask.py index 3e2b84d591..01d93546e3 100644 --- a/ldm/invoke/txt2mask.py +++ b/ldm/invoke/txt2mask.py @@ -36,6 +36,7 @@ from torchvision import transforms CLIP_VERSION = 'ViT-B/16' CLIPSEG_WEIGHTS = 'src/clipseg/weights/rd64-uni.pth' +CLIPSEG_SIZE = 352 class SegmentedGrayscale(object): def __init__(self, image:Image, heatmap:torch.Tensor): @@ -43,28 +44,39 @@ class SegmentedGrayscale(object): self.image = image def to_grayscale(self)->Image: - return Image.fromarray(np.uint8(self.heatmap*255)) + return self._rescale(Image.fromarray(np.uint8(self.heatmap*255))) def to_mask(self,threshold:float=0.5)->Image: discrete_heatmap = self.heatmap.lt(threshold).int() - return Image.fromarray(np.uint8(discrete_heatmap*255),mode='L') + return self._rescale(Image.fromarray(np.uint8(discrete_heatmap*255),mode='L')) def to_transparent(self)->Image: transparent_image = self.image.copy() - transparent_image.putalpha(self.to_image) + transparent_image.putalpha(self.to_grayscale()) return transparent_image + # unscales and uncrops the 352x352 heatmap so that it matches the image again + def _rescale(self, heatmap:Image)->Image: + size = self.image.width if (self.image.width > self.image.height) else self.image.height + resized_image = heatmap.resize( + (size,size), + resample=Image.Resampling.LANCZOS + ) + return resized_image.crop((0,0,self.image.width,self.image.height)) + class Txt2Mask(object): ''' Create new Txt2Mask object. The optional device argument can be one of 'cuda', 'mps' or 'cpu'. ''' def __init__(self,device='cpu'): - print('>> Initializing clipseg model') + print('>> Initializing clipseg model for text to mask inference') + self.device = device self.model = CLIPDensePredT(version=CLIP_VERSION, reduce_dim=64, ) self.model.eval() - self.model.to(device) - self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS, map_location=torch.device(device)), strict=False) + # initially we keep everything in cpu to conserve space + self.model.to('cpu') + self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS, map_location=torch.device('cpu')), strict=False) @torch.no_grad() def segment(self, image:Image, prompt:str) -> SegmentedGrayscale: @@ -73,18 +85,38 @@ class Txt2Mask(object): provided image and returns a SegmentedGrayscale object in which the brighter pixels indicate where the object is inferred to be. ''' + self._to_device(self.device) prompts = [prompt] # right now we operate on just a single prompt at a time transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), - transforms.Resize((image.width, image.height)), # must be multiple of 64... + transforms.Resize((CLIPSEG_SIZE, CLIPSEG_SIZE)), # must be multiple of 64... ]) - img = transform(image).unsqueeze(0) + + img = self._scale_and_crop(image) + img = transform(img).unsqueeze(0) + preds = self.model(img.repeat(len(prompts),1,1,1), prompts)[0] heatmap = torch.sigmoid(preds[0][0]).cpu() + self._to_device('cpu') return SegmentedGrayscale(image, heatmap) + def _to_device(self, device): + self.model.to(device) - - + def _scale_and_crop(self, image:Image)->Image: + scaled_image = Image.new('RGB',(CLIPSEG_SIZE,CLIPSEG_SIZE)) + if image.width > image.height: # width is constraint + scale = CLIPSEG_SIZE / image.width + else: + scale = CLIPSEG_SIZE / image.height + scaled_image.paste( + image.resize( + (int(scale * image.width), + int(scale * image.height) + ), + resample=Image.Resampling.LANCZOS + ),box=(0,0) + ) + return scaled_image