#!/usr/bin/env python3 # Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein) import argparse import shlex import os import re import sys import copy import warnings import time import ldm.dream.readline from ldm.dream.pngwriter import PngWriter, PromptFormatter from ldm.dream.server import DreamServer, ThreadingDreamServer from ldm.dream.image_util import make_grid from omegaconf import OmegaConf # Placeholder to be replaced with proper class that tracks the # outputs and associates with the prompt that generated them. # Just want to get the formatting look right for now. output_cntr = 0 def main(): """Initialize command-line parsers and the diffusion model""" arg_parser = create_argv_parser() opt = arg_parser.parse_args() if opt.laion400m: print('--laion400m flag has been deprecated. Please use --model laion400m instead.') sys.exit(-1) if opt.weights != 'model': print('--weights argument has been deprecated. Please configure ./configs/models.yaml, and call it using --model instead.') sys.exit(-1) try: models = OmegaConf.load(opt.config) width = models[opt.model].width height = models[opt.model].height config = models[opt.model].config weights = models[opt.model].weights except (FileNotFoundError, IOError, KeyError) as e: print(f'{e}. Aborting.') sys.exit(-1) print('* Initializing, be patient...\n') sys.path.append('.') from pytorch_lightning import logging from ldm.generate import Generate # these two lines prevent a horrible warning message from appearing # when the frozen CLIP tokenizer is imported import transformers transformers.logging.set_verbosity_error() # creating a simple text2image object with a handful of # defaults passed on the command line. # additional parameters will be added (or overriden) during # the user input loop t2i = Generate( width=width, height=height, sampler_name=opt.sampler_name, weights=weights, full_precision=opt.full_precision, config=config, grid=opt.grid, # this is solely for recreating the prompt seamless=opt.seamless, embedding_path=opt.embedding_path, device_type=opt.device, ignore_ctrl_c=opt.infile is None, ) # make sure the output directory exists if not os.path.exists(opt.outdir): os.makedirs(opt.outdir) # gets rid of annoying messages about random seed logging.getLogger('pytorch_lightning').setLevel(logging.ERROR) # load the infile as a list of lines infile = None if opt.infile: try: if os.path.isfile(opt.infile): infile = open(opt.infile, 'r', encoding='utf-8') elif opt.infile == '-': # stdin infile = sys.stdin else: raise FileNotFoundError(f'{opt.infile} not found.') except (FileNotFoundError, IOError) as e: print(f'{e}. Aborting.') sys.exit(-1) if opt.seamless: print(">> changed to seamless tiling mode") # preload the model t2i.load_model() if not infile: print( "\n* Initialization done! Awaiting your command (-h for help, 'q' to quit)" ) cmd_parser = create_cmd_parser() if opt.web: dream_server_loop(t2i, opt.host, opt.port, opt.outdir) else: main_loop(t2i, opt.outdir, opt.prompt_as_dir, cmd_parser, infile) def main_loop(t2i, outdir, prompt_as_dir, parser, infile): """prompt/read/execute loop""" done = False path_filter = re.compile(r'[<>:"/\\|?*]') last_results = list() # os.pathconf is not available on Windows if hasattr(os, 'pathconf'): path_max = os.pathconf(outdir, 'PC_PATH_MAX') name_max = os.pathconf(outdir, 'PC_NAME_MAX') else: path_max = 260 name_max = 255 while not done: try: command = get_next_command(infile) except EOFError: done = True continue except KeyboardInterrupt: done = True continue # skip empty lines if not command.strip(): continue if command.startswith(('#', '//')): continue # before splitting, escape single quotes so as not to mess # up the parser command = command.replace("'", "\\'") try: elements = shlex.split(command) except ValueError as e: print(str(e)) continue if elements[0] == 'q': done = True break if elements[0].startswith( '!dream' ): # in case a stored prompt still contains the !dream command elements.pop(0) # rearrange the arguments to mimic how it works in the Dream bot. switches = [''] switches_started = False for el in elements: if el[0] == '-' and not switches_started: switches_started = True if switches_started: switches.append(el) else: switches[0] += el switches[0] += ' ' switches[0] = switches[0][: len(switches[0]) - 1] try: opt = parser.parse_args(switches) except SystemExit: parser.print_help() continue if len(opt.prompt) == 0: print('Try again with a prompt!') continue # retrieve previous value! if opt.init_img is not None and re.match('^-\\d+$', opt.init_img): try: opt.init_img = last_results[int(opt.init_img)][0] print(f'>> Reusing previous image {opt.init_img}') except IndexError: print( f'>> No previous initial image at position {opt.init_img} found') opt.init_img = None continue if opt.seed is not None and opt.seed < 0: # retrieve previous value! try: opt.seed = last_results[opt.seed][1] print(f'>> Reusing previous seed {opt.seed}') except IndexError: print(f'>> No previous seed at position {opt.seed} found') opt.seed = None continue do_grid = opt.grid or t2i.grid if opt.with_variations is not None: # shotgun parsing, woo parts = [] broken = False # python doesn't have labeled loops... for part in opt.with_variations.split(','): seed_and_weight = part.split(':') if len(seed_and_weight) != 2: print(f'could not parse with_variation part "{part}"') broken = True break try: seed = int(seed_and_weight[0]) weight = float(seed_and_weight[1]) except ValueError: print(f'could not parse with_variation part "{part}"') broken = True break parts.append([seed, weight]) if broken: continue if len(parts) > 0: opt.with_variations = parts else: opt.with_variations = None if opt.outdir: if not os.path.exists(opt.outdir): os.makedirs(opt.outdir) current_outdir = opt.outdir elif prompt_as_dir: # sanitize the prompt to a valid folder name subdir = path_filter.sub('_', opt.prompt)[:name_max].rstrip(' .') # truncate path to maximum allowed length # 27 is the length of '######.##########.##.png', plus two separators and a NUL subdir = subdir[:(path_max - 27 - len(os.path.abspath(outdir)))] current_outdir = os.path.join(outdir, subdir) print('Writing files to directory: "' + current_outdir + '"') # make sure the output directory exists if not os.path.exists(current_outdir): os.makedirs(current_outdir) else: current_outdir = outdir # Here is where the images are actually generated! last_results = [] try: file_writer = PngWriter(current_outdir) prefix = file_writer.unique_prefix() results = [] # list of filename, prompt pairs grid_images = dict() # seed -> Image, only used if `do_grid` def image_writer(image, seed, upscaled=False): path = None if do_grid: grid_images[seed] = image else: if upscaled and opt.save_original: filename = f'{prefix}.{seed}.postprocessed.png' else: filename = f'{prefix}.{seed}.png' if opt.variation_amount > 0: iter_opt = argparse.Namespace(**vars(opt)) # copy this_variation = [[seed, opt.variation_amount]] if opt.with_variations is None: iter_opt.with_variations = this_variation else: iter_opt.with_variations = opt.with_variations + this_variation iter_opt.variation_amount = 0 normalized_prompt = PromptFormatter( t2i, iter_opt).normalize_prompt() metadata_prompt = f'{normalized_prompt} -S{iter_opt.seed}' elif opt.with_variations is not None: normalized_prompt = PromptFormatter( t2i, opt).normalize_prompt() # use the original seed - the per-iteration value is the last variation-seed metadata_prompt = f'{normalized_prompt} -S{opt.seed}' else: normalized_prompt = PromptFormatter( t2i, opt).normalize_prompt() metadata_prompt = f'{normalized_prompt} -S{seed}' path = file_writer.save_image_and_prompt_to_png( image, metadata_prompt, filename) if (not upscaled) or opt.save_original: # only append to results if we didn't overwrite an earlier output results.append([path, metadata_prompt]) last_results.append([path, seed]) t2i.prompt2image(image_callback=image_writer, **vars(opt)) if do_grid and len(grid_images) > 0: grid_img = make_grid(list(grid_images.values())) grid_seeds = list(grid_images.keys()) first_seed = last_results[0][1] filename = f'{prefix}.{first_seed}.png' # TODO better metadata for grid images normalized_prompt = PromptFormatter( t2i, opt).normalize_prompt() metadata_prompt = f'{normalized_prompt} -S{first_seed} --grid -n{len(grid_images)} # {grid_seeds}' path = file_writer.save_image_and_prompt_to_png( grid_img, metadata_prompt, filename ) results = [[path, metadata_prompt]] except AssertionError as e: print(e) continue except OSError as e: print(e) continue print('Outputs:') log_path = os.path.join(current_outdir, 'dream_log.txt') write_log_message(results, log_path) print() print('goodbye!') def get_next_command(infile=None) -> str: # command string if infile is None: command = input('dream> ') else: command = infile.readline() if not command: raise EOFError else: command = command.strip() print(f'#{command}') return command def dream_server_loop(t2i, host, port, outdir): print('\n* --web was specified, starting web server...') # Change working directory to the stable-diffusion directory os.chdir( os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) ) # Start server DreamServer.model = t2i DreamServer.outdir = outdir dream_server = ThreadingDreamServer((host, port)) print(">> Started Stable Diffusion dream server!") if host == '0.0.0.0': print( f"Point your browser at http://localhost:{port} or use the host's DNS name or IP address.") else: print(">> Default host address now 127.0.0.1 (localhost). Use --host 0.0.0.0 to bind any address.") print(f">> Point your browser at http://{host}:{port}.") try: dream_server.serve_forever() except KeyboardInterrupt: pass dream_server.server_close() def write_log_message(results, log_path): """logs the name of the output image, prompt, and prompt args to the terminal and log file""" global output_cntr log_lines = [f'{path}: {prompt}\n' for path, prompt in results] for l in log_lines: output_cntr += 1 print(f'[{output_cntr}] {l}',end='') with open(log_path, 'a', encoding='utf-8') as file: file.writelines(log_lines) SAMPLER_CHOICES = [ 'ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms', ] def create_argv_parser(): parser = argparse.ArgumentParser( description="""Generate images using Stable Diffusion. Use --web to launch the web interface. Use --from_file to load prompts from a file path or standard input ("-"). Otherwise you will be dropped into an interactive command prompt (type -h for help.) Other command-line arguments are defaults that can usually be overridden prompt the command prompt. """ ) parser.add_argument( '--laion400m', '--latent_diffusion', '-l', dest='laion400m', action='store_true', help='Fallback to the latent diffusion (laion400m) weights and config', ) parser.add_argument( '--from_file', dest='infile', type=str, help='If specified, load prompts from this file', ) parser.add_argument( '-n', '--iterations', type=int, default=1, help='Number of images to generate', ) parser.add_argument( '-F', '--full_precision', dest='full_precision', action='store_true', help='Use more memory-intensive full precision math for calculations', ) parser.add_argument( '-g', '--grid', action='store_true', help='Generate a grid instead of individual images', ) parser.add_argument( '-A', '-m', '--sampler', dest='sampler_name', choices=SAMPLER_CHOICES, metavar='SAMPLER_NAME', default='k_lms', help=f'Set the initial sampler. Default: k_lms. Supported samplers: {", ".join(SAMPLER_CHOICES)}', ) parser.add_argument( '--outdir', '-o', type=str, default='outputs/img-samples', help='Directory to save generated images and a log of prompts and seeds. Default: outputs/img-samples', ) parser.add_argument( '--seamless', action='store_true', help='Change the model to seamless tiling (circular) mode', ) parser.add_argument( '--embedding_path', type=str, help='Path to a pre-trained embedding manager checkpoint - can only be set on command line', ) parser.add_argument( '--prompt_as_dir', '-p', action='store_true', help='Place images in subdirectories named after the prompt.', ) # GFPGAN related args parser.add_argument( '--gfpgan_bg_upsampler', type=str, default='realesrgan', help='Background upsampler. Default: realesrgan. Options: realesrgan, none.', ) parser.add_argument( '--gfpgan_bg_tile', type=int, default=400, help='Tile size for background sampler, 0 for no tile during testing. Default: 400.', ) parser.add_argument( '--gfpgan_model_path', type=str, default='experiments/pretrained_models/GFPGANv1.3.pth', help='Indicates the path to the GFPGAN model, relative to --gfpgan_dir.', ) parser.add_argument( '--gfpgan_dir', type=str, default='./src/gfpgan', help='Indicates the directory containing the GFPGAN code.', ) parser.add_argument( '--web', dest='web', action='store_true', help='Start in web server mode.', ) parser.add_argument( '--host', type=str, default='127.0.0.1', help='Web server: Host or IP to listen on. Set to 0.0.0.0 to accept traffic from other devices on your network.' ) parser.add_argument( '--port', type=int, default='9090', help='Web server: Port to listen on' ) parser.add_argument( '--weights', default='model', help='Indicates the Stable Diffusion model to use.', ) parser.add_argument( '--device', '-d', type=str, default='cuda', help="device to run stable diffusion on. defaults to cuda `torch.cuda.current_device()` if available" ) parser.add_argument( '--model', default='stable-diffusion-1.4', help='Indicates which diffusion model to load. (currently "stable-diffusion-1.4" (default) or "laion400m")', ) parser.add_argument( '--config', default='configs/models.yaml', help='Path to configuration file for alternate models.', ) return parser def create_cmd_parser(): parser = argparse.ArgumentParser( description='Example: dream> a fantastic alien landscape -W1024 -H960 -s100 -n12' ) parser.add_argument('prompt') parser.add_argument('-s', '--steps', type=int, help='Number of steps') parser.add_argument( '-S', '--seed', type=int, help='Image seed; a +ve integer, or use -1 for the previous seed, -2 for the one before that, etc', ) parser.add_argument( '-n', '--iterations', type=int, default=1, help='Number of samplings to perform (slower, but will provide seeds for individual images)', ) parser.add_argument( '-W', '--width', type=int, help='Image width, multiple of 64' ) parser.add_argument( '-H', '--height', type=int, help='Image height, multiple of 64' ) parser.add_argument( '-C', '--cfg_scale', default=7.5, type=float, help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.', ) parser.add_argument( '-g', '--grid', action='store_true', help='generate a grid' ) parser.add_argument( '--outdir', '-o', type=str, default=None, help='Directory to save generated images and a log of prompts and seeds', ) parser.add_argument( '--seamless', action='store_true', help='Change the model to seamless tiling (circular) mode', ) parser.add_argument( '-i', '--individual', action='store_true', help='Generate individual files (default)', ) parser.add_argument( '-I', '--init_img', type=str, help='Path to input image for img2img mode (supersedes width and height)', ) parser.add_argument( '-M', '--init_mask', type=str, help='Path to input mask for inpainting mode (supersedes width and height)', ) parser.add_argument( '-T', '-fit', '--fit', action='store_true', help='If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)', ) parser.add_argument( '-f', '--strength', default=0.75, type=float, help='Strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely', ) parser.add_argument( '-G', '--gfpgan_strength', default=0, type=float, help='The strength at which to apply the GFPGAN model to the result, in order to improve faces.', ) parser.add_argument( '-U', '--upscale', nargs='+', default=None, type=float, help='Scale factor (2, 4) for upscaling followed by upscaling strength (0-1.0). If strength not specified, defaults to 0.75' ) parser.add_argument( '-save_orig', '--save_original', action='store_true', help='Save original. Use it when upscaling to save both versions.', ) # variants is going to be superseded by a generalized "prompt-morph" function # parser.add_argument('-v','--variants',type=int,help="in img2img mode, the first generated image will get passed back to img2img to generate the requested number of variants") parser.add_argument( '-x', '--skip_normalize', action='store_true', help='Skip subprompt weight normalization', ) parser.add_argument( '-A', '-m', '--sampler', dest='sampler_name', default=None, type=str, choices=SAMPLER_CHOICES, metavar='SAMPLER_NAME', help=f'Switch to a different sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}', ) parser.add_argument( '-t', '--log_tokenization', action='store_true', help='shows how the prompt is split into tokens' ) parser.add_argument( '-v', '--variation_amount', default=0.0, type=float, help='If > 0, generates variations on the initial seed instead of random seeds per iteration. Must be between 0 and 1. Higher values will be more different.' ) parser.add_argument( '-V', '--with_variations', default=None, type=str, help='list of variations to apply, in the format `seed:weight,seed:weight,...' ) return parser if __name__ == '__main__': main()