From 171f8db742f18532b6fa03cdfbf4be2bbf6cf3ad Mon Sep 17 00:00:00 2001 From: Denis Olshin Date: Thu, 8 Sep 2022 03:15:20 +0300 Subject: [PATCH] saving full prompt to metadata when using web ui --- ldm/dream/server.py | 170 +++++++++++++++++------------------- static/dream_web/index.html | 12 +-- 2 files changed, 85 insertions(+), 97 deletions(-) diff --git a/ldm/dream/server.py b/ldm/dream/server.py index 384cbc37a8..ed176fc457 100644 --- a/ldm/dream/server.py +++ b/ldm/dream/server.py @@ -1,11 +1,59 @@ +import argparse import json import base64 import mimetypes import os from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer -from ldm.dream.pngwriter import PngWriter +from ldm.dream.pngwriter import PngWriter, PromptFormatter from threading import Event +def build_opt(post_data, seed, gfpgan_model_exists): + opt = argparse.Namespace() + setattr(opt, 'prompt', post_data['prompt']) + setattr(opt, 'init_img', post_data['initimg']) + setattr(opt, 'strength', float(post_data['strength'])) + setattr(opt, 'iterations', int(post_data['iterations'])) + setattr(opt, 'steps', int(post_data['steps'])) + setattr(opt, 'width', int(post_data['width'])) + setattr(opt, 'height', int(post_data['height'])) + setattr(opt, 'seamless', 'seamless' in post_data) + setattr(opt, 'fit', 'fit' in post_data) + setattr(opt, 'mask', 'mask' in post_data) + setattr(opt, 'invert_mask', 'invert_mask' in post_data) + setattr(opt, 'cfg_scale', float(post_data['cfg_scale'])) + setattr(opt, 'sampler_name', post_data['sampler_name']) + setattr(opt, 'gfpgan_strength', float(post_data['gfpgan_strength']) if gfpgan_model_exists else 0) + setattr(opt, 'upscale', [int(post_data['upscale_level']), float(post_data['upscale_strength'])] if post_data['upscale_level'] != '' else None) + setattr(opt, 'progress_images', 'progress_images' in post_data) + setattr(opt, 'seed', seed if int(post_data['seed']) == -1 else int(post_data['seed'])) + setattr(opt, 'variation_amount', float(post_data['variation_amount']) if int(post_data['seed']) != -1 else 0) + setattr(opt, 'with_variations', []) + + broken = False + if int(post_data['seed']) != -1 and post_data['with_variations'] != '': + for part in post_data['with_variations'].split(','): + seed_and_weight = part.split(':') + if len(seed_and_weight) != 2: + print(f'could not parse with_variation part "{part}"') + broken = True + break + try: + seed = int(seed_and_weight[0]) + weight = float(seed_and_weight[1]) + except ValueError: + print(f'could not parse with_variation part "{part}"') + broken = True + break + opt.with_variations.append([seed, weight]) + + if broken: + raise CanceledException + + if len(opt.with_variations) == 0: + opt.with_variations = None + + return opt + class CanceledException(Exception): pass @@ -64,57 +112,15 @@ class DreamServer(BaseHTTPRequestHandler): content_length = int(self.headers['Content-Length']) post_data = json.loads(self.rfile.read(content_length)) - prompt = post_data['prompt'] - initimg = post_data['initimg'] - strength = float(post_data['strength']) - iterations = int(post_data['iterations']) - steps = int(post_data['steps']) - width = int(post_data['width']) - height = int(post_data['height']) - fit = 'fit' in post_data - seamless = 'seamless' in post_data - cfgscale = float(post_data['cfgscale']) - sampler_name = post_data['sampler'] - variation_amount = float(post_data['variation_amount']) if int(post_data['seed']) == -1 else 0.0 - with_variations = post_data['with_variations'] if int(post_data['seed']) == -1 else '' - gfpgan_strength = float(post_data['gfpgan_strength']) if gfpgan_model_exists else 0 - upscale_level = post_data['upscale_level'] - upscale_strength = post_data['upscale_strength'] - upscale = [int(upscale_level),float(upscale_strength)] if upscale_level != '' else None - progress_images = 'progress_images' in post_data - seed = self.model.seed if int(post_data['seed']) == -1 else int(post_data['seed']) - - if with_variations != '': - parts = [] - broken = False - for part in with_variations.split(','): - seed_and_weight = part.split(':') - if len(seed_and_weight) != 2: - print(f'could not parse with_variation part "{part}"') - broken = True - break - try: - vseed = int(seed_and_weight[0]) - vweight = float(seed_and_weight[1]) - except ValueError: - print(f'could not parse with_variation part "{part}"') - broken = True - break - parts.append([vseed, vweight]) - if broken: - raise CanceledException - if len(parts) > 0: - with_variations = parts - else: - with_variations = None + opt = build_opt(post_data, self.model.seed, gfpgan_model_exists) self.canceled.clear() - print(f">> Request to generate with prompt: {prompt}") + print(f">> Request to generate with prompt: {opt.prompt}") # In order to handle upscaled images, the PngWriter needs to maintain state # across images generated by each call to prompt2img(), so we define it in # the outer scope of image_done() config = post_data.copy() # Shallow copy - config['initimg'] = config.pop('initimg_name','') + config['initimg'] = config.pop('initimg_name', '') images_generated = 0 # helps keep track of when upscaling is started images_upscaled = 0 # helps keep track of when upscaling is completed @@ -127,7 +133,18 @@ class DreamServer(BaseHTTPRequestHandler): # entry should not be inserted into the image list. def image_done(image, seed, upscaled=False): name = f'{prefix}.{seed}.png' - path = pngwriter.save_image_and_prompt_to_png(image, f'{prompt} -S{seed}', name) + iter_opt = argparse.Namespace(**vars(opt)) # copy + if opt.variation_amount > 0: + this_variation = [[seed, opt.variation_amount]] + if opt.with_variations is None: + iter_opt.with_variations = this_variation + else: + iter_opt.with_variations = opt.with_variations + this_variation + iter_opt.variation_amount = 0 + elif opt.with_variations is None: + iter_opt.seed = seed + normalized_prompt = PromptFormatter(self.model, iter_opt).normalize_prompt() + path = pngwriter.save_image_and_prompt_to_png(image, f'{normalized_prompt} -S{iter_opt.seed}', name) if int(config['seed']) == -1: config['seed'] = seed @@ -141,24 +158,24 @@ class DreamServer(BaseHTTPRequestHandler): ) + '\n',"utf-8")) # control state of the "postprocessing..." message - upscaling_requested = upscale or gfpgan_strength>0 + upscaling_requested = opt.upscale or opt.gfpgan_strength > 0 nonlocal images_generated # NB: Is this bad python style? It is typical usage in a perl closure. nonlocal images_upscaled # NB: Is this bad python style? It is typical usage in a perl closure. if upscaled: images_upscaled += 1 else: - images_generated +=1 + images_generated += 1 if upscaling_requested: action = None - if images_generated >= iterations: - if images_upscaled < iterations: + if images_generated >= opt.iterations: + if images_upscaled < opt.iterations: action = 'upscaling-started' else: action = 'upscaling-done' if action: - x = images_upscaled+1 + x = images_upscaled + 1 self.wfile.write(bytes(json.dumps( - {'event':action,'processed_file_cnt':f'{x}/{iterations}'} + {'event': action, 'processed_file_cnt': f'{x}/{opt.iterations}'} ) + '\n',"utf-8")) step_writer = PngWriter(os.path.join(self.outdir, "intermediates")) @@ -171,10 +188,10 @@ class DreamServer(BaseHTTPRequestHandler): # since rendering images is moderately expensive, only render every 5th image # and don't bother with the last one, since it'll render anyway nonlocal step_index - if progress_images and step % 5 == 0 and step < steps - 1: + if opt.progress_images and step % 5 == 0 and step < opt.steps - 1: image = self.model.sample_to_image(sample) - name = f'{prefix}.{seed}.{step_index}.png' - metadata = f'{prompt} -S{seed} [intermediate]' + name = f'{prefix}.{opt.seed}.{step_index}.png' + metadata = f'{opt.prompt} -S{opt.seed} [intermediate]' path = step_writer.save_image_and_prompt_to_png(image, metadata, name) step_index += 1 self.wfile.write(bytes(json.dumps( @@ -182,49 +199,20 @@ class DreamServer(BaseHTTPRequestHandler): ) + '\n',"utf-8")) try: - if initimg is None: + if opt.init_img is None: # Run txt2img - self.model.prompt2image(prompt, - iterations=iterations, - cfg_scale = cfgscale, - width = width, - height = height, - seed = seed, - steps = steps, - variation_amount = variation_amount, - with_variations = with_variations, - gfpgan_strength = gfpgan_strength, - upscale = upscale, - sampler_name = sampler_name, - seamless = seamless, - step_callback=image_progress, - image_callback=image_done) + self.model.prompt2image(**vars(opt), step_callback=image_progress, image_callback=image_done) else: # Decode initimg as base64 to temp file with open("./img2img-tmp.png", "wb") as f: - initimg = initimg.split(",")[1] # Ignore mime type + initimg = opt.init_img.split(",")[1] # Ignore mime type f.write(base64.b64decode(initimg)) + opt1 = argparse.Namespace(**vars(opt)) + opt1.init_img = "./img2img-tmp.png" try: # Run img2img - self.model.prompt2image(prompt, - init_img = "./img2img-tmp.png", - strength = strength, - iterations = iterations, - cfg_scale = cfgscale, - seed = seed, - steps = steps, - variation_amount = variation_amount, - with_variations = with_variations, - sampler_name = sampler_name, - width = width, - height = height, - fit = fit, - seamless = seamless, - gfpgan_strength=gfpgan_strength, - upscale = upscale, - step_callback=image_progress, - image_callback=image_done) + self.model.prompt2image(**vars(opt1), step_callback=image_progress, image_callback=image_done) finally: # Remove the temp file os.remove("./img2img-tmp.png") diff --git a/static/dream_web/index.html b/static/dream_web/index.html index 466e165ffd..b99fa045a9 100644 --- a/static/dream_web/index.html +++ b/static/dream_web/index.html @@ -30,21 +30,21 @@ - - - - + + - +