mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
166 lines
7.4 KiB
Python
166 lines
7.4 KiB
Python
import json
|
|
import base64
|
|
import mimetypes
|
|
import os
|
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
|
from ldm.dream.pngwriter import PngWriter
|
|
|
|
class DreamServer(BaseHTTPRequestHandler):
|
|
model = None
|
|
|
|
def do_GET(self):
|
|
if self.path == "/":
|
|
self.send_response(200)
|
|
self.send_header("Content-type", "text/html")
|
|
self.end_headers()
|
|
with open("./static/dream_web/index.html", "rb") as content:
|
|
self.wfile.write(content.read())
|
|
elif self.path == "/config.js":
|
|
# unfortunately this import can't be at the top level, since that would cause a circular import
|
|
from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
|
|
self.send_response(200)
|
|
self.send_header("Content-type", "application/javascript")
|
|
self.end_headers()
|
|
config = {
|
|
'gfpgan_model_exists': gfpgan_model_exists
|
|
}
|
|
self.wfile.write(bytes("let config = " + json.dumps(config) + ";\n", "utf-8"))
|
|
else:
|
|
path = "." + self.path
|
|
cwd = os.path.realpath(os.getcwd())
|
|
is_in_cwd = os.path.commonprefix((os.path.realpath(path), cwd)) == cwd
|
|
if not (is_in_cwd and os.path.exists(path)):
|
|
self.send_response(404)
|
|
return
|
|
mime_type = mimetypes.guess_type(path)[0]
|
|
if mime_type is not None:
|
|
self.send_response(200)
|
|
self.send_header("Content-type", mime_type)
|
|
self.end_headers()
|
|
with open("." + self.path, "rb") as content:
|
|
self.wfile.write(content.read())
|
|
else:
|
|
self.send_response(404)
|
|
|
|
def do_POST(self):
|
|
self.send_response(200)
|
|
self.send_header("Content-type", "application/json")
|
|
self.end_headers()
|
|
|
|
# unfortunately this import can't be at the top level, since that would cause a circular import
|
|
from ldm.gfpgan.gfpgan_tools import gfpgan_model_exists
|
|
|
|
content_length = int(self.headers['Content-Length'])
|
|
post_data = json.loads(self.rfile.read(content_length))
|
|
prompt = post_data['prompt']
|
|
initimg = post_data['initimg']
|
|
iterations = int(post_data['iterations'])
|
|
steps = int(post_data['steps'])
|
|
width = int(post_data['width'])
|
|
height = int(post_data['height'])
|
|
cfgscale = float(post_data['cfgscale'])
|
|
sampler_name = post_data['sampler']
|
|
gfpgan_strength = float(post_data['gfpgan_strength']) if gfpgan_model_exists else 0
|
|
upscale_level = post_data['upscale_level']
|
|
upscale_strength = post_data['upscale_strength']
|
|
upscale = [int(upscale_level),float(upscale_strength)] if upscale_level != '' else None
|
|
seed = None if int(post_data['seed']) == -1 else int(post_data['seed'])
|
|
|
|
print(f"Request to generate with prompt: {prompt}")
|
|
# In order to handle upscaled images, the PngWriter needs to maintain state
|
|
# across images generated by each call to prompt2img(), so we define it in
|
|
# the outer scope of image_done()
|
|
config = post_data.copy() # Shallow copy
|
|
config['initimg'] = ''
|
|
|
|
images_generated = 0 # helps keep track of when upscaling is started
|
|
images_upscaled = 0 # helps keep track of when upscaling is completed
|
|
pngwriter = PngWriter(
|
|
"./outputs/img-samples/", config['prompt'], 1
|
|
)
|
|
|
|
# if upscaling is requested, then this will be called twice, once when
|
|
# the images are first generated, and then again when after upscaling
|
|
# is complete. The upscaling replaces the original file, so the second
|
|
# entry should not be inserted into the image list.
|
|
def image_done(image, seed, upscaled=False):
|
|
pngwriter.write_image(image, seed, upscaled)
|
|
|
|
# Append post_data to log, but only once!
|
|
if not upscaled:
|
|
current_image = pngwriter.files_written[-1]
|
|
with open("./outputs/img-samples/dream_web_log.txt", "a") as log:
|
|
log.write(f"{current_image[0]}: {json.dumps(config)}\n")
|
|
self.wfile.write(bytes(json.dumps(
|
|
{'event':'result', 'files':current_image, 'config':config}
|
|
) + '\n',"utf-8"))
|
|
|
|
# control state of the "postprocessing..." message
|
|
upscaling_requested = upscale or gfpgan_strength>0
|
|
nonlocal images_generated # NB: Is this bad python style? It is typical usage in a perl closure.
|
|
nonlocal images_upscaled # NB: Is this bad python style? It is typical usage in a perl closure.
|
|
if upscaled:
|
|
images_upscaled += 1
|
|
else:
|
|
images_generated +=1
|
|
if upscaling_requested:
|
|
action = None
|
|
if images_generated >= iterations:
|
|
if images_upscaled < iterations:
|
|
action = 'upscaling-started'
|
|
else:
|
|
action = 'upscaling-done'
|
|
if action:
|
|
x = images_upscaled+1
|
|
self.wfile.write(bytes(json.dumps(
|
|
{'event':action,'processed_file_cnt':f'{x}/{iterations}'}
|
|
) + '\n',"utf-8"))
|
|
|
|
def image_progress(image, step):
|
|
self.wfile.write(bytes(json.dumps(
|
|
{'event':'step', 'step':step}
|
|
) + '\n',"utf-8"))
|
|
|
|
if initimg is None:
|
|
# Run txt2img
|
|
self.model.prompt2image(prompt,
|
|
iterations=iterations,
|
|
cfg_scale = cfgscale,
|
|
width = width,
|
|
height = height,
|
|
seed = seed,
|
|
steps = steps,
|
|
gfpgan_strength = gfpgan_strength,
|
|
upscale = upscale,
|
|
sampler_name = sampler_name,
|
|
step_callback=image_progress,
|
|
image_callback=image_done)
|
|
else:
|
|
# Decode initimg as base64 to temp file
|
|
with open("./img2img-tmp.png", "wb") as f:
|
|
initimg = initimg.split(",")[1] # Ignore mime type
|
|
f.write(base64.b64decode(initimg))
|
|
|
|
# Run img2img
|
|
self.model.prompt2image(prompt,
|
|
init_img = "./img2img-tmp.png",
|
|
iterations = iterations,
|
|
cfg_scale = cfgscale,
|
|
seed = seed,
|
|
steps = steps,
|
|
sampler_name = sampler_name,
|
|
gfpgan_strength=gfpgan_strength,
|
|
upscale = upscale,
|
|
step_callback=image_progress,
|
|
image_callback=image_done)
|
|
|
|
# Remove the temp file
|
|
os.remove("./img2img-tmp.png")
|
|
|
|
print(f"Prompt generated!")
|
|
|
|
|
|
class ThreadingDreamServer(ThreadingHTTPServer):
|
|
def __init__(self, server_address):
|
|
super(ThreadingDreamServer, self).__init__(server_address, DreamServer)
|