import json import base64 import mimetypes import os from pytorch_lightning import logging from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer print("Loading model...") from ldm.simplet2i import T2I model = T2I(sampler_name='k_lms') # to get rid of annoying warning messages from pytorch import transformers transformers.logging.set_verbosity_error() logging.getLogger("pytorch_lightning").setLevel(logging.ERROR) print("Initializing model, be patient...") model.load_model() class DreamServer(BaseHTTPRequestHandler): def do_GET(self): if self.path == "/": self.send_response(200) self.send_header("Content-type", "text/html") self.end_headers() with open("./static/dream_web/index.html", "rb") as content: self.wfile.write(content.read()) elif os.path.exists("." + self.path): mime_type = mimetypes.guess_type(self.path)[0] if mime_type is not None: self.send_response(200) self.send_header("Content-type", mime_type) self.end_headers() with open("." + self.path, "rb") as content: self.wfile.write(content.read()) else: self.send_response(404) else: self.send_response(404) def do_POST(self): self.send_response(200) self.send_header("Content-type", "application/json") self.end_headers() content_length = int(self.headers['Content-Length']) post_data = json.loads(self.rfile.read(content_length)) prompt = post_data['prompt'] initimg = post_data['initimg'] iterations = int(post_data['iterations']) steps = int(post_data['steps']) width = int(post_data['width']) height = int(post_data['height']) cfgscale = float(post_data['cfgscale']) seed = None if int(post_data['seed']) == -1 else int(post_data['seed']) print(f"Request to generate with prompt: {prompt}") outputs = [] if initimg is None: # Run txt2img outputs = model.txt2img(prompt, iterations=iterations, cfg_scale = cfgscale, width = width, height = height, seed = seed, steps = steps) else: # Decode initimg as base64 to temp file with open("./img2img-tmp.png", "wb") as f: initimg = initimg.split(",")[1] # Ignore mime type f.write(base64.b64decode(initimg)) # Run img2img outputs = model.img2img(prompt, init_img = "./img2img-tmp.png", iterations = iterations, cfg_scale = cfgscale, seed = seed, steps = steps) # Remove the temp file os.remove("./img2img-tmp.png") print(f"Prompt generated with output: {outputs}") post_data['initimg'] = '' # Don't send init image back # Append post_data to log with open("./outputs/img-samples/dream_web_log.txt", "a") as log: for output in outputs: log.write(f"{output[0]}: {json.dumps(post_data)}\n") outputs = [x + [post_data] for x in outputs] # Append config to each output result = {'outputs': outputs} self.wfile.write(bytes(json.dumps(result), "utf-8")) if __name__ == "__main__": # Change working directory to the stable-diffusion directory os.chdir( os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..')) ) # Start server dream_server = ThreadingHTTPServer(("0.0.0.0", 9090), DreamServer) print("\n\n* Started Stable Diffusion dream server! Point your browser at http://localhost:9090 or use the host's DNS name or IP address. *") try: dream_server.serve_forever() except KeyboardInterrupt: pass dream_server.server_close()