| Create a mask from a text prompt describing part of the image|
+The mask may either be an image with transparent areas, in which case
+the inpainting will occur in the transparent areas only, or a black
+and white image, in which case all black areas will be painted into.
+
`--text_mask` (short form `-tm`) is a way to generate a mask using a
text description of the part of the image to replace. For example, if
you have an image of a breakfast plate with a bagel, toast and
diff --git a/docs/features/PROMPTS.md b/docs/features/PROMPTS.md
index 8fdb97b7b8..1a3dcb5e9d 100644
--- a/docs/features/PROMPTS.md
+++ b/docs/features/PROMPTS.md
@@ -45,7 +45,7 @@ Here's a prompt that depicts what it does.
original prompt:
-`#!bash "A fantastical translucent poney made of water and foam, ethereal, radiant, hyperalism, scottish folklore, digital painting, artstation, concept art, smooth, 8 k frostbite 3 engine, ultra detailed, art by artgerm and greg rutkowski and magali villeneuve" -s 20 -W 512 -H 768 -C 7.5 -A k_euler_a -S 1654590180`
+`#!bash "A fantastical translucent pony made of water and foam, ethereal, radiant, hyperalism, scottish folklore, digital painting, artstation, concept art, smooth, 8 k frostbite 3 engine, ultra detailed, art by artgerm and greg rutkowski and magali villeneuve" -s 20 -W 512 -H 768 -C 7.5 -A k_euler_a -S 1654590180`
![step1](../assets/negative_prompt_walkthru/step1.png)
@@ -110,7 +110,10 @@ See the section below on "Prompt Blending" for more information about how this w
### Cross-Attention Control ('prompt2prompt')
-Denoise with a given prompt and then re-use the attention→pixel maps to substitute words in the original prompt for words in a new prompt. Based off [bloc97's colab](https://github.com/bloc97/CrossAttentionControl).
+Generate an image with a given prompt and then paint over the image
+using the `prompt2prompt` syntax to substitute words in the original
+prompt for words in a new prompt. Based off [bloc97's
+colab](https://github.com/bloc97/CrossAttentionControl).
* `a ("fluffy cat").swap("smiling dog") eating a hotdog`.
* quotes optional: `a (fluffy cat).swap(smiling dog) eating a hotdog`.
@@ -122,9 +125,26 @@ Denoise with a given prompt and then re-use the attention→pixel maps to substi
* Convenience option `shape_freedom` (0-1) to specify how much "freedom" Stable Diffusion should have to change the shape of the subject being swapped.
* `a (cat).swap(dog, shape_freedom=0.5) eating a hotdog`.
+Note that `prompt2prompt` is not currently working with the runwayML
+inpainting model, and may never work due to the way this model is set
+up. If you attempt to use `prompt2prompt` you will get the original
+image back. However, since this model is so good at inpainting, a
+good substitute is to use the `clipseg` text masking option:
+
+```
+invoke> a fluffy cat eating a hotdot
+Outputs:
+[1010] outputs/000025.2182095108.png: a fluffy cat eating a hotdog
+invoke> a smiling dog eating a hotdog -I 000025.2182095108.png -tm cat
+```
+
### Escaping parantheses () and speech marks ""
-If the model you are using has parentheses () or speech marks "" as part of its syntax, you will need to "escape" these using a backslash, so that`(my_keyword)` becomes `\(my_keyword\)`. Otherwise, the prompt parser will attempt to interpret the parentheses as part of the prompt syntax and it will get confused.
+If the model you are using has parentheses () or speech marks "" as
+part of its syntax, you will need to "escape" these using a backslash,
+so that`(my_keyword)` becomes `\(my_keyword\)`. Otherwise, the prompt
+parser will attempt to interpret the parentheses as part of the prompt
+syntax and it will get confused.
## **Prompt Blending**
diff --git a/ldm/generate.py b/ldm/generate.py
index 3785be56bb..ff0357adfd 100644
--- a/ldm/generate.py
+++ b/ldm/generate.py
@@ -274,6 +274,7 @@ class Generate:
init_img = None,
init_mask = None,
text_mask = None,
+ invert_mask = False,
fit = False,
strength = None,
init_color = None,
@@ -311,6 +312,7 @@ class Generate:
init_img // path to an initial image
init_mask // path to a mask for the initial image
text_mask // a text string that will be used to guide clipseg generation of the init_mask
+ invert_mask // boolean, if true invert the mask
strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely
facetool_strength // strength for GFPGAN/CodeFormer. 0.0 preserves image exactly, 1.0 replaces it completely
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
@@ -418,22 +420,11 @@ class Generate:
height,
fit=fit,
text_mask=text_mask,
+ invert_mask=invert_mask,
)
# TODO: Hacky selection of operation to perform. Needs to be refactored.
- if self.sampler.conditioning_key() in ('hybrid','concat'):
- print(f'** Inpainting model detected. Will try it! **')
- generator = self._make_omnibus()
- elif (init_image is not None) and (mask_image is not None):
- generator = self._make_inpaint()
- elif (embiggen != None or embiggen_tiles != None):
- generator = self._make_embiggen()
- elif init_image is not None:
- generator = self._make_img2img()
- elif hires_fix:
- generator = self._make_txt2img2img()
- else:
- generator = self._make_txt2img()
+ generator = self.select_generator(init_image, mask_image, embiggen, hires_fix)
generator.set_variation(
self.seed, variation_amount, with_variations
@@ -549,7 +540,7 @@ class Generate:
# try to reuse the same filename prefix as the original file.
# we take everything up to the first period
prefix = None
- m = re.match('^([^.]+)\.',os.path.basename(image_path))
+ m = re.match(r'^([^.]+)\.',os.path.basename(image_path))
if m:
prefix = m.groups()[0]
@@ -603,10 +594,9 @@ class Generate:
elif tool == 'embiggen':
# fetch the metadata from the image
- generator = self._make_embiggen()
+ generator = self.select_generator(embiggen=True)
opt.strength = 0.40
print(f'>> Setting img2img strength to {opt.strength} for happy embiggening')
- # embiggen takes a image path (sigh)
generator.generate(
prompt,
sampler = self.sampler,
@@ -640,6 +630,31 @@ class Generate:
print(f'* postprocessing tool {tool} is not yet supported')
return None
+ def select_generator(
+ self,
+ init_image:Image.Image=None,
+ mask_image:Image.Image=None,
+ embiggen:bool=False,
+ hires_fix:bool=False,
+ ):
+ inpainting_model_in_use = self.sampler.uses_inpainting_model()
+
+ if hires_fix:
+ return self._make_txt2img2img()
+
+ if embiggen is not None:
+ return self._make_embiggen()
+
+ if inpainting_model_in_use:
+ return self._make_omnibus()
+
+ if (init_image is not None) and (mask_image is not None):
+ return self._make_inpaint()
+
+ if init_image is not None:
+ return self._make_img2img()
+
+ return self._make_txt2img()
def _make_images(
self,
@@ -649,6 +664,7 @@ class Generate:
height,
fit=False,
text_mask=None,
+ invert_mask=False,
):
init_image = None
init_mask = None
@@ -678,6 +694,9 @@ class Generate:
elif text_mask:
init_mask = self._txt2mask(image, text_mask, width, height, fit=fit)
+ if invert_mask:
+ init_mask = ImageOps.invert(init_mask)
+
return init_image,init_mask
# lots o' repeated code here! Turn into a make_func()
@@ -855,6 +874,8 @@ class Generate:
def sample_to_image(self, samples):
return self._make_base().sample_to_image(samples)
+ # very repetitive code - can this be simplified? The KSampler names are
+ # consistent, at least
def _set_sampler(self):
msg = f'>> Setting Sampler to {self.sampler_name}'
if self.sampler_name == 'plms':
diff --git a/ldm/invoke/args.py b/ldm/invoke/args.py
index 26c5dc1c81..3c1bf2eae4 100644
--- a/ldm/invoke/args.py
+++ b/ldm/invoke/args.py
@@ -705,6 +705,11 @@ class Args(object):
type=str,
help='Path to input mask for inpainting mode (supersedes width and height)',
)
+ img2img_group.add_argument(
+ '--invert_mask',
+ action='store_true',
+ help='Invert the mask',
+ )
img2img_group.add_argument(
'-tm',
'--text_mask',
diff --git a/ldm/invoke/generator/base.py b/ldm/invoke/generator/base.py
index 2e96c93cbb..77b92d693d 100644
--- a/ldm/invoke/generator/base.py
+++ b/ldm/invoke/generator/base.py
@@ -29,7 +29,8 @@ class Generator():
self.threshold = 0
self.variation_amount = 0
self.with_variations = []
- self.use_mps_noise = False
+ self.use_mps_noise = False
+ self.free_gpu_mem = None
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
def get_make_image(self,prompt,**kwargs):
@@ -50,7 +51,7 @@ class Generator():
**kwargs):
scope = choose_autocast(self.precision)
self.safety_checker = safety_checker
- make_image = self.get_make_image(
+ make_image = self.get_make_image(
prompt,
sampler = sampler,
init_image = init_image,
diff --git a/ldm/invoke/generator/embiggen.py b/ldm/invoke/generator/embiggen.py
index 53fbde68cf..dc6af35a6c 100644
--- a/ldm/invoke/generator/embiggen.py
+++ b/ldm/invoke/generator/embiggen.py
@@ -21,6 +21,7 @@ class Embiggen(Generator):
def generate(self,prompt,iterations=1,seed=None,
image_callback=None, step_callback=None,
**kwargs):
+
scope = choose_autocast(self.precision)
make_image = self.get_make_image(
prompt,
@@ -63,6 +64,8 @@ class Embiggen(Generator):
Returns a function returning an image derived from the prompt and multi-stage twice-baked potato layering over the img2img on the initial image
Return value depends on the seed at the time you call it
"""
+ assert not sampler.uses_inpainting_model(), "--embiggen is not supported by inpainting models"
+
# Construct embiggen arg array, and sanity check arguments
if embiggen == None: # embiggen can also be called with just embiggen_tiles
embiggen = [1.0] # If not specified, assume no scaling
diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py
index 9fad2d80e1..a11f1bc69d 100644
--- a/ldm/invoke/generator/txt2img2img.py
+++ b/ldm/invoke/generator/txt2img2img.py
@@ -5,10 +5,11 @@ ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
import torch
import numpy as np
import math
-from ldm.invoke.generator.base import Generator
+from ldm.invoke.generator.base import Generator
from ldm.models.diffusion.ddim import DDIMSampler
+from ldm.invoke.generator.omnibus import Omnibus
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
-
+from PIL import Image
class Txt2Img2Img(Generator):
def __init__(self, model, precision):
@@ -25,16 +26,16 @@ class Txt2Img2Img(Generator):
"""
uc, c, extra_conditioning_info = conditioning
- @torch.no_grad()
- def make_image(x_T):
-
- trained_square = 512 * 512
- actual_square = width * height
- scale = math.sqrt(trained_square / actual_square)
+ trained_square = 512 * 512
+ actual_square = width * height
+ scale = math.sqrt(trained_square / actual_square)
+
+ init_width = math.ceil(scale * width / 64) * 64
+ init_height = math.ceil(scale * height / 64) * 64
+
+ @torch.no_grad()
+ def make_image(x_T):
- init_width = math.ceil(scale * width / 64) * 64
- init_height = math.ceil(scale * height / 64) * 64
-
shape = [
self.latent_channels,
init_height // self.downsampling_factor,
@@ -105,8 +106,49 @@ class Txt2Img2Img(Generator):
return self.sample_to_image(samples)
- return make_image
-
+ # in the case of the inpainting model being loaded, the trick of
+ # providing an interpolated latent doesn't work, so we transiently
+ # create a 512x512 PIL image, upscale it, and run the inpainting
+ # over it in img2img mode. Because the inpaing model is so conservative
+ # it doesn't change the image (much)
+ def inpaint_make_image(x_T):
+ omnibus = Omnibus(self.model,self.precision)
+ result = omnibus.generate(
+ prompt,
+ sampler=sampler,
+ width=init_width,
+ height=init_height,
+ step_callback=step_callback,
+ steps = steps,
+ cfg_scale = cfg_scale,
+ ddim_eta = ddim_eta,
+ conditioning = conditioning,
+ **kwargs
+ )
+ assert result is not None and len(result)>0,'** txt2img failed **'
+ image = result[0][0]
+ interpolated_image = image.resize((width,height),resample=Image.Resampling.LANCZOS)
+ print(kwargs.pop('init_image',None))
+ result = omnibus.generate(
+ prompt,
+ sampler=sampler,
+ init_image=interpolated_image,
+ width=width,
+ height=height,
+ seed=result[0][1],
+ step_callback=step_callback,
+ steps = steps,
+ cfg_scale = cfg_scale,
+ ddim_eta = ddim_eta,
+ conditioning = conditioning,
+ **kwargs
+ )
+ return result[0][0]
+
+ if sampler.uses_inpainting_model():
+ return inpaint_make_image
+ else:
+ return make_image
# returns a tensor filled with random numbers from a normal distribution
def get_noise(self,width,height,scale = True):
@@ -134,3 +176,4 @@ class Txt2Img2Img(Generator):
scaled_height // self.downsampling_factor,
scaled_width // self.downsampling_factor],
device=device)
+
diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py
index f31f5b1758..f92ceb3d3f 100644
--- a/ldm/models/diffusion/sampler.py
+++ b/ldm/models/diffusion/sampler.py
@@ -272,7 +272,6 @@ class Sampler(object):
)
if mask is not None:
- print('DEBUG: in masking routine')
assert x0 is not None
img_orig = self.model.q_sample(
x0, ts
@@ -438,24 +437,5 @@ class Sampler(object):
def conditioning_key(self)->str:
return self.model.model.conditioning_key
- # def make_cond_in(self, uncond, cond):
- # '''
- # This handles the choice between a conditional conditioning
- # that is a tensor (used by cross attention) vs one that is a dict
- # used by 'hybrid'
- # '''
- # if isinstance(cond, dict):
- # assert isinstance(uncond, dict)
- # cond_in = dict()
- # for k in cond:
- # if isinstance(cond[k], list):
- # cond_in[k] = [
- # torch.cat([uncond[k][i], cond[k][i]])
- # for i in range(len(cond[k]))
- # ]
- # else:
- # cond_in[k] = torch.cat([uncond[k], cond[k]])
- # else:
- # cond_in = torch.cat([uncond, cond])
- # return cond_in
-
+ def uses_inpainting_model(self)->bool:
+ return self.conditioning_key() in ('hybrid','concat')
diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py
index 422179e1c6..cf8644a7fb 100644
--- a/ldm/modules/encoders/modules.py
+++ b/ldm/modules/encoders/modules.py
@@ -439,7 +439,6 @@ class FrozenCLIPEmbedder(AbstractEncoder):
param.requires_grad = False
def forward(self, text, **kwargs):
- print(f'DEBUG text={text}, max_length={self.max_length}')
batch_encoding = self.tokenizer(
text,
truncation=True,
diff --git a/scripts/invoke.py b/scripts/invoke.py
index 466536bc46..1937d47ab5 100644
--- a/scripts/invoke.py
+++ b/scripts/invoke.py
@@ -18,6 +18,7 @@ from ldm.invoke.image_util import make_grid
from ldm.invoke.log import write_log
from omegaconf import OmegaConf
from pathlib import Path
+from pyparsing import ParseException
# global used in multiple functions (fix)
infile = None
@@ -328,12 +329,16 @@ def main_loop(gen, opt):
if operation == 'generate':
catch_ctrl_c = infile is None # if running interactively, we catch keyboard interrupts
opt.last_operation='generate'
- gen.prompt2image(
- image_callback=image_writer,
- step_callback=step_callback,
- catch_interrupts=catch_ctrl_c,
- **vars(opt)
- )
+ try:
+ gen.prompt2image(
+ image_callback=image_writer,
+ step_callback=step_callback,
+ catch_interrupts=catch_ctrl_c,
+ **vars(opt)
+ )
+ except ParseException as e:
+ print('** An error occurred while processing your prompt **')
+ print(f'** {str(e)} **')
elif operation == 'postprocess':
print(f'>> fixing {opt.prompt}')
opt.last_operation = do_postprocess(gen,opt,image_writer)
@@ -592,7 +597,9 @@ def write_config_file(conf_path, gen, model_name, new_config, clobber=False, mak
def do_textmask(gen, opt, callback):
image_path = opt.prompt
- assert os.path.exists(image_path), '** "{image_path}" not found. Please enter the name of an existing image file to mask **'
+ if not os.path.exists(image_path):
+ image_path = os.path.join(opt.outdir,image_path)
+ assert os.path.exists(image_path), '** "{opt.prompt}" not found. Please enter the name of an existing image file to mask **'
assert opt.text_mask is not None and len(opt.text_mask) >= 1, '** Please provide a text mask with -tm **'
tm = opt.text_mask[0]
threshold = float(opt.text_mask[1]) if len(opt.text_mask) > 1 else 0.5