From 5d13207aa6da1f198215be0ee67616915db57b3e Mon Sep 17 00:00:00 2001 From: Kevin Gibbons Date: Tue, 30 Aug 2022 08:55:40 -0700 Subject: [PATCH] webui: support cancelation --- ldm/dream/server.py | 88 +++++++++++++++++++++++-------------- static/dream_web/index.css | 4 ++ static/dream_web/index.html | 1 + static/dream_web/index.js | 17 +++++-- 4 files changed, 72 insertions(+), 38 deletions(-) diff --git a/ldm/dream/server.py b/ldm/dream/server.py index 3f1126a1ac..a13665cedc 100644 --- a/ldm/dream/server.py +++ b/ldm/dream/server.py @@ -4,9 +4,14 @@ import mimetypes import os from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from ldm.dream.pngwriter import PngWriter +from threading import Event + +class CanceledException(Exception): + pass class DreamServer(BaseHTTPRequestHandler): model = None + canceled = Event() def do_GET(self): if self.path == "/": @@ -25,6 +30,12 @@ class DreamServer(BaseHTTPRequestHandler): 'gfpgan_model_exists': gfpgan_model_exists } self.wfile.write(bytes("let config = " + json.dumps(config) + ";\n", "utf-8")) + elif self.path == "/cancel": + self.canceled.set() + self.send_response(200) + self.send_header("Content-type", "application/json") + self.end_headers() + self.wfile.write(bytes('{}', 'utf8')) else: path = "." + self.path cwd = os.path.realpath(os.getcwd()) @@ -67,6 +78,7 @@ class DreamServer(BaseHTTPRequestHandler): progress_images = 'progress_images' in post_data seed = self.model.seed if int(post_data['seed']) == -1 else int(post_data['seed']) + self.canceled.clear() print(f"Request to generate with prompt: {prompt}") # In order to handle upscaled images, the PngWriter needs to maintain state # across images generated by each call to prompt2img(), so we define it in @@ -121,6 +133,9 @@ class DreamServer(BaseHTTPRequestHandler): # it doesn't need to know if batch_size > 1, just if this is _part of a batch_ step_writer = PngWriter('./outputs/intermediates/', prompt, 2) def image_progress(sample, step): + if self.canceled.is_set(): + self.wfile.write(bytes(json.dumps({'event':'canceled'}) + '\n', 'utf-8')) + raise CanceledException url = None # since rendering images is moderately expensive, only render every 5th image # and don't bother with the last one, since it'll render anyway @@ -133,41 +148,46 @@ class DreamServer(BaseHTTPRequestHandler): {'event':'step', 'step':step, 'url': url} ) + '\n',"utf-8")) - if initimg is None: - # Run txt2img - self.model.prompt2image(prompt, - iterations=iterations, - cfg_scale = cfgscale, - width = width, - height = height, - seed = seed, - steps = steps, - gfpgan_strength = gfpgan_strength, - upscale = upscale, - sampler_name = sampler_name, - step_callback=image_progress, - image_callback=image_done) - else: - # Decode initimg as base64 to temp file - with open("./img2img-tmp.png", "wb") as f: - initimg = initimg.split(",")[1] # Ignore mime type - f.write(base64.b64decode(initimg)) + try: + if initimg is None: + # Run txt2img + self.model.prompt2image(prompt, + iterations=iterations, + cfg_scale = cfgscale, + width = width, + height = height, + seed = seed, + steps = steps, + gfpgan_strength = gfpgan_strength, + upscale = upscale, + sampler_name = sampler_name, + step_callback=image_progress, + image_callback=image_done) + else: + # Decode initimg as base64 to temp file + with open("./img2img-tmp.png", "wb") as f: + initimg = initimg.split(",")[1] # Ignore mime type + f.write(base64.b64decode(initimg)) - # Run img2img - self.model.prompt2image(prompt, - init_img = "./img2img-tmp.png", - iterations = iterations, - cfg_scale = cfgscale, - seed = seed, - steps = steps, - sampler_name = sampler_name, - gfpgan_strength=gfpgan_strength, - upscale = upscale, - step_callback=image_progress, - image_callback=image_done) - - # Remove the temp file - os.remove("./img2img-tmp.png") + try: + # Run img2img + self.model.prompt2image(prompt, + init_img = "./img2img-tmp.png", + iterations = iterations, + cfg_scale = cfgscale, + seed = seed, + steps = steps, + sampler_name = sampler_name, + gfpgan_strength=gfpgan_strength, + upscale = upscale, + step_callback=image_progress, + image_callback=image_done) + finally: + # Remove the temp file + os.remove("./img2img-tmp.png") + except CanceledException: + print(f"Canceled.") + return print(f"Prompt generated!") diff --git a/static/dream_web/index.css b/static/dream_web/index.css index beb20e5a74..00dc1c5c8d 100644 --- a/static/dream_web/index.css +++ b/static/dream_web/index.css @@ -74,3 +74,7 @@ label { width: 30vh; height: 30vh; } +#cancel-button { + cursor: pointer; + color: red; +} diff --git a/static/dream_web/index.html b/static/dream_web/index.html index 5838b303da..fa233f07d9 100644 --- a/static/dream_web/index.html +++ b/static/dream_web/index.html @@ -87,6 +87,7 @@
For news and support for this web service, visit our GitHub site
+
diff --git a/static/dream_web/index.js b/static/dream_web/index.js index 37899cf41b..32175d0f22 100644 --- a/static/dream_web/index.js +++ b/static/dream_web/index.js @@ -89,23 +89,26 @@ async function generateSubmit(form) { for (let event of value.split('\n').filter(e => e !== '')) { const data = JSON.parse(event); - if (data.event == 'result') { + if (data.event === 'result') { noOutputs = false; document.querySelector("#no-results-message")?.remove(); appendOutput(data.files[0],data.files[1],data.config); progressEle.setAttribute('value', 0); progressEle.setAttribute('max', formData.steps); progressImageEle.src = BLANK_IMAGE_URL; - } else if (data.event == 'upscaling-started') { + } else if (data.event === 'upscaling-started') { document.getElementById("processing_cnt").textContent=data.processed_file_cnt; document.getElementById("scaling-inprocess-message").style.display = "block"; - } else if (data.event == 'upscaling-done') { + } else if (data.event === 'upscaling-done') { document.getElementById("scaling-inprocess-message").style.display = "none"; - } else if (data.event == 'step') { + } else if (data.event === 'step') { progressEle.setAttribute('value', data.step); if (data.url) { progressImageEle.src = data.url; } + } else if (data.event === 'canceled') { + // avoid alerting as if this were an error case + noOutputs = false; } } } @@ -144,6 +147,12 @@ window.onload = () => { }); loadFields(document.querySelector("#generate-form")); + document.querySelector('#cancel-button').addEventListener('click', () => { + fetch('/cancel').catch(e => { + console.error(e); + }); + }); + if (!config.gfpgan_model_exists) { document.querySelector("#gfpgan").style.display = 'none'; }