InvokeAI/scripts/dream.py

454 lines
16 KiB
Python
Raw Normal View History

2022-08-21 23:57:48 +00:00
#!/usr/bin/env python3
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
import os
import re
import sys
2022-09-04 20:48:17 +00:00
import shlex
2022-08-24 16:02:36 +00:00
import copy
import warnings
import time
import ldm.dream.readline
from ldm.dream.args import Args, metadata_dumps, metadata_from_png
from ldm.dream.pngwriter import PngWriter
from ldm.dream.server import DreamServer, ThreadingDreamServer
2022-08-31 04:36:38 +00:00
from ldm.dream.image_util import make_grid
2022-09-16 13:03:41 +00:00
from ldm.dream.log import write_log
from omegaconf import OmegaConf
2022-08-18 16:43:59 +00:00
# Placeholder to be replaced with proper class that tracks the
# outputs and associates with the prompt that generated them.
# Just want to get the formatting look right for now.
output_cntr = 0
def main():
"""Initialize command-line parsers and the diffusion model"""
opt = Args()
args = opt.parse_args()
if not args:
sys.exit(-1)
if args.laion400m:
print('--laion400m flag has been deprecated. Please use --model laion400m instead.')
sys.exit(-1)
if args.weights:
print('--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead.')
sys.exit(-1)
print('* Initializing, be patient...\n')
sys.path.append('.')
Refactoring simplet2i (#387) * start refactoring -not yet functional * first phase of refactor done - not sure weighted prompts working * Second phase of refactoring. Everything mostly working. * The refactoring has moved all the hard-core inference work into ldm.dream.generator.*, where there are submodules for txt2img and img2img. inpaint will go in there as well. * Some additional refactoring will be done soon, but relatively minor work. * fix -save_orig flag to actually work * add @neonsecret attention.py memory optimization * remove unneeded imports * move token logging into conditioning.py * add placeholder version of inpaint; porting in progress * fix crash in img2img * inpainting working; not tested on variations * fix crashes in img2img * ported attention.py memory optimization #117 from basujindal branch * added @torch_no_grad() decorators to img2img, txt2img, inpaint closures * Final commit prior to PR against development * fixup crash when generating intermediate images in web UI * rename ldm.simplet2i to ldm.generate * add backward-compatibility simplet2i shell with deprecation warning * add back in mps exception, addresses @vargol comment in #354 * replaced Conditioning class with exported functions * fix wrong type of with_variations attribute during intialization * changed "image_iterator()" to "get_make_image()" * raise NotImplementedError for calling get_make_image() in parent class * Update ldm/generate.py better error message Co-authored-by: Kevin Gibbons <bakkot@gmail.com> * minor stylistic fixes and assertion checks from code review * moved get_noise() method into img2img class * break get_noise() into two methods, one for txt2img and the other for img2img * inpainting works on non-square images now * make get_noise() an abstract method in base class * much improved inpainting Co-authored-by: Kevin Gibbons <bakkot@gmail.com>
2022-09-06 00:40:10 +00:00
from ldm.generate import Generate
# these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported
import transformers
transformers.logging.set_verbosity_error()
# Loading Face Restoration and ESRGAN Modules
try:
gfpgan, codeformer, esrgan = None, None, None
2022-09-21 07:00:19 +00:00
from ldm.dream.restoration.base import Restoration
restoration = Restoration(opt.gfpgan_dir, opt.gfpgan_model_path, opt.esrgan_bg_tile)
if opt.restore:
gfpgan, codeformer = restoration.load_face_restore_models()
else:
2022-09-19 17:59:43 +00:00
print('>> Face restoration disabled')
if opt.esrgan:
2022-09-21 07:00:19 +00:00
esrgan = restoration.load_esrgan()
else:
2022-09-19 17:59:43 +00:00
print('>> Upscaling disabled')
except (ModuleNotFoundError, ImportError):
import traceback
print(traceback.format_exc(), file=sys.stderr)
print('>> You may need to install the ESRGAN and/or GFPGAN modules')
# creating a simple text2image object with a handful of
# defaults passed on the command line.
# additional parameters will be added (or overriden) during
# the user input loop
try:
gen = Generate(
conf = opt.conf,
model = opt.model,
sampler_name = opt.sampler_name,
embedding_path = opt.embedding_path,
full_precision = opt.full_precision,
precision = opt.precision,
gfpgan=gfpgan,
codeformer=codeformer,
esrgan=esrgan
)
except (FileNotFoundError, IOError, KeyError) as e:
print(f'{e}. Aborting.')
sys.exit(-1)
# make sure the output directory exists
if not os.path.exists(opt.outdir):
os.makedirs(opt.outdir)
2022-08-26 11:10:37 +00:00
# load the infile as a list of lines
infile = None
2022-08-26 11:10:37 +00:00
if opt.infile:
2022-08-28 18:20:34 +00:00
try:
if os.path.isfile(opt.infile):
2022-08-29 02:57:13 +00:00
infile = open(opt.infile, 'r', encoding='utf-8')
elif opt.infile == '-': # stdin
2022-08-28 18:20:34 +00:00
infile = sys.stdin
else:
raise FileNotFoundError(f'{opt.infile} not found.')
except (FileNotFoundError, IOError) as e:
2022-08-28 18:20:34 +00:00
print(f'{e}. Aborting.')
sys.exit(-1)
if opt.seamless:
print(">> changed to seamless tiling mode")
# preload the model
gen.load_model()
2022-08-26 02:57:30 +00:00
2022-08-28 18:20:34 +00:00
if not infile:
print(
"\n* Initialization done! Awaiting your command (-h for help, 'q' to quit)"
2022-08-28 18:20:34 +00:00
)
# web server loops forever
if opt.web:
dream_server_loop(gen, opt.host, opt.port, opt.outdir, gfpgan)
sys.exit(0)
main_loop(gen, opt, infile)
# TODO: main_loop() has gotten busy. Needs to be refactored.
def main_loop(gen, opt, infile):
"""prompt/read/execute loop"""
done = False
path_filter = re.compile(r'[<>:"/\\|?*]')
last_results = list()
model_config = OmegaConf.load(opt.conf)[opt.model]
# os.pathconf is not available on Windows
if hasattr(os, 'pathconf'):
path_max = os.pathconf(opt.outdir, 'PC_PATH_MAX')
name_max = os.pathconf(opt.outdir, 'PC_NAME_MAX')
else:
path_max = 260
name_max = 255
while not done:
operation = 'generate' # default operation, alternative is 'postprocess'
try:
2022-08-28 18:20:34 +00:00
command = get_next_command(infile)
except EOFError:
done = True
continue
2022-08-26 11:10:37 +00:00
# skip empty lines
if not command.strip():
continue
if command.startswith(('#', '//')):
continue
2022-09-17 02:12:35 +00:00
if len(command.strip()) == 1 and command.startswith('q'):
done = True
break
if command.startswith(
'!dream'
): # in case a stored prompt still contains the !dream command
command = command.replace('!dream ','',1)
if command.startswith(
'!fix'
):
command = command.replace('!fix ','',1)
operation = 'postprocess'
2022-09-17 04:57:35 +00:00
if opt.parse_cmd(command) is None:
continue
2022-09-04 20:48:17 +00:00
if opt.init_img:
try:
oldargs = metadata_from_png(opt.init_img)
opt.prompt = oldargs.prompt
print(f'>> Retrieved old prompt "{opt.prompt}" from {opt.init_img}')
2022-09-04 20:48:17 +00:00
except AttributeError:
pass
except KeyError:
pass
if len(opt.prompt) == 0:
2022-09-17 02:12:35 +00:00
print('\nTry again with a prompt!')
continue
# width and height are set by model if not specified
if not opt.width:
opt.width = model_config.width
if not opt.height:
opt.height = model_config.height
# retrieve previous value of init image if requested
if opt.init_img is not None and re.match('^-\\d+$', opt.init_img):
try:
opt.init_img = last_results[int(opt.init_img)][0]
print(f'>> Reusing previous image {opt.init_img}')
except IndexError:
print(
f'>> No previous initial image at position {opt.init_img} found')
opt.init_img = None
continue
# retrieve previous valueof seed if requested
if opt.seed is not None and opt.seed < 0:
try:
opt.seed = last_results[opt.seed][1]
print(f'>> Reusing previous seed {opt.seed}')
except IndexError:
print(f'>> No previous seed at position {opt.seed} found')
opt.seed = None
continue
2022-09-15 13:43:33 +00:00
if opt.strength is None:
2022-09-15 13:52:36 +00:00
opt.strength = 0.75 if opt.out_direction is None else 0.83
2022-09-04 20:48:17 +00:00
if opt.with_variations is not None:
# shotgun parsing, woo
parts = []
broken = False # python doesn't have labeled loops...
for part in opt.with_variations.split(','):
seed_and_weight = part.split(':')
if len(seed_and_weight) != 2:
print(f'could not parse with_variation part "{part}"')
broken = True
break
try:
seed = int(seed_and_weight[0])
weight = float(seed_and_weight[1])
except ValueError:
print(f'could not parse with_variation part "{part}"')
broken = True
break
parts.append([seed, weight])
if broken:
continue
if len(parts) > 0:
opt.with_variations = parts
else:
opt.with_variations = None
if opt.prompt_as_dir:
2022-08-30 00:10:15 +00:00
# sanitize the prompt to a valid folder name
subdir = path_filter.sub('_', opt.prompt)[:name_max].rstrip(' .')
# truncate path to maximum allowed length
2022-08-30 00:10:15 +00:00
# 27 is the length of '######.##########.##.png', plus two separators and a NUL
subdir = subdir[:(path_max - 27 - len(os.path.abspath(opt.outdir)))]
2022-09-17 02:12:35 +00:00
current_outdir = os.path.join(opt.outdir, subdir)
print('Writing files to directory: "' + current_outdir + '"')
# make sure the output directory exists
if not os.path.exists(current_outdir):
os.makedirs(current_outdir)
else:
if not os.path.exists(opt.outdir):
os.makedirs(opt.outdir)
2022-09-17 02:12:35 +00:00
current_outdir = opt.outdir
# Here is where the images are actually generated!
last_results = []
try:
file_writer = PngWriter(current_outdir)
prefix = file_writer.unique_prefix()
results = [] # list of filename, prompt pairs
grid_images = dict() # seed -> Image, only used if `opt.grid`
prior_variations = opt.with_variations or []
def image_writer(image, seed, upscaled=False, first_seed=None):
# note the seed is the seed of the current image
# the first_seed is the original seed that noise is added to
# when the -v switch is used to generate variations
2022-09-12 11:28:58 +00:00
path = None
nonlocal prior_variations
if opt.grid:
2022-08-31 04:21:04 +00:00
grid_images[seed] = image
else:
if operation == 'postprocess':
filename = choose_postprocess_name(opt.prompt)
elif upscaled and opt.save_original:
2022-08-31 04:21:04 +00:00
filename = f'{prefix}.{seed}.postprocessed.png'
else:
filename = f'{prefix}.{seed}.png'
if opt.variation_amount > 0:
first_seed = first_seed or seed
this_variation = [[seed, opt.variation_amount]]
opt.with_variations = prior_variations + this_variation
formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed)
elif len(prior_variations) > 0:
formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed)
elif operation == 'postprocess':
formatted_dream_prompt = '!fix '+opt.dream_prompt_str(seed=seed)
else:
formatted_dream_prompt = opt.dream_prompt_str(seed=seed)
path = file_writer.save_image_and_prompt_to_png(
image = image,
dream_prompt = formatted_dream_prompt,
metadata = metadata_dumps(
opt,
seeds = [seed],
model_hash = gen.model_hash,
),
name = filename,
)
2022-08-31 04:21:04 +00:00
if (not upscaled) or opt.save_original:
# only append to results if we didn't overwrite an earlier output
results.append([path, formatted_dream_prompt])
last_results.append([path, seed])
2022-08-31 04:21:04 +00:00
if operation == 'generate':
catch_ctrl_c = infile is None # if running interactively, we catch keyboard interrupts
gen.prompt2image(
image_callback=image_writer,
catch_interrupts=catch_ctrl_c,
**vars(opt)
)
elif operation == 'postprocess':
print(f'>> fixing {opt.prompt}')
do_postprocess(gen,opt,image_writer)
2022-08-31 04:21:04 +00:00
if opt.grid and len(grid_images) > 0:
2022-09-12 11:28:58 +00:00
grid_img = make_grid(list(grid_images.values()))
grid_seeds = list(grid_images.keys())
first_seed = last_results[0][1]
filename = f'{prefix}.{first_seed}.png'
formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed,grid=True,iterations=len(grid_images))
formatted_dream_prompt += f' # {grid_seeds}'
metadata = metadata_dumps(
opt,
seeds = grid_seeds,
model_hash = gen.model_hash
)
2022-08-31 04:21:04 +00:00
path = file_writer.save_image_and_prompt_to_png(
image = grid_img,
dream_prompt = formatted_dream_prompt,
metadata = metadata,
name = filename
)
results = [[path, formatted_dream_prompt]]
except AssertionError as e:
print(e)
continue
2022-08-24 15:42:44 +00:00
except OSError as e:
print(e)
continue
print('Outputs:')
2022-09-16 09:21:54 +00:00
log_path = os.path.join(current_outdir, 'dream_log')
2022-09-16 13:03:41 +00:00
global output_cntr
output_cntr = write_log(results, log_path ,('txt', 'md'), output_cntr)
2022-09-11 21:19:46 +00:00
print()
print('goodbye!')
def do_postprocess (gen, opt, callback):
file_path = opt.prompt # treat the prompt as the file pathname
if os.path.dirname(file_path) == '': #basename given
file_path = os.path.join(opt.outdir,file_path)
if not os.path.exists(file_path):
print(f'* file {file_path} does not exist')
return
tool=None
if opt.gfpgan_strength > 0:
tool = opt.facetool
elif opt.embiggen:
tool = 'embiggen'
elif opt.upscale:
tool = 'upscale'
elif opt.out_direction:
tool = 'outpaint'
opt.save_original = True # do not overwrite old image!
return gen.apply_postprocessor(
image_path = opt.prompt,
tool = tool,
gfpgan_strength = opt.gfpgan_strength,
codeformer_fidelity = opt.codeformer_fidelity,
save_original = opt.save_original,
upscale = opt.upscale,
out_direction = opt.out_direction,
callback = callback,
opt = opt,
)
def choose_postprocess_name(original_filename):
basename,_ = os.path.splitext(os.path.basename(original_filename))
if re.search('\d+\.\d+$',basename):
return f'{basename}.fixed.png'
match = re.search('(\d+\.\d+)\.fixed(-(\d+))?$',basename)
if match:
counter = match.group(3) or 0
return '{prefix}-{counter:02d}.png'.format(prefix=match.group(1), counter=int(counter)+1)
else:
return f'{basename}.fixed.png'
def get_next_command(infile=None) -> str: # command string
2022-08-28 18:20:34 +00:00
if infile is None:
command = input('dream> ')
2022-08-28 18:20:34 +00:00
else:
command = infile.readline()
if not command:
raise EOFError
else:
command = command.strip()
if len(command)>0:
print(f'#{command}')
2022-08-28 18:20:34 +00:00
return command
2022-08-29 02:57:13 +00:00
def dream_server_loop(gen, host, port, outdir, gfpgan):
print('\n* --web was specified, starting web server...')
# Change working directory to the stable-diffusion directory
os.chdir(
os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
)
2022-08-26 02:57:30 +00:00
# Start server
DreamServer.model = gen # misnomer in DreamServer - this is not the model you are looking for
DreamServer.outdir = outdir
DreamServer.gfpgan_model_exists = False
if gfpgan is not None:
DreamServer.gfpgan_model_exists = gfpgan.gfpgan_model_exists
dream_server = ThreadingDreamServer((host, port))
print(">> Started Stable Diffusion dream server!")
if host == '0.0.0.0':
print(
f"Point your browser at http://localhost:{port} or use the host's DNS name or IP address.")
else:
print(">> Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address.")
print(f">> Point your browser at http://{host}:{port}.")
try:
dream_server.serve_forever()
except KeyboardInterrupt:
pass
dream_server.server_close()
def write_log_message(results, log_path):
"""logs the name of the output image, prompt, and prompt args to the terminal and log file"""
global output_cntr
log_lines = [f'{path}: {prompt}\n' for path, prompt in results]
for l in log_lines:
output_cntr += 1
2022-09-11 21:19:46 +00:00
print(f'[{output_cntr}] {l}',end='')
2022-08-29 02:57:13 +00:00
with open(log_path, 'a', encoding='utf-8') as file:
file.writelines(log_lines)
2022-08-22 02:48:40 +00:00
if __name__ == '__main__':
main()