InvokeAI/ldm/invoke/CLI.py
Damian Stewart 786b8878d6
Save and display per-token attention maps (#1866)
* attention maps saving to /tmp

* tidy up diffusers branch backporting of cross attention refactoring

* base64-encoding the attention maps image for generationResult

* cleanup/refactor conditioning.py

* attention maps and tokens being sent to web UI

* attention maps: restrict count to actual token count and improve robustness

* add argument type hint to image_to_dataURL function

Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>

Co-authored-by: damian <git@damianstewart.com>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
2022-12-10 15:57:41 +01:00

949 lines
34 KiB
Python

import os
import re
import sys
import shlex
import copy
import warnings
import time
import traceback
import yaml
from ldm.generate import Generate
from ldm.invoke.globals import Globals
from ldm.invoke.prompt_parser import PromptParser
from ldm.invoke.readline import get_completer, Completer
from ldm.invoke.args import Args, metadata_dumps, metadata_from_png, dream_cmd_from_png
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
from ldm.invoke.image_util import make_grid
from ldm.invoke.log import write_log
from ldm.invoke.concepts_lib import Concepts
from omegaconf import OmegaConf
from pathlib import Path
import pyparsing
# global used in multiple functions (fix)
infile = None
def main():
"""Initialize command-line parsers and the diffusion model"""
global infile
print('* Initializing, be patient...')
opt = Args()
args = opt.parse_args()
if not args:
sys.exit(-1)
if args.laion400m:
print('--laion400m flag has been deprecated. Please use --model laion400m instead.')
sys.exit(-1)
if args.weights:
print('--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead.')
sys.exit(-1)
if args.max_loaded_models is not None:
if args.max_loaded_models <= 0:
print('--max_loaded_models must be >= 1; using 1')
args.max_loaded_models = 1
# alert - setting globals here
Globals.root = os.path.expanduser(args.root_dir or os.environ.get('INVOKEAI_ROOT') or os.path.abspath('.'))
Globals.try_patchmatch = args.patchmatch
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
# loading here to avoid long delays on startup
from ldm.generate import Generate
# these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported
import transformers
transformers.logging.set_verbosity_error()
# Loading Face Restoration and ESRGAN Modules
gfpgan,codeformer,esrgan = load_face_restoration(opt)
# normalize the config directory relative to root
if not os.path.isabs(opt.conf):
opt.conf = os.path.normpath(os.path.join(Globals.root,opt.conf))
if opt.embeddings:
if not os.path.isabs(opt.embedding_path):
embedding_path = os.path.normpath(os.path.join(Globals.root,opt.embedding_path))
else:
embedding_path = opt.embedding_path
else:
embedding_path = None
# load the infile as a list of lines
if opt.infile:
try:
if os.path.isfile(opt.infile):
infile = open(opt.infile, 'r', encoding='utf-8')
elif opt.infile == '-': # stdin
infile = sys.stdin
else:
raise FileNotFoundError(f'{opt.infile} not found.')
except (FileNotFoundError, IOError) as e:
print(f'{e}. Aborting.')
sys.exit(-1)
# creating a Generate object:
try:
gen = Generate(
conf = opt.conf,
model = opt.model,
sampler_name = opt.sampler_name,
embedding_path = embedding_path,
full_precision = opt.full_precision,
precision = opt.precision,
gfpgan=gfpgan,
codeformer=codeformer,
esrgan=esrgan,
free_gpu_mem=opt.free_gpu_mem,
safety_checker=opt.safety_checker,
max_loaded_models=opt.max_loaded_models,
)
except (FileNotFoundError, TypeError, AssertionError):
emergency_model_reconfigure()
sys.exit(-1)
except (IOError, KeyError) as e:
print(f'{e}. Aborting.')
sys.exit(-1)
if opt.seamless:
print(">> changed to seamless tiling mode")
# preload the model
try:
gen.load_model()
except AssertionError:
emergency_model_reconfigure()
sys.exit(-1)
# web server loops forever
if opt.web or opt.gui:
invoke_ai_web_server_loop(gen, gfpgan, codeformer, esrgan)
sys.exit(0)
if not infile:
print(
"\n* Initialization done! Awaiting your command (-h for help, 'q' to quit)"
)
try:
main_loop(gen, opt)
except KeyboardInterrupt:
print("\ngoodbye!")
# TODO: main_loop() has gotten busy. Needs to be refactored.
def main_loop(gen, opt):
"""prompt/read/execute loop"""
global infile
done = False
doneAfterInFile = infile is not None
path_filter = re.compile(r'[<>:"/\\|?*]')
last_results = list()
model_config = OmegaConf.load(opt.conf)
# The readline completer reads history from the .dream_history file located in the
# output directory specified at the time of script launch. We do not currently support
# changing the history file midstream when the output directory is changed.
completer = get_completer(opt, models=list(model_config.keys()))
set_default_output_dir(opt, completer)
add_embedding_terms(gen, completer)
output_cntr = completer.get_current_history_length()+1
# os.pathconf is not available on Windows
if hasattr(os, 'pathconf'):
path_max = os.pathconf(opt.outdir, 'PC_PATH_MAX')
name_max = os.pathconf(opt.outdir, 'PC_NAME_MAX')
else:
path_max = 260
name_max = 255
while not done:
operation = 'generate'
try:
command = get_next_command(infile)
except EOFError:
done = infile is None or doneAfterInFile
infile = None
continue
# skip empty lines
if not command.strip():
continue
if command.startswith(('#', '//')):
continue
if len(command.strip()) == 1 and command.startswith('q'):
done = True
break
if not command.startswith('!history'):
completer.add_history(command)
if command.startswith('!'):
command, operation = do_command(command, gen, opt, completer)
if operation is None:
continue
if opt.parse_cmd(command) is None:
continue
if opt.init_img:
try:
if not opt.prompt:
oldargs = metadata_from_png(opt.init_img)
opt.prompt = oldargs.prompt
print(f'>> Retrieved old prompt "{opt.prompt}" from {opt.init_img}')
except (OSError, AttributeError, KeyError):
pass
if len(opt.prompt) == 0:
opt.prompt = ''
# width and height are set by model if not specified
if not opt.width:
opt.width = gen.width
if not opt.height:
opt.height = gen.height
# retrieve previous value of init image if requested
if opt.init_img is not None and re.match('^-\\d+$', opt.init_img):
try:
opt.init_img = last_results[int(opt.init_img)][0]
print(f'>> Reusing previous image {opt.init_img}')
except IndexError:
print(
f'>> No previous initial image at position {opt.init_img} found')
opt.init_img = None
continue
# the outdir can change with each command, so we adjust it here
set_default_output_dir(opt,completer)
# try to relativize pathnames
for attr in ('init_img','init_mask','init_color'):
if getattr(opt,attr) and not os.path.exists(getattr(opt,attr)):
basename = getattr(opt,attr)
path = os.path.join(opt.outdir,basename)
setattr(opt,attr,path)
# retrieve previous value of seed if requested
# Exception: for postprocess operations negative seed values
# mean "discard the original seed and generate a new one"
# (this is a non-obvious hack and needs to be reworked)
if opt.seed is not None and opt.seed < 0 and operation != 'postprocess':
try:
opt.seed = last_results[opt.seed][1]
print(f'>> Reusing previous seed {opt.seed}')
except IndexError:
print(f'>> No previous seed at position {opt.seed} found')
opt.seed = None
continue
if opt.strength is None:
opt.strength = 0.75 if opt.out_direction is None else 0.83
if opt.with_variations is not None:
opt.with_variations = split_variations(opt.with_variations)
if opt.prompt_as_dir and operation == 'generate':
# sanitize the prompt to a valid folder name
subdir = path_filter.sub('_', opt.prompt)[:name_max].rstrip(' .')
# truncate path to maximum allowed length
# 39 is the length of '######.##########.##########-##.png', plus two separators and a NUL
subdir = subdir[:(path_max - 39 - len(os.path.abspath(opt.outdir)))]
current_outdir = os.path.join(opt.outdir, subdir)
print('Writing files to directory: "' + current_outdir + '"')
# make sure the output directory exists
if not os.path.exists(current_outdir):
os.makedirs(current_outdir)
else:
if not os.path.exists(opt.outdir):
os.makedirs(opt.outdir)
current_outdir = opt.outdir
# Here is where the images are actually generated!
last_results = []
try:
file_writer = PngWriter(current_outdir)
results = [] # list of filename, prompt pairs
grid_images = dict() # seed -> Image, only used if `opt.grid`
prior_variations = opt.with_variations or []
prefix = file_writer.unique_prefix()
step_callback = make_step_callback(gen, opt, prefix) if opt.save_intermediates > 0 else None
def image_writer(image, seed, upscaled=False, first_seed=None, use_prefix=None, prompt_in=None, attention_maps_image=None):
# note the seed is the seed of the current image
# the first_seed is the original seed that noise is added to
# when the -v switch is used to generate variations
nonlocal prior_variations
nonlocal prefix
path = None
if opt.grid:
grid_images[seed] = image
elif operation == 'mask':
filename = f'{prefix}.{use_prefix}.{seed}.png'
tm = opt.text_mask[0]
th = opt.text_mask[1] if len(opt.text_mask)>1 else 0.5
formatted_dream_prompt = f'!mask {opt.input_file_path} -tm {tm} {th}'
path = file_writer.save_image_and_prompt_to_png(
image = image,
dream_prompt = formatted_dream_prompt,
metadata = {},
name = filename,
compress_level = opt.png_compression,
)
results.append([path, formatted_dream_prompt])
else:
if use_prefix is not None:
prefix = use_prefix
postprocessed = upscaled if upscaled else operation=='postprocess'
opt.prompt = gen.concept_lib().replace_triggers_with_concepts(opt.prompt or prompt_in) # to avoid the problem of non-unique concept triggers
filename, formatted_dream_prompt = prepare_image_metadata(
opt,
prefix,
seed,
operation,
prior_variations,
postprocessed,
first_seed
)
path = file_writer.save_image_and_prompt_to_png(
image = image,
dream_prompt = formatted_dream_prompt,
metadata = metadata_dumps(
opt,
seeds = [seed if opt.variation_amount==0 and len(prior_variations)==0 else first_seed],
model_hash = gen.model_hash,
),
name = filename,
compress_level = opt.png_compression,
)
# update rfc metadata
if operation == 'postprocess':
tool = re.match('postprocess:(\w+)',opt.last_operation).groups()[0]
add_postprocessing_to_metadata(
opt,
opt.input_file_path,
filename,
tool,
formatted_dream_prompt,
)
if (not postprocessed) or opt.save_original:
# only append to results if we didn't overwrite an earlier output
results.append([path, formatted_dream_prompt])
# so that the seed autocompletes (on linux|mac when -S or --seed specified
if completer and operation == 'generate':
completer.add_seed(seed)
completer.add_seed(first_seed)
last_results.append([path, seed])
if operation == 'generate':
catch_ctrl_c = infile is None # if running interactively, we catch keyboard interrupts
opt.last_operation='generate'
try:
gen.prompt2image(
image_callback=image_writer,
step_callback=step_callback,
catch_interrupts=catch_ctrl_c,
**vars(opt)
)
except (PromptParser.ParsingException, pyparsing.ParseException) as e:
print('** An error occurred while processing your prompt **')
print(f'** {str(e)} **')
elif operation == 'postprocess':
print(f'>> fixing {opt.prompt}')
opt.last_operation = do_postprocess(gen,opt,image_writer)
elif operation == 'mask':
print(f'>> generating masks from {opt.prompt}')
do_textmask(gen, opt, image_writer)
if opt.grid and len(grid_images) > 0:
grid_img = make_grid(list(grid_images.values()))
grid_seeds = list(grid_images.keys())
first_seed = last_results[0][1]
filename = f'{prefix}.{first_seed}.png'
formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed,grid=True,iterations=len(grid_images))
formatted_dream_prompt += f' # {grid_seeds}'
metadata = metadata_dumps(
opt,
seeds = grid_seeds,
model_hash = gen.model_hash
)
path = file_writer.save_image_and_prompt_to_png(
image = grid_img,
dream_prompt = formatted_dream_prompt,
metadata = metadata,
name = filename
)
results = [[path, formatted_dream_prompt]]
except AssertionError as e:
print(e)
continue
except OSError as e:
print(e)
continue
print('Outputs:')
log_path = os.path.join(current_outdir, 'invoke_log')
output_cntr = write_log(results, log_path ,('txt', 'md'), output_cntr)
print()
print('goodbye!')
# TO DO: remove repetitive code and the awkward command.replace() trope
# Just do a simple parse of the command!
def do_command(command:str, gen, opt:Args, completer) -> tuple:
global infile
operation = 'generate' # default operation, alternative is 'postprocess'
if command.startswith('!dream'): # in case a stored prompt still contains the !dream command
command = command.replace('!dream ','',1)
elif command.startswith('!fix'):
command = command.replace('!fix ','',1)
operation = 'postprocess'
elif command.startswith('!mask'):
command = command.replace('!mask ','',1)
operation = 'mask'
elif command.startswith('!switch'):
model_name = command.replace('!switch ','',1)
gen.set_model(model_name)
add_embedding_terms(gen, completer)
completer.add_history(command)
operation = None
elif command.startswith('!models'):
gen.model_cache.print_models()
completer.add_history(command)
operation = None
elif command.startswith('!import'):
path = shlex.split(command)
if len(path) < 2:
print('** please provide a path to a .ckpt or .vae model file')
elif not os.path.exists(path[1]):
print(f'** {path[1]}: file not found')
else:
add_weights_to_config(path[1], gen, opt, completer)
completer.add_history(command)
operation = None
elif command.startswith('!edit'):
path = shlex.split(command)
if len(path) < 2:
print('** please provide the name of a model')
else:
edit_config(path[1], gen, opt, completer)
completer.add_history(command)
operation = None
elif command.startswith('!del'):
path = shlex.split(command)
if len(path) < 2:
print('** please provide the name of a model')
else:
del_config(path[1], gen, opt, completer)
completer.add_history(command)
operation = None
elif command.startswith('!fetch'):
file_path = command.replace('!fetch','',1).strip()
retrieve_dream_command(opt,file_path,completer)
completer.add_history(command)
operation = None
elif command.startswith('!replay'):
file_path = command.replace('!replay','',1).strip()
if infile is None and os.path.isfile(file_path):
infile = open(file_path, 'r', encoding='utf-8')
completer.add_history(command)
operation = None
elif command.startswith('!history'):
completer.show_history()
operation = None
elif command.startswith('!search'):
search_str = command.replace('!search','',1).strip()
completer.show_history(search_str)
operation = None
elif command.startswith('!clear'):
completer.clear_history()
operation = None
elif re.match('^!(\d+)',command):
command_no = re.match('^!(\d+)',command).groups()[0]
command = completer.get_line(int(command_no))
completer.set_line(command)
operation = None
else: # not a recognized command, so give the --help text
command = '-h'
return command, operation
def set_default_output_dir(opt:Args, completer:Completer):
'''
If opt.outdir is relative, we add the root directory to it
normalize the outdir relative to root and make sure it exists.
'''
if not os.path.isabs(opt.outdir):
opt.outdir=os.path.normpath(os.path.join(Globals.root,opt.outdir))
if not os.path.exists(opt.outdir):
os.makedirs(opt.outdir)
completer.set_default_dir(opt.outdir)
def add_weights_to_config(model_path:str, gen, opt, completer):
print(f'>> Model import in process. Please enter the values needed to configure this model:')
print()
new_config = {}
new_config['weights'] = model_path
done = False
while not done:
model_name = input('Short name for this model: ')
if not re.match('^[\w._-]+$',model_name):
print('** model name must contain only words, digits and the characters [._-] **')
else:
done = True
new_config['description'] = input('Description of this model: ')
completer.complete_extensions(('.yaml','.yml'))
completer.linebuffer = 'configs/stable-diffusion/v1-inference.yaml'
done = False
while not done:
new_config['config'] = input('Configuration file for this model: ')
done = os.path.exists(new_config['config'])
done = False
completer.complete_extensions(('.vae.pt','.vae','.ckpt'))
while not done:
vae = input('VAE autoencoder file for this model [None]: ')
if os.path.exists(vae):
new_config['vae'] = vae
done = True
else:
done = len(vae)==0
completer.complete_extensions(None)
for field in ('width','height'):
done = False
while not done:
try:
completer.linebuffer = '512'
value = int(input(f'Default image {field}: '))
assert value >= 64 and value <= 2048
new_config[field] = value
done = True
except:
print('** Please enter a valid integer between 64 and 2048')
make_default = input('Make this the default model? [n] ') in ('y','Y')
if write_config_file(opt.conf, gen, model_name, new_config, make_default=make_default):
completer.add_model(model_name)
def del_config(model_name:str, gen, opt, completer):
current_model = gen.model_name
if model_name == current_model:
print("** Can't delete active model. !switch to another model first. **")
return
gen.model_cache.del_model(model_name)
gen.model_cache.commit(opt.conf)
print(f'** {model_name} deleted')
completer.del_model(model_name)
def edit_config(model_name:str, gen, opt, completer):
config = gen.model_cache.config
if model_name not in config:
print(f'** Unknown model {model_name}')
return
print(f'\n>> Editing model {model_name} from configuration file {opt.conf}')
conf = config[model_name]
new_config = {}
completer.complete_extensions(('.yaml','.yml','.ckpt','.vae.pt'))
for field in ('description', 'weights', 'vae', 'config', 'width','height'):
completer.linebuffer = str(conf[field]) if field in conf else ''
new_value = input(f'{field}: ')
new_config[field] = int(new_value) if field in ('width','height') else new_value
make_default = input('Make this the default model? [n] ') in ('y','Y')
completer.complete_extensions(None)
write_config_file(opt.conf, gen, model_name, new_config, clobber=True, make_default=make_default)
def write_config_file(conf_path, gen, model_name, new_config, clobber=False, make_default=False):
current_model = gen.model_name
op = 'modify' if clobber else 'import'
print('\n>> New configuration:')
if make_default:
new_config['default'] = True
print(yaml.dump({model_name:new_config}))
if input(f'OK to {op} [n]? ') not in ('y','Y'):
return False
try:
print('>> Verifying that new model loads...')
gen.model_cache.add_model(model_name, new_config, clobber)
assert gen.set_model(model_name) is not None, 'model failed to load'
except AssertionError as e:
print(f'** aborting **')
gen.model_cache.del_model(model_name)
return False
if make_default:
print('making this default')
gen.model_cache.set_default_model(model_name)
gen.model_cache.commit(conf_path)
do_switch = input(f'Keep model loaded? [y]')
if len(do_switch)==0 or do_switch[0] in ('y','Y'):
pass
else:
gen.set_model(current_model)
return True
def do_textmask(gen, opt, callback):
image_path = opt.prompt
if not os.path.exists(image_path):
image_path = os.path.join(opt.outdir,image_path)
assert os.path.exists(image_path), '** "{opt.prompt}" not found. Please enter the name of an existing image file to mask **'
assert opt.text_mask is not None and len(opt.text_mask) >= 1, '** Please provide a text mask with -tm **'
opt.input_file_path = image_path
tm = opt.text_mask[0]
threshold = float(opt.text_mask[1]) if len(opt.text_mask) > 1 else 0.5
gen.apply_textmask(
image_path = image_path,
prompt = tm,
threshold = threshold,
callback = callback,
)
def do_postprocess (gen, opt, callback):
file_path = opt.prompt # treat the prompt as the file pathname
if opt.new_prompt is not None:
opt.prompt = opt.new_prompt
else:
opt.prompt = None
if os.path.dirname(file_path) == '': #basename given
file_path = os.path.join(opt.outdir,file_path)
opt.input_file_path = file_path
tool=None
if opt.facetool_strength > 0:
tool = opt.facetool
elif opt.embiggen:
tool = 'embiggen'
elif opt.upscale:
tool = 'upscale'
elif opt.out_direction:
tool = 'outpaint'
elif opt.outcrop:
tool = 'outcrop'
opt.save_original = True # do not overwrite old image!
opt.last_operation = f'postprocess:{tool}'
try:
gen.apply_postprocessor(
image_path = file_path,
tool = tool,
facetool_strength = opt.facetool_strength,
codeformer_fidelity = opt.codeformer_fidelity,
save_original = opt.save_original,
upscale = opt.upscale,
out_direction = opt.out_direction,
outcrop = opt.outcrop,
callback = callback,
opt = opt,
)
except OSError:
print(traceback.format_exc(), file=sys.stderr)
print(f'** {file_path}: file could not be read')
return
except (KeyError, AttributeError):
print(traceback.format_exc(), file=sys.stderr)
return
return opt.last_operation
def add_postprocessing_to_metadata(opt,original_file,new_file,tool,command):
original_file = original_file if os.path.exists(original_file) else os.path.join(opt.outdir,original_file)
new_file = new_file if os.path.exists(new_file) else os.path.join(opt.outdir,new_file)
try:
meta = retrieve_metadata(original_file)['sd-metadata']
except AttributeError:
try:
meta = retrieve_metadata(new_file)['sd-metadata']
except AttributeError:
meta = {}
if 'image' not in meta:
meta = metadata_dumps(opt,seeds=[opt.seed])['image']
meta['image'] = {}
img_data = meta.get('image')
pp = img_data.get('postprocessing',[]) or []
pp.append(
{
'tool':tool,
'dream_command':command,
}
)
meta['image']['postprocessing'] = pp
write_metadata(new_file,meta)
def prepare_image_metadata(
opt,
prefix,
seed,
operation='generate',
prior_variations=[],
postprocessed=False,
first_seed=None
):
if postprocessed and opt.save_original:
filename = choose_postprocess_name(opt,prefix,seed)
else:
wildcards = dict(opt.__dict__)
wildcards['prefix'] = prefix
wildcards['seed'] = seed
try:
filename = opt.fnformat.format(**wildcards)
except KeyError as e:
print(f'** The filename format contains an unknown key \'{e.args[0]}\'. Will use \'{{prefix}}.{{seed}}.png\' instead')
filename = f'{prefix}.{seed}.png'
except IndexError as e:
print(f'** The filename format is broken or complete. Will use \'{{prefix}}.{{seed}}.png\' instead')
filename = f'{prefix}.{seed}.png'
if opt.variation_amount > 0:
first_seed = first_seed or seed
this_variation = [[seed, opt.variation_amount]]
opt.with_variations = prior_variations + this_variation
formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed)
elif len(prior_variations) > 0:
formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed)
elif operation == 'postprocess':
formatted_dream_prompt = '!fix '+opt.dream_prompt_str(seed=seed,prompt=opt.input_file_path)
else:
formatted_dream_prompt = opt.dream_prompt_str(seed=seed)
return filename,formatted_dream_prompt
def choose_postprocess_name(opt,prefix,seed) -> str:
match = re.search('postprocess:(\w+)',opt.last_operation)
if match:
modifier = match.group(1) # will look like "gfpgan", "upscale", "outpaint" or "embiggen"
else:
modifier = 'postprocessed'
counter = 0
filename = None
available = False
while not available:
if counter == 0:
filename = f'{prefix}.{seed}.{modifier}.png'
else:
filename = f'{prefix}.{seed}.{modifier}-{counter:02d}.png'
available = not os.path.exists(os.path.join(opt.outdir,filename))
counter += 1
return filename
def get_next_command(infile=None) -> str: # command string
if infile is None:
command = input('invoke> ')
else:
command = infile.readline()
if not command:
raise EOFError
else:
command = command.strip()
if len(command)>0:
print(f'#{command}')
return command
def invoke_ai_web_server_loop(gen: Generate, gfpgan, codeformer, esrgan):
print('\n* --web was specified, starting web server...')
from backend.invoke_ai_web_server import InvokeAIWebServer
# Change working directory to the stable-diffusion directory
os.chdir(
os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
)
invoke_ai_web_server = InvokeAIWebServer(generate=gen, gfpgan=gfpgan, codeformer=codeformer, esrgan=esrgan)
try:
invoke_ai_web_server.run()
except KeyboardInterrupt:
pass
def add_embedding_terms(gen,completer):
'''
Called after setting the model, updates the autocompleter with
any terms loaded by the embedding manager.
'''
completer.add_embedding_terms(gen.model.embedding_manager.list_terms())
def split_variations(variations_string) -> list:
# shotgun parsing, woo
parts = []
broken = False # python doesn't have labeled loops...
for part in variations_string.split(','):
seed_and_weight = part.split(':')
if len(seed_and_weight) != 2:
print(f'** Could not parse with_variation part "{part}"')
broken = True
break
try:
seed = int(seed_and_weight[0])
weight = float(seed_and_weight[1])
except ValueError:
print(f'** Could not parse with_variation part "{part}"')
broken = True
break
parts.append([seed, weight])
if broken:
return None
elif len(parts) == 0:
return None
else:
return parts
def load_face_restoration(opt):
try:
gfpgan, codeformer, esrgan = None, None, None
if opt.restore or opt.esrgan:
from ldm.invoke.restoration import Restoration
restoration = Restoration()
if opt.restore:
gfpgan, codeformer = restoration.load_face_restore_models(opt.gfpgan_model_path)
else:
print('>> Face restoration disabled')
if opt.esrgan:
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
else:
print('>> Upscaling disabled')
else:
print('>> Face restoration and upscaling disabled')
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
return gfpgan,codeformer,esrgan
def make_step_callback(gen, opt, prefix):
destination = os.path.join(opt.outdir,'intermediates',prefix)
os.makedirs(destination,exist_ok=True)
print(f'>> Intermediate images will be written into {destination}')
def callback(img, step):
if step % opt.save_intermediates == 0 or step == opt.steps-1:
filename = os.path.join(destination,f'{step:04}.png')
image = gen.sample_to_image(img)
image.save(filename,'PNG')
return callback
def retrieve_dream_command(opt,command,completer):
'''
Given a full or partial path to a previously-generated image file,
will retrieve and format the dream command used to generate the image,
and pop it into the readline buffer (linux, Mac), or print out a comment
for cut-and-paste (windows)
Given a wildcard path to a folder with image png files,
will retrieve and format the dream command used to generate the images,
and save them to a file commands.txt for further processing
'''
if len(command) == 0:
return
tokens = command.split()
dir,basename = os.path.split(tokens[0])
if len(dir) == 0:
path = os.path.join(opt.outdir,basename)
else:
path = tokens[0]
if len(tokens) > 1:
return write_commands(opt, path, tokens[1])
cmd = ''
try:
cmd = dream_cmd_from_png(path)
except OSError:
print(f'## {tokens[0]}: file could not be read')
except (KeyError, AttributeError, IndexError):
print(f'## {tokens[0]}: file has no metadata')
except:
print(f'## {tokens[0]}: file could not be processed')
if len(cmd)>0:
completer.set_line(cmd)
def write_commands(opt, file_path:str, outfilepath:str):
dir,basename = os.path.split(file_path)
try:
paths = sorted(list(Path(dir).glob(basename)))
except ValueError:
print(f'## "{basename}": unacceptable pattern')
return
commands = []
cmd = None
for path in paths:
try:
cmd = dream_cmd_from_png(path)
except (KeyError, AttributeError, IndexError):
print(f'## {path}: file has no metadata')
except:
print(f'## {path}: file could not be processed')
if cmd:
commands.append(f'# {path}')
commands.append(cmd)
if len(commands)>0:
dir,basename = os.path.split(outfilepath)
if len(dir)==0:
outfilepath = os.path.join(opt.outdir,basename)
with open(outfilepath, 'w', encoding='utf-8') as f:
f.write('\n'.join(commands))
print(f'>> File {outfilepath} with commands created')
def emergency_model_reconfigure():
print()
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
print(' You appear to have a missing or misconfigured model file(s). ')
print(' The script will now exit and run configure_invokeai.py to help fix the problem.')
print(' After reconfiguration is done, please relaunch invoke.py. ')
print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
print('configure_invokeai is launching....\n')
sys.argv = ['configure_invokeai','--interactive']
import configure_invokeai
configure_invokeai.main()