moved scripts/dream_server.py into ldm/dream/server.py

This commit is contained in:
Lincoln Stein
2022-08-28 16:37:27 -04:00
parent 08a9702b73
commit 7dfca3dcb5
5 changed files with 594 additions and 128 deletions

View File

@ -9,6 +9,7 @@ import copy
import warnings
import ldm.dream.readline
from ldm.dream.pngwriter import PngWriter, PromptFormatter
from ldm.dream.server import DreamServer, ThreadingDreamServer
def main():
"""Initialize command-line parsers and the diffusion model"""
@ -113,17 +114,19 @@ def main():
print('Error loading GFPGAN:', file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
if not infile:
print(
"\n* Initialization done! Awaiting your command (-h for help, 'q' to quit)"
)
log_path = os.path.join(opt.outdir, 'dream_log.txt')
cmd_parser = create_cmd_parser()
main_loop(t2i, opt.outdir, cmd_parser, log_path, infile)
with open(log_path, 'a') as log:
cmd_parser = create_cmd_parser()
if opt.web:
dream_server_loop(t2i)
else:
main_loop(t2i, opt.outdir, cmd_parser, log_path, infile)
log.close()
def main_loop(t2i, outdir, parser, log_path, infile):
print(
"\n* Initialization done! Awaiting your command (-h for help, 'q' to quit, 'cd' to change output dir, 'pwd' to print output dir)..."
)
"""prompt/read/execute loop"""
done = False
last_seeds = []
@ -246,6 +249,26 @@ def get_next_command(infile=None) -> 'command string':
print(f'#{command}')
return command
def dream_server_loop(t2i):
print('\n* --web was specified, starting web server...')
# Change working directory to the stable-diffusion directory
os.chdir(
os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
)
# Start server
DreamServer.model = t2i
dream_server = ThreadingDreamServer(("0.0.0.0", 9090))
print("\nStarted Stable Diffusion dream server!")
print("Point your browser at http://localhost:9090 or use the host's DNS name or IP address.")
try:
dream_server.serve_forever()
except KeyboardInterrupt:
pass
dream_server.server_close()
def load_gfpgan_bg_upsampler(bg_upsampler, bg_tile=400):
import torch
@ -426,6 +449,12 @@ def create_argv_parser():
default='../GFPGAN',
help='indicates the directory containing the GFPGAN code. Only used if --gfpgan is specified',
)
parser.add_argument(
'--web',
dest='web',
action='store_true',
help='start in web server mode.',
)
return parser

View File

@ -1,114 +0,0 @@
import json
import base64
import mimetypes
import os
from pytorch_lightning import logging
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
print("Loading model...")
from ldm.simplet2i import T2I
model = T2I(sampler_name='k_lms')
# to get rid of annoying warning messages from pytorch
import transformers
transformers.logging.set_verbosity_error()
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
print("Initializing model, be patient...")
model.load_model()
class DreamServer(BaseHTTPRequestHandler):
def do_GET(self):
if self.path == "/":
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
with open("./static/dream_web/index.html", "rb") as content:
self.wfile.write(content.read())
elif os.path.exists("." + self.path):
mime_type = mimetypes.guess_type(self.path)[0]
if mime_type is not None:
self.send_response(200)
self.send_header("Content-type", mime_type)
self.end_headers()
with open("." + self.path, "rb") as content:
self.wfile.write(content.read())
else:
self.send_response(404)
else:
self.send_response(404)
def do_POST(self):
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
content_length = int(self.headers['Content-Length'])
post_data = json.loads(self.rfile.read(content_length))
prompt = post_data['prompt']
initimg = post_data['initimg']
iterations = int(post_data['iterations'])
steps = int(post_data['steps'])
width = int(post_data['width'])
height = int(post_data['height'])
cfgscale = float(post_data['cfgscale'])
seed = None if int(post_data['seed']) == -1 else int(post_data['seed'])
print(f"Request to generate with prompt: {prompt}")
outputs = []
if initimg is None:
# Run txt2img
outputs = model.txt2img(prompt,
iterations=iterations,
cfg_scale = cfgscale,
width = width,
height = height,
seed = seed,
steps = steps)
else:
# Decode initimg as base64 to temp file
with open("./img2img-tmp.png", "wb") as f:
initimg = initimg.split(",")[1] # Ignore mime type
f.write(base64.b64decode(initimg))
# Run img2img
outputs = model.img2img(prompt,
init_img = "./img2img-tmp.png",
iterations = iterations,
cfg_scale = cfgscale,
seed = seed,
steps = steps)
# Remove the temp file
os.remove("./img2img-tmp.png")
print(f"Prompt generated with output: {outputs}")
post_data['initimg'] = '' # Don't send init image back
# Append post_data to log
with open("./outputs/img-samples/dream_web_log.txt", "a") as log:
for output in outputs:
log.write(f"{output[0]}: {json.dumps(post_data)}\n")
outputs = [x + [post_data] for x in outputs] # Append config to each output
result = {'outputs': outputs}
self.wfile.write(bytes(json.dumps(result), "utf-8"))
if __name__ == "__main__":
# Change working directory to the stable-diffusion directory
os.chdir(
os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..'))
)
# Start server
dream_server = ThreadingHTTPServer(("0.0.0.0", 9090), DreamServer)
print("\n\n* Started Stable Diffusion dream server! Point your browser at http://localhost:9090 or use the host's DNS name or IP address. *")
try:
dream_server.serve_forever()
except KeyboardInterrupt:
pass
dream_server.server_close()