add !mask command to view output of clipseg

- The !mask command takes an image path, a text prompt, and
  (optionally) a masking threshold. It creates a mask over the region
  indicated by the prompt, and outputs several files that show which
  regions will be masked by the chosen prompt and threshold.

- The mask images should not be passed directly to img2img because
  they are designed for visualization only. Instead, use the
  --text_mask option to pass the selected prompt and threshold.

- See docs/features/INPAINTING.md for details.
This commit is contained in:
Lincoln Stein
2022-10-20 02:33:07 -04:00
parent 63f274f6df
commit a357bf4f19
10 changed files with 142 additions and 9 deletions

View File

@ -225,9 +225,13 @@ def main_loop(gen, opt, infile):
os.makedirs(opt.outdir)
current_outdir = opt.outdir
# write out the history at this point
# Write out the history at this point.
# TODO: Fix the parsing of command-line parameters
# so that !operations don't need to be stripped and readded
if operation == 'postprocess':
completer.add_history(f'!fix {command}')
elif operation == 'mask':
completer.add_history(f'!mask {command}')
else:
completer.add_history(command)
@ -247,13 +251,28 @@ def main_loop(gen, opt, infile):
# when the -v switch is used to generate variations
nonlocal prior_variations
nonlocal prefix
if use_prefix is not None:
prefix = use_prefix
path = None
if opt.grid:
grid_images[seed] = image
elif operation == 'mask':
filename = f'{prefix}.{use_prefix}.{seed}.png'
tm = opt.text_mask[0]
th = opt.text_mask[1] if len(opt.text_mask)>1 else 0.5
formatted_dream_prompt = f'!mask {opt.prompt} -tm {tm} {th}'
path = file_writer.save_image_and_prompt_to_png(
image = image,
dream_prompt = formatted_dream_prompt,
metadata = {},
name = filename,
compress_level = opt.png_compression,
)
results.append([path, formatted_dream_prompt])
else:
if use_prefix is not None:
prefix = use_prefix
postprocessed = upscaled if upscaled else operation=='postprocess'
filename, formatted_dream_prompt = prepare_image_metadata(
opt,
@ -292,7 +311,7 @@ def main_loop(gen, opt, infile):
results.append([path, formatted_dream_prompt])
# so that the seed autocompletes (on linux|mac when -S or --seed specified
if completer:
if completer and operation == 'generate':
completer.add_seed(seed)
completer.add_seed(first_seed)
last_results.append([path, seed])
@ -310,6 +329,10 @@ def main_loop(gen, opt, infile):
print(f'>> fixing {opt.prompt}')
opt.last_operation = do_postprocess(gen,opt,image_writer)
elif operation == 'mask':
print(f'>> generating masks from {opt.prompt}')
do_textmask(gen, opt, image_writer)
if opt.grid and len(grid_images) > 0:
grid_img = make_grid(list(grid_images.values()))
grid_seeds = list(grid_images.keys())
@ -355,6 +378,10 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
command = command.replace('!fix ','',1)
operation = 'postprocess'
elif command.startswith('!mask'):
command = command.replace('!mask ','',1)
operation = 'mask'
elif command.startswith('!switch'):
model_name = command.replace('!switch ','',1)
gen.set_model(model_name)
@ -363,6 +390,7 @@ def do_command(command:str, gen, opt:Args, completer) -> tuple:
elif command.startswith('!models'):
gen.model_cache.print_models()
completer.add_history(command)
operation = None
elif command.startswith('!import'):
@ -494,6 +522,19 @@ def write_config_file(conf_path, gen, model_name, new_config, clobber=False):
os.rename(tmpfile,conf_path)
return True
def do_textmask(gen, opt, callback):
image_path = opt.prompt
assert os.path.exists(image_path), '** "{image_path}" not found. Please enter the name of an existing image file to mask **'
assert opt.text_mask is not None and len(opt.text_mask) >= 1, '** Please provide a text mask with -tm **'
tm = opt.text_mask[0]
threshold = float(opt.text_mask[1]) if len(opt.text_mask) > 1 else 0.5
gen.apply_textmask(
image_path = image_path,
prompt = tm,
threshold = threshold,
callback = callback,
)
def do_postprocess (gen, opt, callback):
file_path = opt.prompt # treat the prompt as the file pathname
if os.path.dirname(file_path) == '': #basename given
@ -670,7 +711,7 @@ def load_face_restoration(opt):
print(traceback.format_exc(), file=sys.stderr)
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
return gfpgan,codeformer,esrgan
def make_step_callback(gen, opt, prefix):
destination = os.path.join(opt.outdir,'intermediates',prefix)
os.makedirs(destination,exist_ok=True)