saving full prompt to metadata when using web ui

This commit is contained in:
Denis Olshin 2022-09-08 03:15:20 +03:00
parent d7e67b62f0
commit 171f8db742
2 changed files with 85 additions and 97 deletions

View File

@ -1,11 +1,59 @@
import argparse
import json import json
import base64 import base64
import mimetypes import mimetypes
import os import os
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from ldm.dream.pngwriter import PngWriter from ldm.dream.pngwriter import PngWriter, PromptFormatter
from threading import Event from threading import Event
def build_opt(post_data, seed, gfpgan_model_exists):
opt = argparse.Namespace()
setattr(opt, 'prompt', post_data['prompt'])
setattr(opt, 'init_img', post_data['initimg'])
setattr(opt, 'strength', float(post_data['strength']))
setattr(opt, 'iterations', int(post_data['iterations']))
setattr(opt, 'steps', int(post_data['steps']))
setattr(opt, 'width', int(post_data['width']))
setattr(opt, 'height', int(post_data['height']))
setattr(opt, 'seamless', 'seamless' in post_data)
setattr(opt, 'fit', 'fit' in post_data)
setattr(opt, 'mask', 'mask' in post_data)
setattr(opt, 'invert_mask', 'invert_mask' in post_data)
setattr(opt, 'cfg_scale', float(post_data['cfg_scale']))
setattr(opt, 'sampler_name', post_data['sampler_name'])
setattr(opt, 'gfpgan_strength', float(post_data['gfpgan_strength']) if gfpgan_model_exists else 0)
setattr(opt, 'upscale', [int(post_data['upscale_level']), float(post_data['upscale_strength'])] if post_data['upscale_level'] != '' else None)
setattr(opt, 'progress_images', 'progress_images' in post_data)
setattr(opt, 'seed', seed if int(post_data['seed']) == -1 else int(post_data['seed']))
setattr(opt, 'variation_amount', float(post_data['variation_amount']) if int(post_data['seed']) != -1 else 0)
setattr(opt, 'with_variations', [])
broken = False
if int(post_data['seed']) != -1 and post_data['with_variations'] != '':
for part in post_data['with_variations'].split(','):
seed_and_weight = part.split(':')
if len(seed_and_weight) != 2:
print(f'could not parse with_variation part "{part}"')
broken = True
break
try:
seed = int(seed_and_weight[0])
weight = float(seed_and_weight[1])
except ValueError:
print(f'could not parse with_variation part "{part}"')
broken = True
break
opt.with_variations.append([seed, weight])
if broken:
raise CanceledException
if len(opt.with_variations) == 0:
opt.with_variations = None
return opt
class CanceledException(Exception): class CanceledException(Exception):
pass pass
@ -64,52 +112,10 @@ class DreamServer(BaseHTTPRequestHandler):
content_length = int(self.headers['Content-Length']) content_length = int(self.headers['Content-Length'])
post_data = json.loads(self.rfile.read(content_length)) post_data = json.loads(self.rfile.read(content_length))
prompt = post_data['prompt'] opt = build_opt(post_data, self.model.seed, gfpgan_model_exists)
initimg = post_data['initimg']
strength = float(post_data['strength'])
iterations = int(post_data['iterations'])
steps = int(post_data['steps'])
width = int(post_data['width'])
height = int(post_data['height'])
fit = 'fit' in post_data
seamless = 'seamless' in post_data
cfgscale = float(post_data['cfgscale'])
sampler_name = post_data['sampler']
variation_amount = float(post_data['variation_amount']) if int(post_data['seed']) == -1 else 0.0
with_variations = post_data['with_variations'] if int(post_data['seed']) == -1 else ''
gfpgan_strength = float(post_data['gfpgan_strength']) if gfpgan_model_exists else 0
upscale_level = post_data['upscale_level']
upscale_strength = post_data['upscale_strength']
upscale = [int(upscale_level),float(upscale_strength)] if upscale_level != '' else None
progress_images = 'progress_images' in post_data
seed = self.model.seed if int(post_data['seed']) == -1 else int(post_data['seed'])
if with_variations != '':
parts = []
broken = False
for part in with_variations.split(','):
seed_and_weight = part.split(':')
if len(seed_and_weight) != 2:
print(f'could not parse with_variation part "{part}"')
broken = True
break
try:
vseed = int(seed_and_weight[0])
vweight = float(seed_and_weight[1])
except ValueError:
print(f'could not parse with_variation part "{part}"')
broken = True
break
parts.append([vseed, vweight])
if broken:
raise CanceledException
if len(parts) > 0:
with_variations = parts
else:
with_variations = None
self.canceled.clear() self.canceled.clear()
print(f">> Request to generate with prompt: {prompt}") print(f">> Request to generate with prompt: {opt.prompt}")
# In order to handle upscaled images, the PngWriter needs to maintain state # In order to handle upscaled images, the PngWriter needs to maintain state
# across images generated by each call to prompt2img(), so we define it in # across images generated by each call to prompt2img(), so we define it in
# the outer scope of image_done() # the outer scope of image_done()
@ -127,7 +133,18 @@ class DreamServer(BaseHTTPRequestHandler):
# entry should not be inserted into the image list. # entry should not be inserted into the image list.
def image_done(image, seed, upscaled=False): def image_done(image, seed, upscaled=False):
name = f'{prefix}.{seed}.png' name = f'{prefix}.{seed}.png'
path = pngwriter.save_image_and_prompt_to_png(image, f'{prompt} -S{seed}', name) iter_opt = argparse.Namespace(**vars(opt)) # copy
if opt.variation_amount > 0:
this_variation = [[seed, opt.variation_amount]]
if opt.with_variations is None:
iter_opt.with_variations = this_variation
else:
iter_opt.with_variations = opt.with_variations + this_variation
iter_opt.variation_amount = 0
elif opt.with_variations is None:
iter_opt.seed = seed
normalized_prompt = PromptFormatter(self.model, iter_opt).normalize_prompt()
path = pngwriter.save_image_and_prompt_to_png(image, f'{normalized_prompt} -S{iter_opt.seed}', name)
if int(config['seed']) == -1: if int(config['seed']) == -1:
config['seed'] = seed config['seed'] = seed
@ -141,7 +158,7 @@ class DreamServer(BaseHTTPRequestHandler):
) + '\n',"utf-8")) ) + '\n',"utf-8"))
# control state of the "postprocessing..." message # control state of the "postprocessing..." message
upscaling_requested = upscale or gfpgan_strength>0 upscaling_requested = opt.upscale or opt.gfpgan_strength > 0
nonlocal images_generated # NB: Is this bad python style? It is typical usage in a perl closure. nonlocal images_generated # NB: Is this bad python style? It is typical usage in a perl closure.
nonlocal images_upscaled # NB: Is this bad python style? It is typical usage in a perl closure. nonlocal images_upscaled # NB: Is this bad python style? It is typical usage in a perl closure.
if upscaled: if upscaled:
@ -150,15 +167,15 @@ class DreamServer(BaseHTTPRequestHandler):
images_generated += 1 images_generated += 1
if upscaling_requested: if upscaling_requested:
action = None action = None
if images_generated >= iterations: if images_generated >= opt.iterations:
if images_upscaled < iterations: if images_upscaled < opt.iterations:
action = 'upscaling-started' action = 'upscaling-started'
else: else:
action = 'upscaling-done' action = 'upscaling-done'
if action: if action:
x = images_upscaled + 1 x = images_upscaled + 1
self.wfile.write(bytes(json.dumps( self.wfile.write(bytes(json.dumps(
{'event':action,'processed_file_cnt':f'{x}/{iterations}'} {'event': action, 'processed_file_cnt': f'{x}/{opt.iterations}'}
) + '\n',"utf-8")) ) + '\n',"utf-8"))
step_writer = PngWriter(os.path.join(self.outdir, "intermediates")) step_writer = PngWriter(os.path.join(self.outdir, "intermediates"))
@ -171,10 +188,10 @@ class DreamServer(BaseHTTPRequestHandler):
# since rendering images is moderately expensive, only render every 5th image # since rendering images is moderately expensive, only render every 5th image
# and don't bother with the last one, since it'll render anyway # and don't bother with the last one, since it'll render anyway
nonlocal step_index nonlocal step_index
if progress_images and step % 5 == 0 and step < steps - 1: if opt.progress_images and step % 5 == 0 and step < opt.steps - 1:
image = self.model.sample_to_image(sample) image = self.model.sample_to_image(sample)
name = f'{prefix}.{seed}.{step_index}.png' name = f'{prefix}.{opt.seed}.{step_index}.png'
metadata = f'{prompt} -S{seed} [intermediate]' metadata = f'{opt.prompt} -S{opt.seed} [intermediate]'
path = step_writer.save_image_and_prompt_to_png(image, metadata, name) path = step_writer.save_image_and_prompt_to_png(image, metadata, name)
step_index += 1 step_index += 1
self.wfile.write(bytes(json.dumps( self.wfile.write(bytes(json.dumps(
@ -182,49 +199,20 @@ class DreamServer(BaseHTTPRequestHandler):
) + '\n',"utf-8")) ) + '\n',"utf-8"))
try: try:
if initimg is None: if opt.init_img is None:
# Run txt2img # Run txt2img
self.model.prompt2image(prompt, self.model.prompt2image(**vars(opt), step_callback=image_progress, image_callback=image_done)
iterations=iterations,
cfg_scale = cfgscale,
width = width,
height = height,
seed = seed,
steps = steps,
variation_amount = variation_amount,
with_variations = with_variations,
gfpgan_strength = gfpgan_strength,
upscale = upscale,
sampler_name = sampler_name,
seamless = seamless,
step_callback=image_progress,
image_callback=image_done)
else: else:
# Decode initimg as base64 to temp file # Decode initimg as base64 to temp file
with open("./img2img-tmp.png", "wb") as f: with open("./img2img-tmp.png", "wb") as f:
initimg = initimg.split(",")[1] # Ignore mime type initimg = opt.init_img.split(",")[1] # Ignore mime type
f.write(base64.b64decode(initimg)) f.write(base64.b64decode(initimg))
opt1 = argparse.Namespace(**vars(opt))
opt1.init_img = "./img2img-tmp.png"
try: try:
# Run img2img # Run img2img
self.model.prompt2image(prompt, self.model.prompt2image(**vars(opt1), step_callback=image_progress, image_callback=image_done)
init_img = "./img2img-tmp.png",
strength = strength,
iterations = iterations,
cfg_scale = cfgscale,
seed = seed,
steps = steps,
variation_amount = variation_amount,
with_variations = with_variations,
sampler_name = sampler_name,
width = width,
height = height,
fit = fit,
seamless = seamless,
gfpgan_strength=gfpgan_strength,
upscale = upscale,
step_callback=image_progress,
image_callback=image_done)
finally: finally:
# Remove the temp file # Remove the temp file
os.remove("./img2img-tmp.png") os.remove("./img2img-tmp.png")

View File

@ -30,10 +30,10 @@
<input value="1" type="number" id="iterations" name="iterations" size="4"> <input value="1" type="number" id="iterations" name="iterations" size="4">
<label for="steps">Steps:</label> <label for="steps">Steps:</label>
<input value="50" type="number" id="steps" name="steps"> <input value="50" type="number" id="steps" name="steps">
<label for="cfgscale">Cfg Scale:</label> <label for="cfg_scale">Cfg Scale:</label>
<input value="7.5" type="number" id="cfgscale" name="cfgscale" step="any"> <input value="7.5" type="number" id="cfg_scale" name="cfg_scale" step="any">
<label for="sampler">Sampler:</label> <label for="sampler_name">Sampler:</label>
<select id="sampler" name="sampler" value="k_lms"> <select id="sampler_name" name="sampler_name" value="k_lms">
<option value="ddim">DDIM</option> <option value="ddim">DDIM</option>
<option value="plms">PLMS</option> <option value="plms">PLMS</option>
<option value="k_lms" selected>KLMS</option> <option value="k_lms" selected>KLMS</option>