mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
moved scripts/dream_server.py into ldm/dream/server.py
This commit is contained in:
@ -9,6 +9,7 @@ import copy
|
||||
import warnings
|
||||
import ldm.dream.readline
|
||||
from ldm.dream.pngwriter import PngWriter, PromptFormatter
|
||||
from ldm.dream.server import DreamServer, ThreadingDreamServer
|
||||
|
||||
def main():
|
||||
"""Initialize command-line parsers and the diffusion model"""
|
||||
@ -113,17 +114,19 @@ def main():
|
||||
print('Error loading GFPGAN:', file=sys.stderr)
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
||||
if not infile:
|
||||
print(
|
||||
"\n* Initialization done! Awaiting your command (-h for help, 'q' to quit)"
|
||||
)
|
||||
|
||||
log_path = os.path.join(opt.outdir, 'dream_log.txt')
|
||||
cmd_parser = create_cmd_parser()
|
||||
main_loop(t2i, opt.outdir, cmd_parser, log_path, infile)
|
||||
|
||||
with open(log_path, 'a') as log:
|
||||
cmd_parser = create_cmd_parser()
|
||||
if opt.web:
|
||||
dream_server_loop(t2i)
|
||||
else:
|
||||
main_loop(t2i, opt.outdir, cmd_parser, log_path, infile)
|
||||
log.close()
|
||||
|
||||
def main_loop(t2i, outdir, parser, log_path, infile):
|
||||
print(
|
||||
"\n* Initialization done! Awaiting your command (-h for help, 'q' to quit, 'cd' to change output dir, 'pwd' to print output dir)..."
|
||||
)
|
||||
"""prompt/read/execute loop"""
|
||||
done = False
|
||||
last_seeds = []
|
||||
@ -246,6 +249,26 @@ def get_next_command(infile=None) -> 'command string':
|
||||
print(f'#{command}')
|
||||
return command
|
||||
|
||||
def dream_server_loop(t2i):
|
||||
print('\n* --web was specified, starting web server...')
|
||||
# Change working directory to the stable-diffusion directory
|
||||
os.chdir(
|
||||
os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
|
||||
)
|
||||
|
||||
# Start server
|
||||
DreamServer.model = t2i
|
||||
dream_server = ThreadingDreamServer(("0.0.0.0", 9090))
|
||||
print("\nStarted Stable Diffusion dream server!")
|
||||
print("Point your browser at http://localhost:9090 or use the host's DNS name or IP address.")
|
||||
|
||||
try:
|
||||
dream_server.serve_forever()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
dream_server.server_close()
|
||||
|
||||
def load_gfpgan_bg_upsampler(bg_upsampler, bg_tile=400):
|
||||
import torch
|
||||
|
||||
@ -426,6 +449,12 @@ def create_argv_parser():
|
||||
default='../GFPGAN',
|
||||
help='indicates the directory containing the GFPGAN code. Only used if --gfpgan is specified',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--web',
|
||||
dest='web',
|
||||
action='store_true',
|
||||
help='start in web server mode.',
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
|
@ -1,114 +0,0 @@
|
||||
import json
|
||||
import base64
|
||||
import mimetypes
|
||||
import os
|
||||
from pytorch_lightning import logging
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
|
||||
print("Loading model...")
|
||||
from ldm.simplet2i import T2I
|
||||
model = T2I(sampler_name='k_lms')
|
||||
|
||||
# to get rid of annoying warning messages from pytorch
|
||||
import transformers
|
||||
transformers.logging.set_verbosity_error()
|
||||
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
|
||||
|
||||
print("Initializing model, be patient...")
|
||||
model.load_model()
|
||||
|
||||
class DreamServer(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
if self.path == "/":
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "text/html")
|
||||
self.end_headers()
|
||||
with open("./static/dream_web/index.html", "rb") as content:
|
||||
self.wfile.write(content.read())
|
||||
elif os.path.exists("." + self.path):
|
||||
mime_type = mimetypes.guess_type(self.path)[0]
|
||||
if mime_type is not None:
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", mime_type)
|
||||
self.end_headers()
|
||||
with open("." + self.path, "rb") as content:
|
||||
self.wfile.write(content.read())
|
||||
else:
|
||||
self.send_response(404)
|
||||
else:
|
||||
self.send_response(404)
|
||||
|
||||
def do_POST(self):
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "application/json")
|
||||
self.end_headers()
|
||||
|
||||
content_length = int(self.headers['Content-Length'])
|
||||
post_data = json.loads(self.rfile.read(content_length))
|
||||
prompt = post_data['prompt']
|
||||
initimg = post_data['initimg']
|
||||
iterations = int(post_data['iterations'])
|
||||
steps = int(post_data['steps'])
|
||||
width = int(post_data['width'])
|
||||
height = int(post_data['height'])
|
||||
cfgscale = float(post_data['cfgscale'])
|
||||
seed = None if int(post_data['seed']) == -1 else int(post_data['seed'])
|
||||
|
||||
print(f"Request to generate with prompt: {prompt}")
|
||||
|
||||
outputs = []
|
||||
if initimg is None:
|
||||
# Run txt2img
|
||||
outputs = model.txt2img(prompt,
|
||||
iterations=iterations,
|
||||
cfg_scale = cfgscale,
|
||||
width = width,
|
||||
height = height,
|
||||
seed = seed,
|
||||
steps = steps)
|
||||
else:
|
||||
# Decode initimg as base64 to temp file
|
||||
with open("./img2img-tmp.png", "wb") as f:
|
||||
initimg = initimg.split(",")[1] # Ignore mime type
|
||||
f.write(base64.b64decode(initimg))
|
||||
|
||||
# Run img2img
|
||||
outputs = model.img2img(prompt,
|
||||
init_img = "./img2img-tmp.png",
|
||||
iterations = iterations,
|
||||
cfg_scale = cfgscale,
|
||||
seed = seed,
|
||||
steps = steps)
|
||||
# Remove the temp file
|
||||
os.remove("./img2img-tmp.png")
|
||||
|
||||
print(f"Prompt generated with output: {outputs}")
|
||||
|
||||
post_data['initimg'] = '' # Don't send init image back
|
||||
|
||||
# Append post_data to log
|
||||
with open("./outputs/img-samples/dream_web_log.txt", "a") as log:
|
||||
for output in outputs:
|
||||
log.write(f"{output[0]}: {json.dumps(post_data)}\n")
|
||||
|
||||
outputs = [x + [post_data] for x in outputs] # Append config to each output
|
||||
result = {'outputs': outputs}
|
||||
self.wfile.write(bytes(json.dumps(result), "utf-8"))
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Change working directory to the stable-diffusion directory
|
||||
os.chdir(
|
||||
os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..'))
|
||||
)
|
||||
|
||||
# Start server
|
||||
dream_server = ThreadingHTTPServer(("0.0.0.0", 9090), DreamServer)
|
||||
print("\n\n* Started Stable Diffusion dream server! Point your browser at http://localhost:9090 or use the host's DNS name or IP address. *")
|
||||
|
||||
try:
|
||||
dream_server.serve_forever()
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
dream_server.server_close()
|
||||
|
Reference in New Issue
Block a user