acceptable integration of new prompting system and inpainting

This was a difficult merge because both PR #1108 and #1243 made
changes to obscure parts of the diffusion code.

- prompt weighting, merging and cross-attention working
  - cross-attention does not work with runwayML inpainting
    model, but weighting and merging are tested and working
- CLI command parsing code rewritten in order to get embedded
  quotes right
- --hires now works with runwayML inpainting
- --embiggen does not work with runwayML and will give an error
- Added an --invert option to invert masks applied to inpainting
- Updated documentation
This commit is contained in:
Lincoln Stein 2022-10-27 01:51:35 -04:00
parent 79689e87ce
commit 799dc6d0df
10 changed files with 148 additions and 64 deletions

View File

@ -218,8 +218,13 @@ well as the --mask (-M) and --text_mask (-tm) arguments:
| Argument <img width="100" align="right"/> | Shortcut | Default | Description | | Argument <img width="100" align="right"/> | Shortcut | Default | Description |
|--------------------|------------|---------------------|--------------| |--------------------|------------|---------------------|--------------|
| `--init_mask <path>` | `-M<path>` | `None` |Path to an image the same size as the initial_image, with areas for inpainting made transparent.| | `--init_mask <path>` | `-M<path>` | `None` |Path to an image the same size as the initial_image, with areas for inpainting made transparent.|
| `--invert_mask ` | | False |If true, invert the mask so that transparent areas are opaque and vice versa.|
| `--text_mask <prompt> [<float>]` | `-tm <prompt> [<float>]` | <none> | Create a mask from a text prompt describing part of the image| | `--text_mask <prompt> [<float>]` | `-tm <prompt> [<float>]` | <none> | Create a mask from a text prompt describing part of the image|
The mask may either be an image with transparent areas, in which case
the inpainting will occur in the transparent areas only, or a black
and white image, in which case all black areas will be painted into.
`--text_mask` (short form `-tm`) is a way to generate a mask using a `--text_mask` (short form `-tm`) is a way to generate a mask using a
text description of the part of the image to replace. For example, if text description of the part of the image to replace. For example, if
you have an image of a breakfast plate with a bagel, toast and you have an image of a breakfast plate with a bagel, toast and

View File

@ -45,7 +45,7 @@ Here's a prompt that depicts what it does.
original prompt: original prompt:
`#!bash "A fantastical translucent poney made of water and foam, ethereal, radiant, hyperalism, scottish folklore, digital painting, artstation, concept art, smooth, 8 k frostbite 3 engine, ultra detailed, art by artgerm and greg rutkowski and magali villeneuve" -s 20 -W 512 -H 768 -C 7.5 -A k_euler_a -S 1654590180` `#!bash "A fantastical translucent pony made of water and foam, ethereal, radiant, hyperalism, scottish folklore, digital painting, artstation, concept art, smooth, 8 k frostbite 3 engine, ultra detailed, art by artgerm and greg rutkowski and magali villeneuve" -s 20 -W 512 -H 768 -C 7.5 -A k_euler_a -S 1654590180`
<div align="center" markdown> <div align="center" markdown>
![step1](../assets/negative_prompt_walkthru/step1.png) ![step1](../assets/negative_prompt_walkthru/step1.png)
@ -110,7 +110,10 @@ See the section below on "Prompt Blending" for more information about how this w
### Cross-Attention Control ('prompt2prompt') ### Cross-Attention Control ('prompt2prompt')
Denoise with a given prompt and then re-use the attention→pixel maps to substitute words in the original prompt for words in a new prompt. Based off [bloc97's colab](https://github.com/bloc97/CrossAttentionControl). Generate an image with a given prompt and then paint over the image
using the `prompt2prompt` syntax to substitute words in the original
prompt for words in a new prompt. Based off [bloc97's
colab](https://github.com/bloc97/CrossAttentionControl).
* `a ("fluffy cat").swap("smiling dog") eating a hotdog`. * `a ("fluffy cat").swap("smiling dog") eating a hotdog`.
* quotes optional: `a (fluffy cat).swap(smiling dog) eating a hotdog`. * quotes optional: `a (fluffy cat).swap(smiling dog) eating a hotdog`.
@ -122,9 +125,26 @@ Denoise with a given prompt and then re-use the attention→pixel maps to substi
* Convenience option `shape_freedom` (0-1) to specify how much "freedom" Stable Diffusion should have to change the shape of the subject being swapped. * Convenience option `shape_freedom` (0-1) to specify how much "freedom" Stable Diffusion should have to change the shape of the subject being swapped.
* `a (cat).swap(dog, shape_freedom=0.5) eating a hotdog`. * `a (cat).swap(dog, shape_freedom=0.5) eating a hotdog`.
Note that `prompt2prompt` is not currently working with the runwayML
inpainting model, and may never work due to the way this model is set
up. If you attempt to use `prompt2prompt` you will get the original
image back. However, since this model is so good at inpainting, a
good substitute is to use the `clipseg` text masking option:
```
invoke> a fluffy cat eating a hotdot
Outputs:
[1010] outputs/000025.2182095108.png: a fluffy cat eating a hotdog
invoke> a smiling dog eating a hotdog -I 000025.2182095108.png -tm cat
```
### Escaping parantheses () and speech marks "" ### Escaping parantheses () and speech marks ""
If the model you are using has parentheses () or speech marks "" as part of its syntax, you will need to "escape" these using a backslash, so that`(my_keyword)` becomes `\(my_keyword\)`. Otherwise, the prompt parser will attempt to interpret the parentheses as part of the prompt syntax and it will get confused. If the model you are using has parentheses () or speech marks "" as
part of its syntax, you will need to "escape" these using a backslash,
so that`(my_keyword)` becomes `\(my_keyword\)`. Otherwise, the prompt
parser will attempt to interpret the parentheses as part of the prompt
syntax and it will get confused.
## **Prompt Blending** ## **Prompt Blending**

View File

@ -274,6 +274,7 @@ class Generate:
init_img = None, init_img = None,
init_mask = None, init_mask = None,
text_mask = None, text_mask = None,
invert_mask = False,
fit = False, fit = False,
strength = None, strength = None,
init_color = None, init_color = None,
@ -311,6 +312,7 @@ class Generate:
init_img // path to an initial image init_img // path to an initial image
init_mask // path to a mask for the initial image init_mask // path to a mask for the initial image
text_mask // a text string that will be used to guide clipseg generation of the init_mask text_mask // a text string that will be used to guide clipseg generation of the init_mask
invert_mask // boolean, if true invert the mask
strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely
facetool_strength // strength for GFPGAN/CodeFormer. 0.0 preserves image exactly, 1.0 replaces it completely facetool_strength // strength for GFPGAN/CodeFormer. 0.0 preserves image exactly, 1.0 replaces it completely
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image) ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
@ -418,22 +420,11 @@ class Generate:
height, height,
fit=fit, fit=fit,
text_mask=text_mask, text_mask=text_mask,
invert_mask=invert_mask,
) )
# TODO: Hacky selection of operation to perform. Needs to be refactored. # TODO: Hacky selection of operation to perform. Needs to be refactored.
if self.sampler.conditioning_key() in ('hybrid','concat'): generator = self.select_generator(init_image, mask_image, embiggen, hires_fix)
print(f'** Inpainting model detected. Will try it! **')
generator = self._make_omnibus()
elif (init_image is not None) and (mask_image is not None):
generator = self._make_inpaint()
elif (embiggen != None or embiggen_tiles != None):
generator = self._make_embiggen()
elif init_image is not None:
generator = self._make_img2img()
elif hires_fix:
generator = self._make_txt2img2img()
else:
generator = self._make_txt2img()
generator.set_variation( generator.set_variation(
self.seed, variation_amount, with_variations self.seed, variation_amount, with_variations
@ -549,7 +540,7 @@ class Generate:
# try to reuse the same filename prefix as the original file. # try to reuse the same filename prefix as the original file.
# we take everything up to the first period # we take everything up to the first period
prefix = None prefix = None
m = re.match('^([^.]+)\.',os.path.basename(image_path)) m = re.match(r'^([^.]+)\.',os.path.basename(image_path))
if m: if m:
prefix = m.groups()[0] prefix = m.groups()[0]
@ -603,10 +594,9 @@ class Generate:
elif tool == 'embiggen': elif tool == 'embiggen':
# fetch the metadata from the image # fetch the metadata from the image
generator = self._make_embiggen() generator = self.select_generator(embiggen=True)
opt.strength = 0.40 opt.strength = 0.40
print(f'>> Setting img2img strength to {opt.strength} for happy embiggening') print(f'>> Setting img2img strength to {opt.strength} for happy embiggening')
# embiggen takes a image path (sigh)
generator.generate( generator.generate(
prompt, prompt,
sampler = self.sampler, sampler = self.sampler,
@ -640,6 +630,31 @@ class Generate:
print(f'* postprocessing tool {tool} is not yet supported') print(f'* postprocessing tool {tool} is not yet supported')
return None return None
def select_generator(
self,
init_image:Image.Image=None,
mask_image:Image.Image=None,
embiggen:bool=False,
hires_fix:bool=False,
):
inpainting_model_in_use = self.sampler.uses_inpainting_model()
if hires_fix:
return self._make_txt2img2img()
if embiggen is not None:
return self._make_embiggen()
if inpainting_model_in_use:
return self._make_omnibus()
if (init_image is not None) and (mask_image is not None):
return self._make_inpaint()
if init_image is not None:
return self._make_img2img()
return self._make_txt2img()
def _make_images( def _make_images(
self, self,
@ -649,6 +664,7 @@ class Generate:
height, height,
fit=False, fit=False,
text_mask=None, text_mask=None,
invert_mask=False,
): ):
init_image = None init_image = None
init_mask = None init_mask = None
@ -678,6 +694,9 @@ class Generate:
elif text_mask: elif text_mask:
init_mask = self._txt2mask(image, text_mask, width, height, fit=fit) init_mask = self._txt2mask(image, text_mask, width, height, fit=fit)
if invert_mask:
init_mask = ImageOps.invert(init_mask)
return init_image,init_mask return init_image,init_mask
# lots o' repeated code here! Turn into a make_func() # lots o' repeated code here! Turn into a make_func()
@ -855,6 +874,8 @@ class Generate:
def sample_to_image(self, samples): def sample_to_image(self, samples):
return self._make_base().sample_to_image(samples) return self._make_base().sample_to_image(samples)
# very repetitive code - can this be simplified? The KSampler names are
# consistent, at least
def _set_sampler(self): def _set_sampler(self):
msg = f'>> Setting Sampler to {self.sampler_name}' msg = f'>> Setting Sampler to {self.sampler_name}'
if self.sampler_name == 'plms': if self.sampler_name == 'plms':

View File

@ -705,6 +705,11 @@ class Args(object):
type=str, type=str,
help='Path to input mask for inpainting mode (supersedes width and height)', help='Path to input mask for inpainting mode (supersedes width and height)',
) )
img2img_group.add_argument(
'--invert_mask',
action='store_true',
help='Invert the mask',
)
img2img_group.add_argument( img2img_group.add_argument(
'-tm', '-tm',
'--text_mask', '--text_mask',

View File

@ -29,7 +29,8 @@ class Generator():
self.threshold = 0 self.threshold = 0
self.variation_amount = 0 self.variation_amount = 0
self.with_variations = [] self.with_variations = []
self.use_mps_noise = False self.use_mps_noise = False
self.free_gpu_mem = None
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py # this is going to be overridden in img2img.py, txt2img.py and inpaint.py
def get_make_image(self,prompt,**kwargs): def get_make_image(self,prompt,**kwargs):
@ -50,7 +51,7 @@ class Generator():
**kwargs): **kwargs):
scope = choose_autocast(self.precision) scope = choose_autocast(self.precision)
self.safety_checker = safety_checker self.safety_checker = safety_checker
make_image = self.get_make_image( make_image = self.get_make_image(
prompt, prompt,
sampler = sampler, sampler = sampler,
init_image = init_image, init_image = init_image,

View File

@ -21,6 +21,7 @@ class Embiggen(Generator):
def generate(self,prompt,iterations=1,seed=None, def generate(self,prompt,iterations=1,seed=None,
image_callback=None, step_callback=None, image_callback=None, step_callback=None,
**kwargs): **kwargs):
scope = choose_autocast(self.precision) scope = choose_autocast(self.precision)
make_image = self.get_make_image( make_image = self.get_make_image(
prompt, prompt,
@ -63,6 +64,8 @@ class Embiggen(Generator):
Returns a function returning an image derived from the prompt and multi-stage twice-baked potato layering over the img2img on the initial image Returns a function returning an image derived from the prompt and multi-stage twice-baked potato layering over the img2img on the initial image
Return value depends on the seed at the time you call it Return value depends on the seed at the time you call it
""" """
assert not sampler.uses_inpainting_model(), "--embiggen is not supported by inpainting models"
# Construct embiggen arg array, and sanity check arguments # Construct embiggen arg array, and sanity check arguments
if embiggen == None: # embiggen can also be called with just embiggen_tiles if embiggen == None: # embiggen can also be called with just embiggen_tiles
embiggen = [1.0] # If not specified, assume no scaling embiggen = [1.0] # If not specified, assume no scaling

View File

@ -5,10 +5,11 @@ ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
import torch import torch
import numpy as np import numpy as np
import math import math
from ldm.invoke.generator.base import Generator from ldm.invoke.generator.base import Generator
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.invoke.generator.omnibus import Omnibus
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from PIL import Image
class Txt2Img2Img(Generator): class Txt2Img2Img(Generator):
def __init__(self, model, precision): def __init__(self, model, precision):
@ -25,16 +26,16 @@ class Txt2Img2Img(Generator):
""" """
uc, c, extra_conditioning_info = conditioning uc, c, extra_conditioning_info = conditioning
@torch.no_grad() trained_square = 512 * 512
def make_image(x_T): actual_square = width * height
scale = math.sqrt(trained_square / actual_square)
trained_square = 512 * 512
actual_square = width * height init_width = math.ceil(scale * width / 64) * 64
scale = math.sqrt(trained_square / actual_square) init_height = math.ceil(scale * height / 64) * 64
@torch.no_grad()
def make_image(x_T):
init_width = math.ceil(scale * width / 64) * 64
init_height = math.ceil(scale * height / 64) * 64
shape = [ shape = [
self.latent_channels, self.latent_channels,
init_height // self.downsampling_factor, init_height // self.downsampling_factor,
@ -105,8 +106,49 @@ class Txt2Img2Img(Generator):
return self.sample_to_image(samples) return self.sample_to_image(samples)
return make_image # in the case of the inpainting model being loaded, the trick of
# providing an interpolated latent doesn't work, so we transiently
# create a 512x512 PIL image, upscale it, and run the inpainting
# over it in img2img mode. Because the inpaing model is so conservative
# it doesn't change the image (much)
def inpaint_make_image(x_T):
omnibus = Omnibus(self.model,self.precision)
result = omnibus.generate(
prompt,
sampler=sampler,
width=init_width,
height=init_height,
step_callback=step_callback,
steps = steps,
cfg_scale = cfg_scale,
ddim_eta = ddim_eta,
conditioning = conditioning,
**kwargs
)
assert result is not None and len(result)>0,'** txt2img failed **'
image = result[0][0]
interpolated_image = image.resize((width,height),resample=Image.Resampling.LANCZOS)
print(kwargs.pop('init_image',None))
result = omnibus.generate(
prompt,
sampler=sampler,
init_image=interpolated_image,
width=width,
height=height,
seed=result[0][1],
step_callback=step_callback,
steps = steps,
cfg_scale = cfg_scale,
ddim_eta = ddim_eta,
conditioning = conditioning,
**kwargs
)
return result[0][0]
if sampler.uses_inpainting_model():
return inpaint_make_image
else:
return make_image
# returns a tensor filled with random numbers from a normal distribution # returns a tensor filled with random numbers from a normal distribution
def get_noise(self,width,height,scale = True): def get_noise(self,width,height,scale = True):
@ -134,3 +176,4 @@ class Txt2Img2Img(Generator):
scaled_height // self.downsampling_factor, scaled_height // self.downsampling_factor,
scaled_width // self.downsampling_factor], scaled_width // self.downsampling_factor],
device=device) device=device)

View File

@ -272,7 +272,6 @@ class Sampler(object):
) )
if mask is not None: if mask is not None:
print('DEBUG: in masking routine')
assert x0 is not None assert x0 is not None
img_orig = self.model.q_sample( img_orig = self.model.q_sample(
x0, ts x0, ts
@ -438,24 +437,5 @@ class Sampler(object):
def conditioning_key(self)->str: def conditioning_key(self)->str:
return self.model.model.conditioning_key return self.model.model.conditioning_key
# def make_cond_in(self, uncond, cond): def uses_inpainting_model(self)->bool:
# ''' return self.conditioning_key() in ('hybrid','concat')
# This handles the choice between a conditional conditioning
# that is a tensor (used by cross attention) vs one that is a dict
# used by 'hybrid'
# '''
# if isinstance(cond, dict):
# assert isinstance(uncond, dict)
# cond_in = dict()
# for k in cond:
# if isinstance(cond[k], list):
# cond_in[k] = [
# torch.cat([uncond[k][i], cond[k][i]])
# for i in range(len(cond[k]))
# ]
# else:
# cond_in[k] = torch.cat([uncond[k], cond[k]])
# else:
# cond_in = torch.cat([uncond, cond])
# return cond_in

View File

@ -439,7 +439,6 @@ class FrozenCLIPEmbedder(AbstractEncoder):
param.requires_grad = False param.requires_grad = False
def forward(self, text, **kwargs): def forward(self, text, **kwargs):
print(f'DEBUG text={text}, max_length={self.max_length}')
batch_encoding = self.tokenizer( batch_encoding = self.tokenizer(
text, text,
truncation=True, truncation=True,

View File

@ -18,6 +18,7 @@ from ldm.invoke.image_util import make_grid
from ldm.invoke.log import write_log from ldm.invoke.log import write_log
from omegaconf import OmegaConf from omegaconf import OmegaConf
from pathlib import Path from pathlib import Path
from pyparsing import ParseException
# global used in multiple functions (fix) # global used in multiple functions (fix)
infile = None infile = None
@ -328,12 +329,16 @@ def main_loop(gen, opt):
if operation == 'generate': if operation == 'generate':
catch_ctrl_c = infile is None # if running interactively, we catch keyboard interrupts catch_ctrl_c = infile is None # if running interactively, we catch keyboard interrupts
opt.last_operation='generate' opt.last_operation='generate'
gen.prompt2image( try:
image_callback=image_writer, gen.prompt2image(
step_callback=step_callback, image_callback=image_writer,
catch_interrupts=catch_ctrl_c, step_callback=step_callback,
**vars(opt) catch_interrupts=catch_ctrl_c,
) **vars(opt)
)
except ParseException as e:
print('** An error occurred while processing your prompt **')
print(f'** {str(e)} **')
elif operation == 'postprocess': elif operation == 'postprocess':
print(f'>> fixing {opt.prompt}') print(f'>> fixing {opt.prompt}')
opt.last_operation = do_postprocess(gen,opt,image_writer) opt.last_operation = do_postprocess(gen,opt,image_writer)
@ -592,7 +597,9 @@ def write_config_file(conf_path, gen, model_name, new_config, clobber=False, mak
def do_textmask(gen, opt, callback): def do_textmask(gen, opt, callback):
image_path = opt.prompt image_path = opt.prompt
assert os.path.exists(image_path), '** "{image_path}" not found. Please enter the name of an existing image file to mask **' if not os.path.exists(image_path):
image_path = os.path.join(opt.outdir,image_path)
assert os.path.exists(image_path), '** "{opt.prompt}" not found. Please enter the name of an existing image file to mask **'
assert opt.text_mask is not None and len(opt.text_mask) >= 1, '** Please provide a text mask with -tm **' assert opt.text_mask is not None and len(opt.text_mask) >= 1, '** Please provide a text mask with -tm **'
tm = opt.text_mask[0] tm = opt.text_mask[0]
threshold = float(opt.text_mask[1]) if len(opt.text_mask) > 1 else 0.5 threshold = float(opt.text_mask[1]) if len(opt.text_mask) > 1 else 0.5