From 799dc6d0df9e50eafc1043623a0ed7a10ed54e80 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 27 Oct 2022 01:51:35 -0400 Subject: [PATCH] acceptable integration of new prompting system and inpainting This was a difficult merge because both PR #1108 and #1243 made changes to obscure parts of the diffusion code. - prompt weighting, merging and cross-attention working - cross-attention does not work with runwayML inpainting model, but weighting and merging are tested and working - CLI command parsing code rewritten in order to get embedded quotes right - --hires now works with runwayML inpainting - --embiggen does not work with runwayML and will give an error - Added an --invert option to invert masks applied to inpainting - Updated documentation --- docs/features/CLI.md | 5 +++ docs/features/PROMPTS.md | 26 +++++++++-- ldm/generate.py | 53 +++++++++++++++------- ldm/invoke/args.py | 5 +++ ldm/invoke/generator/base.py | 5 ++- ldm/invoke/generator/embiggen.py | 3 ++ ldm/invoke/generator/txt2img2img.py | 69 +++++++++++++++++++++++------ ldm/models/diffusion/sampler.py | 24 +--------- ldm/modules/encoders/modules.py | 1 - scripts/invoke.py | 21 ++++++--- 10 files changed, 148 insertions(+), 64 deletions(-) diff --git a/docs/features/CLI.md b/docs/features/CLI.md index 67a187fb3b..2cc8fa6d05 100644 --- a/docs/features/CLI.md +++ b/docs/features/CLI.md @@ -218,8 +218,13 @@ well as the --mask (-M) and --text_mask (-tm) arguments: | Argument | Shortcut | Default | Description | |--------------------|------------|---------------------|--------------| | `--init_mask ` | `-M` | `None` |Path to an image the same size as the initial_image, with areas for inpainting made transparent.| +| `--invert_mask ` | | False |If true, invert the mask so that transparent areas are opaque and vice versa.| | `--text_mask []` | `-tm []` | | Create a mask from a text prompt describing part of the image| +The mask may either be an image with transparent areas, in which case +the inpainting will occur in the transparent areas only, or a black +and white image, in which case all black areas will be painted into. + `--text_mask` (short form `-tm`) is a way to generate a mask using a text description of the part of the image to replace. For example, if you have an image of a breakfast plate with a bagel, toast and diff --git a/docs/features/PROMPTS.md b/docs/features/PROMPTS.md index 8fdb97b7b8..1a3dcb5e9d 100644 --- a/docs/features/PROMPTS.md +++ b/docs/features/PROMPTS.md @@ -45,7 +45,7 @@ Here's a prompt that depicts what it does. original prompt: -`#!bash "A fantastical translucent poney made of water and foam, ethereal, radiant, hyperalism, scottish folklore, digital painting, artstation, concept art, smooth, 8 k frostbite 3 engine, ultra detailed, art by artgerm and greg rutkowski and magali villeneuve" -s 20 -W 512 -H 768 -C 7.5 -A k_euler_a -S 1654590180` +`#!bash "A fantastical translucent pony made of water and foam, ethereal, radiant, hyperalism, scottish folklore, digital painting, artstation, concept art, smooth, 8 k frostbite 3 engine, ultra detailed, art by artgerm and greg rutkowski and magali villeneuve" -s 20 -W 512 -H 768 -C 7.5 -A k_euler_a -S 1654590180`
![step1](../assets/negative_prompt_walkthru/step1.png) @@ -110,7 +110,10 @@ See the section below on "Prompt Blending" for more information about how this w ### Cross-Attention Control ('prompt2prompt') -Denoise with a given prompt and then re-use the attention→pixel maps to substitute words in the original prompt for words in a new prompt. Based off [bloc97's colab](https://github.com/bloc97/CrossAttentionControl). +Generate an image with a given prompt and then paint over the image +using the `prompt2prompt` syntax to substitute words in the original +prompt for words in a new prompt. Based off [bloc97's +colab](https://github.com/bloc97/CrossAttentionControl). * `a ("fluffy cat").swap("smiling dog") eating a hotdog`. * quotes optional: `a (fluffy cat).swap(smiling dog) eating a hotdog`. @@ -122,9 +125,26 @@ Denoise with a given prompt and then re-use the attention→pixel maps to substi * Convenience option `shape_freedom` (0-1) to specify how much "freedom" Stable Diffusion should have to change the shape of the subject being swapped. * `a (cat).swap(dog, shape_freedom=0.5) eating a hotdog`. +Note that `prompt2prompt` is not currently working with the runwayML +inpainting model, and may never work due to the way this model is set +up. If you attempt to use `prompt2prompt` you will get the original +image back. However, since this model is so good at inpainting, a +good substitute is to use the `clipseg` text masking option: + +``` +invoke> a fluffy cat eating a hotdot +Outputs: +[1010] outputs/000025.2182095108.png: a fluffy cat eating a hotdog +invoke> a smiling dog eating a hotdog -I 000025.2182095108.png -tm cat +``` + ### Escaping parantheses () and speech marks "" -If the model you are using has parentheses () or speech marks "" as part of its syntax, you will need to "escape" these using a backslash, so that`(my_keyword)` becomes `\(my_keyword\)`. Otherwise, the prompt parser will attempt to interpret the parentheses as part of the prompt syntax and it will get confused. +If the model you are using has parentheses () or speech marks "" as +part of its syntax, you will need to "escape" these using a backslash, +so that`(my_keyword)` becomes `\(my_keyword\)`. Otherwise, the prompt +parser will attempt to interpret the parentheses as part of the prompt +syntax and it will get confused. ## **Prompt Blending** diff --git a/ldm/generate.py b/ldm/generate.py index 3785be56bb..ff0357adfd 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -274,6 +274,7 @@ class Generate: init_img = None, init_mask = None, text_mask = None, + invert_mask = False, fit = False, strength = None, init_color = None, @@ -311,6 +312,7 @@ class Generate: init_img // path to an initial image init_mask // path to a mask for the initial image text_mask // a text string that will be used to guide clipseg generation of the init_mask + invert_mask // boolean, if true invert the mask strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely facetool_strength // strength for GFPGAN/CodeFormer. 0.0 preserves image exactly, 1.0 replaces it completely ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image) @@ -418,22 +420,11 @@ class Generate: height, fit=fit, text_mask=text_mask, + invert_mask=invert_mask, ) # TODO: Hacky selection of operation to perform. Needs to be refactored. - if self.sampler.conditioning_key() in ('hybrid','concat'): - print(f'** Inpainting model detected. Will try it! **') - generator = self._make_omnibus() - elif (init_image is not None) and (mask_image is not None): - generator = self._make_inpaint() - elif (embiggen != None or embiggen_tiles != None): - generator = self._make_embiggen() - elif init_image is not None: - generator = self._make_img2img() - elif hires_fix: - generator = self._make_txt2img2img() - else: - generator = self._make_txt2img() + generator = self.select_generator(init_image, mask_image, embiggen, hires_fix) generator.set_variation( self.seed, variation_amount, with_variations @@ -549,7 +540,7 @@ class Generate: # try to reuse the same filename prefix as the original file. # we take everything up to the first period prefix = None - m = re.match('^([^.]+)\.',os.path.basename(image_path)) + m = re.match(r'^([^.]+)\.',os.path.basename(image_path)) if m: prefix = m.groups()[0] @@ -603,10 +594,9 @@ class Generate: elif tool == 'embiggen': # fetch the metadata from the image - generator = self._make_embiggen() + generator = self.select_generator(embiggen=True) opt.strength = 0.40 print(f'>> Setting img2img strength to {opt.strength} for happy embiggening') - # embiggen takes a image path (sigh) generator.generate( prompt, sampler = self.sampler, @@ -640,6 +630,31 @@ class Generate: print(f'* postprocessing tool {tool} is not yet supported') return None + def select_generator( + self, + init_image:Image.Image=None, + mask_image:Image.Image=None, + embiggen:bool=False, + hires_fix:bool=False, + ): + inpainting_model_in_use = self.sampler.uses_inpainting_model() + + if hires_fix: + return self._make_txt2img2img() + + if embiggen is not None: + return self._make_embiggen() + + if inpainting_model_in_use: + return self._make_omnibus() + + if (init_image is not None) and (mask_image is not None): + return self._make_inpaint() + + if init_image is not None: + return self._make_img2img() + + return self._make_txt2img() def _make_images( self, @@ -649,6 +664,7 @@ class Generate: height, fit=False, text_mask=None, + invert_mask=False, ): init_image = None init_mask = None @@ -678,6 +694,9 @@ class Generate: elif text_mask: init_mask = self._txt2mask(image, text_mask, width, height, fit=fit) + if invert_mask: + init_mask = ImageOps.invert(init_mask) + return init_image,init_mask # lots o' repeated code here! Turn into a make_func() @@ -855,6 +874,8 @@ class Generate: def sample_to_image(self, samples): return self._make_base().sample_to_image(samples) + # very repetitive code - can this be simplified? The KSampler names are + # consistent, at least def _set_sampler(self): msg = f'>> Setting Sampler to {self.sampler_name}' if self.sampler_name == 'plms': diff --git a/ldm/invoke/args.py b/ldm/invoke/args.py index 26c5dc1c81..3c1bf2eae4 100644 --- a/ldm/invoke/args.py +++ b/ldm/invoke/args.py @@ -705,6 +705,11 @@ class Args(object): type=str, help='Path to input mask for inpainting mode (supersedes width and height)', ) + img2img_group.add_argument( + '--invert_mask', + action='store_true', + help='Invert the mask', + ) img2img_group.add_argument( '-tm', '--text_mask', diff --git a/ldm/invoke/generator/base.py b/ldm/invoke/generator/base.py index 2e96c93cbb..77b92d693d 100644 --- a/ldm/invoke/generator/base.py +++ b/ldm/invoke/generator/base.py @@ -29,7 +29,8 @@ class Generator(): self.threshold = 0 self.variation_amount = 0 self.with_variations = [] - self.use_mps_noise = False + self.use_mps_noise = False + self.free_gpu_mem = None # this is going to be overridden in img2img.py, txt2img.py and inpaint.py def get_make_image(self,prompt,**kwargs): @@ -50,7 +51,7 @@ class Generator(): **kwargs): scope = choose_autocast(self.precision) self.safety_checker = safety_checker - make_image = self.get_make_image( + make_image = self.get_make_image( prompt, sampler = sampler, init_image = init_image, diff --git a/ldm/invoke/generator/embiggen.py b/ldm/invoke/generator/embiggen.py index 53fbde68cf..dc6af35a6c 100644 --- a/ldm/invoke/generator/embiggen.py +++ b/ldm/invoke/generator/embiggen.py @@ -21,6 +21,7 @@ class Embiggen(Generator): def generate(self,prompt,iterations=1,seed=None, image_callback=None, step_callback=None, **kwargs): + scope = choose_autocast(self.precision) make_image = self.get_make_image( prompt, @@ -63,6 +64,8 @@ class Embiggen(Generator): Returns a function returning an image derived from the prompt and multi-stage twice-baked potato layering over the img2img on the initial image Return value depends on the seed at the time you call it """ + assert not sampler.uses_inpainting_model(), "--embiggen is not supported by inpainting models" + # Construct embiggen arg array, and sanity check arguments if embiggen == None: # embiggen can also be called with just embiggen_tiles embiggen = [1.0] # If not specified, assume no scaling diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index 9fad2d80e1..a11f1bc69d 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -5,10 +5,11 @@ ldm.invoke.generator.txt2img inherits from ldm.invoke.generator import torch import numpy as np import math -from ldm.invoke.generator.base import Generator +from ldm.invoke.generator.base import Generator from ldm.models.diffusion.ddim import DDIMSampler +from ldm.invoke.generator.omnibus import Omnibus from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent - +from PIL import Image class Txt2Img2Img(Generator): def __init__(self, model, precision): @@ -25,16 +26,16 @@ class Txt2Img2Img(Generator): """ uc, c, extra_conditioning_info = conditioning - @torch.no_grad() - def make_image(x_T): - - trained_square = 512 * 512 - actual_square = width * height - scale = math.sqrt(trained_square / actual_square) + trained_square = 512 * 512 + actual_square = width * height + scale = math.sqrt(trained_square / actual_square) + + init_width = math.ceil(scale * width / 64) * 64 + init_height = math.ceil(scale * height / 64) * 64 + + @torch.no_grad() + def make_image(x_T): - init_width = math.ceil(scale * width / 64) * 64 - init_height = math.ceil(scale * height / 64) * 64 - shape = [ self.latent_channels, init_height // self.downsampling_factor, @@ -105,8 +106,49 @@ class Txt2Img2Img(Generator): return self.sample_to_image(samples) - return make_image - + # in the case of the inpainting model being loaded, the trick of + # providing an interpolated latent doesn't work, so we transiently + # create a 512x512 PIL image, upscale it, and run the inpainting + # over it in img2img mode. Because the inpaing model is so conservative + # it doesn't change the image (much) + def inpaint_make_image(x_T): + omnibus = Omnibus(self.model,self.precision) + result = omnibus.generate( + prompt, + sampler=sampler, + width=init_width, + height=init_height, + step_callback=step_callback, + steps = steps, + cfg_scale = cfg_scale, + ddim_eta = ddim_eta, + conditioning = conditioning, + **kwargs + ) + assert result is not None and len(result)>0,'** txt2img failed **' + image = result[0][0] + interpolated_image = image.resize((width,height),resample=Image.Resampling.LANCZOS) + print(kwargs.pop('init_image',None)) + result = omnibus.generate( + prompt, + sampler=sampler, + init_image=interpolated_image, + width=width, + height=height, + seed=result[0][1], + step_callback=step_callback, + steps = steps, + cfg_scale = cfg_scale, + ddim_eta = ddim_eta, + conditioning = conditioning, + **kwargs + ) + return result[0][0] + + if sampler.uses_inpainting_model(): + return inpaint_make_image + else: + return make_image # returns a tensor filled with random numbers from a normal distribution def get_noise(self,width,height,scale = True): @@ -134,3 +176,4 @@ class Txt2Img2Img(Generator): scaled_height // self.downsampling_factor, scaled_width // self.downsampling_factor], device=device) + diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py index f31f5b1758..f92ceb3d3f 100644 --- a/ldm/models/diffusion/sampler.py +++ b/ldm/models/diffusion/sampler.py @@ -272,7 +272,6 @@ class Sampler(object): ) if mask is not None: - print('DEBUG: in masking routine') assert x0 is not None img_orig = self.model.q_sample( x0, ts @@ -438,24 +437,5 @@ class Sampler(object): def conditioning_key(self)->str: return self.model.model.conditioning_key - # def make_cond_in(self, uncond, cond): - # ''' - # This handles the choice between a conditional conditioning - # that is a tensor (used by cross attention) vs one that is a dict - # used by 'hybrid' - # ''' - # if isinstance(cond, dict): - # assert isinstance(uncond, dict) - # cond_in = dict() - # for k in cond: - # if isinstance(cond[k], list): - # cond_in[k] = [ - # torch.cat([uncond[k][i], cond[k][i]]) - # for i in range(len(cond[k])) - # ] - # else: - # cond_in[k] = torch.cat([uncond[k], cond[k]]) - # else: - # cond_in = torch.cat([uncond, cond]) - # return cond_in - + def uses_inpainting_model(self)->bool: + return self.conditioning_key() in ('hybrid','concat') diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 422179e1c6..cf8644a7fb 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -439,7 +439,6 @@ class FrozenCLIPEmbedder(AbstractEncoder): param.requires_grad = False def forward(self, text, **kwargs): - print(f'DEBUG text={text}, max_length={self.max_length}') batch_encoding = self.tokenizer( text, truncation=True, diff --git a/scripts/invoke.py b/scripts/invoke.py index 466536bc46..1937d47ab5 100644 --- a/scripts/invoke.py +++ b/scripts/invoke.py @@ -18,6 +18,7 @@ from ldm.invoke.image_util import make_grid from ldm.invoke.log import write_log from omegaconf import OmegaConf from pathlib import Path +from pyparsing import ParseException # global used in multiple functions (fix) infile = None @@ -328,12 +329,16 @@ def main_loop(gen, opt): if operation == 'generate': catch_ctrl_c = infile is None # if running interactively, we catch keyboard interrupts opt.last_operation='generate' - gen.prompt2image( - image_callback=image_writer, - step_callback=step_callback, - catch_interrupts=catch_ctrl_c, - **vars(opt) - ) + try: + gen.prompt2image( + image_callback=image_writer, + step_callback=step_callback, + catch_interrupts=catch_ctrl_c, + **vars(opt) + ) + except ParseException as e: + print('** An error occurred while processing your prompt **') + print(f'** {str(e)} **') elif operation == 'postprocess': print(f'>> fixing {opt.prompt}') opt.last_operation = do_postprocess(gen,opt,image_writer) @@ -592,7 +597,9 @@ def write_config_file(conf_path, gen, model_name, new_config, clobber=False, mak def do_textmask(gen, opt, callback): image_path = opt.prompt - assert os.path.exists(image_path), '** "{image_path}" not found. Please enter the name of an existing image file to mask **' + if not os.path.exists(image_path): + image_path = os.path.join(opt.outdir,image_path) + assert os.path.exists(image_path), '** "{opt.prompt}" not found. Please enter the name of an existing image file to mask **' assert opt.text_mask is not None and len(opt.text_mask) >= 1, '** Please provide a text mask with -tm **' tm = opt.text_mask[0] threshold = float(opt.text_mask[1]) if len(opt.text_mask) > 1 else 0.5