InvokeAI/ldm/dream/server.py

248 lines
11 KiB
Python
Raw Normal View History

import argparse
2022-08-28 21:27:43 +00:00
import json
import base64
import mimetypes
import os
2022-08-28 21:27:43 +00:00
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from ldm.dream.pngwriter import PngWriter, PromptFormatter
2022-08-30 15:55:40 +00:00
from threading import Event
def build_opt(post_data, seed, gfpgan_model_exists):
opt = argparse.Namespace()
setattr(opt, 'prompt', post_data['prompt'])
setattr(opt, 'init_img', post_data['initimg'])
setattr(opt, 'strength', float(post_data['strength']))
setattr(opt, 'iterations', int(post_data['iterations']))
setattr(opt, 'steps', int(post_data['steps']))
setattr(opt, 'width', int(post_data['width']))
setattr(opt, 'height', int(post_data['height']))
setattr(opt, 'seamless', 'seamless' in post_data)
setattr(opt, 'fit', 'fit' in post_data)
setattr(opt, 'mask', 'mask' in post_data)
setattr(opt, 'invert_mask', 'invert_mask' in post_data)
setattr(opt, 'cfg_scale', float(post_data['cfg_scale']))
setattr(opt, 'sampler_name', post_data['sampler_name'])
setattr(opt, 'gfpgan_strength', float(post_data['gfpgan_strength']) if gfpgan_model_exists else 0)
setattr(opt, 'upscale', [int(post_data['upscale_level']), float(post_data['upscale_strength'])] if post_data['upscale_level'] != '' else None)
setattr(opt, 'progress_images', 'progress_images' in post_data)
setattr(opt, 'seed', None if int(post_data['seed']) == -1 else int(post_data['seed']))
setattr(opt, 'variation_amount', float(post_data['variation_amount']) if int(post_data['seed']) != -1 else 0)
setattr(opt, 'with_variations', [])
broken = False
if int(post_data['seed']) != -1 and post_data['with_variations'] != '':
for part in post_data['with_variations'].split(','):
seed_and_weight = part.split(':')
if len(seed_and_weight) != 2:
print(f'could not parse with_variation part "{part}"')
broken = True
break
try:
seed = int(seed_and_weight[0])
weight = float(seed_and_weight[1])
except ValueError:
print(f'could not parse with_variation part "{part}"')
broken = True
break
opt.with_variations.append([seed, weight])
if broken:
raise CanceledException
if len(opt.with_variations) == 0:
opt.with_variations = None
return opt
2022-08-30 15:55:40 +00:00
class CanceledException(Exception):
pass
2022-08-28 21:27:43 +00:00
class DreamServer(BaseHTTPRequestHandler):
model = None
outdir = None
2022-08-30 15:55:40 +00:00
canceled = Event()
2022-08-28 21:27:43 +00:00
def do_GET(self):
if self.path == "/":
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
with open("./static/dream_web/index.html", "rb") as content:
self.wfile.write(content.read())
2022-08-29 15:53:27 +00:00
elif self.path == "/config.js":
# unfortunately this import can't be at the top level, since that would cause a circular import
from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
self.send_response(200)
self.send_header("Content-type", "application/javascript")
self.end_headers()
config = {
'gfpgan_model_exists': gfpgan_model_exists
}
self.wfile.write(bytes("let config = " + json.dumps(config) + ";\n", "utf-8"))
elif self.path == "/run_log.json":
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
output = []
log_file = os.path.join(self.outdir, "dream_web_log.txt")
if os.path.exists(log_file):
with open(log_file, "r") as log:
for line in log:
url, config = line.split(": {", maxsplit=1)
config = json.loads("{" + config)
config["url"] = url.lstrip(".")
if os.path.exists(url):
output.append(config)
self.wfile.write(bytes(json.dumps({"run_log": output}), "utf-8"))
2022-08-30 15:55:40 +00:00
elif self.path == "/cancel":
self.canceled.set()
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
self.wfile.write(bytes('{}', 'utf8'))
else:
path_dir = os.path.dirname(self.path)
out_dir = os.path.realpath(self.outdir.rstrip('/'))
if self.path.startswith('/static/dream_web/'):
path = '.' + self.path
elif out_dir.endswith(path_dir):
file = os.path.basename(self.path)
path = os.path.join(self.outdir,file)
else:
self.send_response(404)
return
mime_type = mimetypes.guess_type(path)[0]
2022-08-28 21:27:43 +00:00
if mime_type is not None:
self.send_response(200)
self.send_header("Content-type", mime_type)
self.end_headers()
with open(path, "rb") as content:
2022-08-28 21:27:43 +00:00
self.wfile.write(content.read())
else:
2022-08-28 21:27:43 +00:00
self.send_response(404)
def do_POST(self):
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
2022-08-29 15:53:27 +00:00
# unfortunately this import can't be at the top level, since that would cause a circular import
from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
2022-08-28 21:27:43 +00:00
content_length = int(self.headers['Content-Length'])
post_data = json.loads(self.rfile.read(content_length))
opt = build_opt(post_data, self.model.seed, gfpgan_model_exists)
2022-08-30 15:55:40 +00:00
self.canceled.clear()
print(f">> Request to generate with prompt: {opt.prompt}")
# In order to handle upscaled images, the PngWriter needs to maintain state
# across images generated by each call to prompt2img(), so we define it in
# the outer scope of image_done()
config = post_data.copy() # Shallow copy
config['initimg'] = config.pop('initimg_name', '')
images_generated = 0 # helps keep track of when upscaling is started
images_upscaled = 0 # helps keep track of when upscaling is completed
pngwriter = PngWriter(self.outdir)
2022-08-31 04:21:04 +00:00
prefix = pngwriter.unique_prefix()
# if upscaling is requested, then this will be called twice, once when
# the images are first generated, and then again when after upscaling
# is complete. The upscaling replaces the original file, so the second
# entry should not be inserted into the image list.
2022-08-29 15:53:27 +00:00
def image_done(image, seed, upscaled=False):
2022-08-31 04:21:04 +00:00
name = f'{prefix}.{seed}.png'
iter_opt = argparse.Namespace(**vars(opt)) # copy
if opt.variation_amount > 0:
this_variation = [[seed, opt.variation_amount]]
if opt.with_variations is None:
iter_opt.with_variations = this_variation
else:
iter_opt.with_variations = opt.with_variations + this_variation
iter_opt.variation_amount = 0
elif opt.with_variations is None:
iter_opt.seed = seed
normalized_prompt = PromptFormatter(self.model, iter_opt).normalize_prompt()
path = pngwriter.save_image_and_prompt_to_png(image, f'{normalized_prompt} -S{iter_opt.seed}', name)
if int(config['seed']) == -1:
config['seed'] = seed
# Append post_data to log, but only once!
if not upscaled:
with open(os.path.join(self.outdir, "dream_web_log.txt"), "a") as log:
2022-08-31 04:21:04 +00:00
log.write(f"{path}: {json.dumps(config)}\n")
self.wfile.write(bytes(json.dumps(
{'event': 'result', 'url': path, 'seed': seed, 'config': config}
) + '\n',"utf-8"))
# control state of the "postprocessing..." message
upscaling_requested = opt.upscale or opt.gfpgan_strength > 0
nonlocal images_generated # NB: Is this bad python style? It is typical usage in a perl closure.
nonlocal images_upscaled # NB: Is this bad python style? It is typical usage in a perl closure.
if upscaled:
images_upscaled += 1
else:
images_generated += 1
if upscaling_requested:
action = None
if images_generated >= opt.iterations:
if images_upscaled < opt.iterations:
action = 'upscaling-started'
else:
action = 'upscaling-done'
if action:
x = images_upscaled + 1
self.wfile.write(bytes(json.dumps(
{'event': action, 'processed_file_cnt': f'{x}/{opt.iterations}'}
) + '\n',"utf-8"))
2022-08-27 01:10:13 +00:00
step_writer = PngWriter(os.path.join(self.outdir, "intermediates"))
2022-08-31 04:21:04 +00:00
step_index = 1
2022-08-29 16:36:48 +00:00
def image_progress(sample, step):
2022-08-30 15:55:40 +00:00
if self.canceled.is_set():
self.wfile.write(bytes(json.dumps({'event':'canceled'}) + '\n', 'utf-8'))
raise CanceledException
2022-08-31 04:21:04 +00:00
path = None
2022-08-29 16:36:48 +00:00
# since rendering images is moderately expensive, only render every 5th image
# and don't bother with the last one, since it'll render anyway
2022-08-31 04:21:04 +00:00
nonlocal step_index
if opt.progress_images and step % 5 == 0 and step < opt.steps - 1:
Refactoring simplet2i (#387) * start refactoring -not yet functional * first phase of refactor done - not sure weighted prompts working * Second phase of refactoring. Everything mostly working. * The refactoring has moved all the hard-core inference work into ldm.dream.generator.*, where there are submodules for txt2img and img2img. inpaint will go in there as well. * Some additional refactoring will be done soon, but relatively minor work. * fix -save_orig flag to actually work * add @neonsecret attention.py memory optimization * remove unneeded imports * move token logging into conditioning.py * add placeholder version of inpaint; porting in progress * fix crash in img2img * inpainting working; not tested on variations * fix crashes in img2img * ported attention.py memory optimization #117 from basujindal branch * added @torch_no_grad() decorators to img2img, txt2img, inpaint closures * Final commit prior to PR against development * fixup crash when generating intermediate images in web UI * rename ldm.simplet2i to ldm.generate * add backward-compatibility simplet2i shell with deprecation warning * add back in mps exception, addresses @vargol comment in #354 * replaced Conditioning class with exported functions * fix wrong type of with_variations attribute during intialization * changed "image_iterator()" to "get_make_image()" * raise NotImplementedError for calling get_make_image() in parent class * Update ldm/generate.py better error message Co-authored-by: Kevin Gibbons <bakkot@gmail.com> * minor stylistic fixes and assertion checks from code review * moved get_noise() method into img2img class * break get_noise() into two methods, one for txt2img and the other for img2img * inpainting works on non-square images now * make get_noise() an abstract method in base class * much improved inpainting Co-authored-by: Kevin Gibbons <bakkot@gmail.com>
2022-09-06 00:40:10 +00:00
image = self.model.sample_to_image(sample)
name = f'{prefix}.{opt.seed}.{step_index}.png'
metadata = f'{opt.prompt} -S{opt.seed} [intermediate]'
2022-08-31 04:21:04 +00:00
path = step_writer.save_image_and_prompt_to_png(image, metadata, name)
step_index += 1
2022-08-27 01:10:13 +00:00
self.wfile.write(bytes(json.dumps(
2022-08-31 04:21:04 +00:00
{'event': 'step', 'step': step + 1, 'url': path}
2022-08-27 01:10:13 +00:00
) + '\n',"utf-8"))
2022-08-28 21:27:43 +00:00
2022-08-30 15:55:40 +00:00
try:
if opt.init_img is None:
2022-08-30 15:55:40 +00:00
# Run txt2img
self.model.prompt2image(**vars(opt), step_callback=image_progress, image_callback=image_done)
2022-08-30 15:55:40 +00:00
else:
# Decode initimg as base64 to temp file
with open("./img2img-tmp.png", "wb") as f:
initimg = opt.init_img.split(",")[1] # Ignore mime type
2022-08-30 15:55:40 +00:00
f.write(base64.b64decode(initimg))
opt1 = argparse.Namespace(**vars(opt))
opt1.init_img = "./img2img-tmp.png"
2022-08-30 15:55:40 +00:00
try:
# Run img2img
self.model.prompt2image(**vars(opt1), step_callback=image_progress, image_callback=image_done)
2022-08-30 15:55:40 +00:00
finally:
# Remove the temp file
os.remove("./img2img-tmp.png")
except CanceledException:
print(f"Canceled.")
return
2022-08-28 21:27:43 +00:00
class ThreadingDreamServer(ThreadingHTTPServer):
def __init__(self, server_address):
super(ThreadingDreamServer, self).__init__(server_address, DreamServer)