diff --git a/ldm/dream/server.py b/ldm/dream/server.py
index 384cbc37a8..ed176fc457 100644
--- a/ldm/dream/server.py
+++ b/ldm/dream/server.py
@@ -1,11 +1,59 @@
+import argparse
import json
import base64
import mimetypes
import os
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
-from ldm.dream.pngwriter import PngWriter
+from ldm.dream.pngwriter import PngWriter, PromptFormatter
from threading import Event
+def build_opt(post_data, seed, gfpgan_model_exists):
+ opt = argparse.Namespace()
+ setattr(opt, 'prompt', post_data['prompt'])
+ setattr(opt, 'init_img', post_data['initimg'])
+ setattr(opt, 'strength', float(post_data['strength']))
+ setattr(opt, 'iterations', int(post_data['iterations']))
+ setattr(opt, 'steps', int(post_data['steps']))
+ setattr(opt, 'width', int(post_data['width']))
+ setattr(opt, 'height', int(post_data['height']))
+ setattr(opt, 'seamless', 'seamless' in post_data)
+ setattr(opt, 'fit', 'fit' in post_data)
+ setattr(opt, 'mask', 'mask' in post_data)
+ setattr(opt, 'invert_mask', 'invert_mask' in post_data)
+ setattr(opt, 'cfg_scale', float(post_data['cfg_scale']))
+ setattr(opt, 'sampler_name', post_data['sampler_name'])
+ setattr(opt, 'gfpgan_strength', float(post_data['gfpgan_strength']) if gfpgan_model_exists else 0)
+ setattr(opt, 'upscale', [int(post_data['upscale_level']), float(post_data['upscale_strength'])] if post_data['upscale_level'] != '' else None)
+ setattr(opt, 'progress_images', 'progress_images' in post_data)
+ setattr(opt, 'seed', seed if int(post_data['seed']) == -1 else int(post_data['seed']))
+ setattr(opt, 'variation_amount', float(post_data['variation_amount']) if int(post_data['seed']) != -1 else 0)
+ setattr(opt, 'with_variations', [])
+
+ broken = False
+ if int(post_data['seed']) != -1 and post_data['with_variations'] != '':
+ for part in post_data['with_variations'].split(','):
+ seed_and_weight = part.split(':')
+ if len(seed_and_weight) != 2:
+ print(f'could not parse with_variation part "{part}"')
+ broken = True
+ break
+ try:
+ seed = int(seed_and_weight[0])
+ weight = float(seed_and_weight[1])
+ except ValueError:
+ print(f'could not parse with_variation part "{part}"')
+ broken = True
+ break
+ opt.with_variations.append([seed, weight])
+
+ if broken:
+ raise CanceledException
+
+ if len(opt.with_variations) == 0:
+ opt.with_variations = None
+
+ return opt
+
class CanceledException(Exception):
pass
@@ -64,57 +112,15 @@ class DreamServer(BaseHTTPRequestHandler):
content_length = int(self.headers['Content-Length'])
post_data = json.loads(self.rfile.read(content_length))
- prompt = post_data['prompt']
- initimg = post_data['initimg']
- strength = float(post_data['strength'])
- iterations = int(post_data['iterations'])
- steps = int(post_data['steps'])
- width = int(post_data['width'])
- height = int(post_data['height'])
- fit = 'fit' in post_data
- seamless = 'seamless' in post_data
- cfgscale = float(post_data['cfgscale'])
- sampler_name = post_data['sampler']
- variation_amount = float(post_data['variation_amount']) if int(post_data['seed']) == -1 else 0.0
- with_variations = post_data['with_variations'] if int(post_data['seed']) == -1 else ''
- gfpgan_strength = float(post_data['gfpgan_strength']) if gfpgan_model_exists else 0
- upscale_level = post_data['upscale_level']
- upscale_strength = post_data['upscale_strength']
- upscale = [int(upscale_level),float(upscale_strength)] if upscale_level != '' else None
- progress_images = 'progress_images' in post_data
- seed = self.model.seed if int(post_data['seed']) == -1 else int(post_data['seed'])
-
- if with_variations != '':
- parts = []
- broken = False
- for part in with_variations.split(','):
- seed_and_weight = part.split(':')
- if len(seed_and_weight) != 2:
- print(f'could not parse with_variation part "{part}"')
- broken = True
- break
- try:
- vseed = int(seed_and_weight[0])
- vweight = float(seed_and_weight[1])
- except ValueError:
- print(f'could not parse with_variation part "{part}"')
- broken = True
- break
- parts.append([vseed, vweight])
- if broken:
- raise CanceledException
- if len(parts) > 0:
- with_variations = parts
- else:
- with_variations = None
+ opt = build_opt(post_data, self.model.seed, gfpgan_model_exists)
self.canceled.clear()
- print(f">> Request to generate with prompt: {prompt}")
+ print(f">> Request to generate with prompt: {opt.prompt}")
# In order to handle upscaled images, the PngWriter needs to maintain state
# across images generated by each call to prompt2img(), so we define it in
# the outer scope of image_done()
config = post_data.copy() # Shallow copy
- config['initimg'] = config.pop('initimg_name','')
+ config['initimg'] = config.pop('initimg_name', '')
images_generated = 0 # helps keep track of when upscaling is started
images_upscaled = 0 # helps keep track of when upscaling is completed
@@ -127,7 +133,18 @@ class DreamServer(BaseHTTPRequestHandler):
# entry should not be inserted into the image list.
def image_done(image, seed, upscaled=False):
name = f'{prefix}.{seed}.png'
- path = pngwriter.save_image_and_prompt_to_png(image, f'{prompt} -S{seed}', name)
+ iter_opt = argparse.Namespace(**vars(opt)) # copy
+ if opt.variation_amount > 0:
+ this_variation = [[seed, opt.variation_amount]]
+ if opt.with_variations is None:
+ iter_opt.with_variations = this_variation
+ else:
+ iter_opt.with_variations = opt.with_variations + this_variation
+ iter_opt.variation_amount = 0
+ elif opt.with_variations is None:
+ iter_opt.seed = seed
+ normalized_prompt = PromptFormatter(self.model, iter_opt).normalize_prompt()
+ path = pngwriter.save_image_and_prompt_to_png(image, f'{normalized_prompt} -S{iter_opt.seed}', name)
if int(config['seed']) == -1:
config['seed'] = seed
@@ -141,24 +158,24 @@ class DreamServer(BaseHTTPRequestHandler):
) + '\n',"utf-8"))
# control state of the "postprocessing..." message
- upscaling_requested = upscale or gfpgan_strength>0
+ upscaling_requested = opt.upscale or opt.gfpgan_strength > 0
nonlocal images_generated # NB: Is this bad python style? It is typical usage in a perl closure.
nonlocal images_upscaled # NB: Is this bad python style? It is typical usage in a perl closure.
if upscaled:
images_upscaled += 1
else:
- images_generated +=1
+ images_generated += 1
if upscaling_requested:
action = None
- if images_generated >= iterations:
- if images_upscaled < iterations:
+ if images_generated >= opt.iterations:
+ if images_upscaled < opt.iterations:
action = 'upscaling-started'
else:
action = 'upscaling-done'
if action:
- x = images_upscaled+1
+ x = images_upscaled + 1
self.wfile.write(bytes(json.dumps(
- {'event':action,'processed_file_cnt':f'{x}/{iterations}'}
+ {'event': action, 'processed_file_cnt': f'{x}/{opt.iterations}'}
) + '\n',"utf-8"))
step_writer = PngWriter(os.path.join(self.outdir, "intermediates"))
@@ -171,10 +188,10 @@ class DreamServer(BaseHTTPRequestHandler):
# since rendering images is moderately expensive, only render every 5th image
# and don't bother with the last one, since it'll render anyway
nonlocal step_index
- if progress_images and step % 5 == 0 and step < steps - 1:
+ if opt.progress_images and step % 5 == 0 and step < opt.steps - 1:
image = self.model.sample_to_image(sample)
- name = f'{prefix}.{seed}.{step_index}.png'
- metadata = f'{prompt} -S{seed} [intermediate]'
+ name = f'{prefix}.{opt.seed}.{step_index}.png'
+ metadata = f'{opt.prompt} -S{opt.seed} [intermediate]'
path = step_writer.save_image_and_prompt_to_png(image, metadata, name)
step_index += 1
self.wfile.write(bytes(json.dumps(
@@ -182,49 +199,20 @@ class DreamServer(BaseHTTPRequestHandler):
) + '\n',"utf-8"))
try:
- if initimg is None:
+ if opt.init_img is None:
# Run txt2img
- self.model.prompt2image(prompt,
- iterations=iterations,
- cfg_scale = cfgscale,
- width = width,
- height = height,
- seed = seed,
- steps = steps,
- variation_amount = variation_amount,
- with_variations = with_variations,
- gfpgan_strength = gfpgan_strength,
- upscale = upscale,
- sampler_name = sampler_name,
- seamless = seamless,
- step_callback=image_progress,
- image_callback=image_done)
+ self.model.prompt2image(**vars(opt), step_callback=image_progress, image_callback=image_done)
else:
# Decode initimg as base64 to temp file
with open("./img2img-tmp.png", "wb") as f:
- initimg = initimg.split(",")[1] # Ignore mime type
+ initimg = opt.init_img.split(",")[1] # Ignore mime type
f.write(base64.b64decode(initimg))
+ opt1 = argparse.Namespace(**vars(opt))
+ opt1.init_img = "./img2img-tmp.png"
try:
# Run img2img
- self.model.prompt2image(prompt,
- init_img = "./img2img-tmp.png",
- strength = strength,
- iterations = iterations,
- cfg_scale = cfgscale,
- seed = seed,
- steps = steps,
- variation_amount = variation_amount,
- with_variations = with_variations,
- sampler_name = sampler_name,
- width = width,
- height = height,
- fit = fit,
- seamless = seamless,
- gfpgan_strength=gfpgan_strength,
- upscale = upscale,
- step_callback=image_progress,
- image_callback=image_done)
+ self.model.prompt2image(**vars(opt1), step_callback=image_progress, image_callback=image_done)
finally:
# Remove the temp file
os.remove("./img2img-tmp.png")
diff --git a/static/dream_web/index.html b/static/dream_web/index.html
index 466e165ffd..b99fa045a9 100644
--- a/static/dream_web/index.html
+++ b/static/dream_web/index.html
@@ -30,21 +30,21 @@
-
-
-
-