InvokeAI/ldm/dream/server.py
David Ford ed38c97ed8 Removed unrelated changes and updated based on recommendation
Removed the changes to the index.html and .gitattributes for this PR. Will add them in separate PRs.

Applied recommended change for resolving the case issue.
2022-08-29 16:43:34 -05:00

150 lines
6.5 KiB
Python

import json
import base64
import mimetypes
import os
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from ldm.dream.pngwriter import PngWriter
class DreamServer(BaseHTTPRequestHandler):
model = None
def do_GET(self):
if self.path == "/":
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
with open("./static/dream_web/index.html", "rb") as content:
self.wfile.write(content.read())
else:
path = "." + self.path
cwd = os.path.realpath(os.getcwd())
is_in_cwd = os.path.commonprefix((os.path.realpath(path), cwd)) == cwd
if not (is_in_cwd and os.path.exists(path)):
self.send_response(404)
return
mime_type = mimetypes.guess_type(path)[0]
if mime_type is not None:
self.send_response(200)
self.send_header("Content-type", mime_type)
self.end_headers()
with open("." + self.path, "rb") as content:
self.wfile.write(content.read())
else:
self.send_response(404)
def do_POST(self):
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
content_length = int(self.headers['Content-Length'])
post_data = json.loads(self.rfile.read(content_length))
prompt = post_data['prompt']
initimg = post_data['initimg']
iterations = int(post_data['iterations'])
steps = int(post_data['steps'])
width = int(post_data['width'])
height = int(post_data['height'])
cfgscale = float(post_data['cfgscale'])
gfpgan_strength = float(post_data['gfpgan_strength'])
upscale_level = post_data['upscale_level']
upscale_strength = post_data['upscale_strength']
upscale = [int(upscale_level),float(upscale_strength)] if upscale_level != '' else None
seed = None if int(post_data['seed']) == -1 else int(post_data['seed'])
print(f"Request to generate with prompt: {prompt}")
# In order to handle upscaled images, the PngWriter needs to maintain state
# across images generated by each call to prompt2img(), so we define it in
# the outer scope of image_done()
config = post_data.copy() # Shallow copy
config['initimg'] = ''
images_generated = 0 # helps keep track of when upscaling is started
images_upscaled = 0 # helps keep track of when upscaling is completed
pngwriter = PngWriter(
"./outputs/img-samples/", config['prompt'], 1
)
# if upscaling is requested, then this will be called twice, once when
# the images are first generated, and then again when after upscaling
# is complete. The upscaling replaces the original file, so the second
# entry should not be inserted into the image list.
def image_done(image, seed, upscaled=False):
pngwriter.write_image(image, seed, upscaled)
# Append post_data to log, but only once!
if not upscaled:
current_image = pngwriter.files_written[-1]
with open("./outputs/img-samples/dream_web_log.txt", "a") as log:
log.write(f"{current_image[0]}: {json.dumps(config)}\n")
self.wfile.write(bytes(json.dumps(
{'event':'result', 'files':current_image, 'config':config}
) + '\n',"utf-8"))
# control state of the "postprocessing..." message
upscaling_requested = upscale or gfpgan_strength>0
nonlocal images_generated # NB: Is this bad python style? It is typical usage in a perl closure.
nonlocal images_upscaled # NB: Is this bad python style? It is typical usage in a perl closure.
if upscaled:
images_upscaled += 1
else:
images_generated +=1
if upscaling_requested:
action = None
if images_generated >= iterations:
if images_upscaled < iterations:
action = 'upscaling-started'
else:
action = 'upscaling-done'
if action:
x = images_upscaled+1
self.wfile.write(bytes(json.dumps(
{'event':action,'processed_file_cnt':f'{x}/{iterations}'}
) + '\n',"utf-8"))
def image_progress(image, step):
self.wfile.write(bytes(json.dumps(
{'event':'step', 'step':step}
) + '\n',"utf-8"))
if initimg is None:
# Run txt2img
self.model.prompt2image(prompt,
iterations=iterations,
cfg_scale = cfgscale,
width = width,
height = height,
seed = seed,
steps = steps,
gfpgan_strength = gfpgan_strength,
upscale = upscale,
step_callback=image_progress,
image_callback=image_done)
else:
# Decode initimg as base64 to temp file
with open("./img2img-tmp.png", "wb") as f:
initimg = initimg.split(",")[1] # Ignore mime type
f.write(base64.b64decode(initimg))
# Run img2img
self.model.prompt2image(prompt,
init_img = "./img2img-tmp.png",
iterations = iterations,
cfg_scale = cfgscale,
seed = seed,
steps = steps,
gfpgan_strength=gfpgan_strength,
upscale = upscale,
step_callback=image_progress,
image_callback=image_done)
# Remove the temp file
os.remove("./img2img-tmp.png")
print(f"Prompt generated!")
class ThreadingDreamServer(ThreadingHTTPServer):
def __init__(self, server_address):
super(ThreadingDreamServer, self).__init__(server_address, DreamServer)