Merge pull request #182 from bakkot/webui-cancel

webui: support cancelation
This commit is contained in:
Lincoln Stein 2022-08-30 12:02:05 -04:00 committed by GitHub
commit 8bf321f6ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 72 additions and 38 deletions

View File

@ -4,9 +4,14 @@ import mimetypes
import os
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from ldm.dream.pngwriter import PngWriter
from threading import Event
class CanceledException(Exception):
pass
class DreamServer(BaseHTTPRequestHandler):
model = None
canceled = Event()
def do_GET(self):
if self.path == "/":
@ -25,6 +30,12 @@ class DreamServer(BaseHTTPRequestHandler):
'gfpgan_model_exists': gfpgan_model_exists
}
self.wfile.write(bytes("let config = " + json.dumps(config) + ";\n", "utf-8"))
elif self.path == "/cancel":
self.canceled.set()
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()
self.wfile.write(bytes('{}', 'utf8'))
else:
path = "." + self.path
cwd = os.path.realpath(os.getcwd())
@ -67,6 +78,7 @@ class DreamServer(BaseHTTPRequestHandler):
progress_images = 'progress_images' in post_data
seed = self.model.seed if int(post_data['seed']) == -1 else int(post_data['seed'])
self.canceled.clear()
print(f"Request to generate with prompt: {prompt}")
# In order to handle upscaled images, the PngWriter needs to maintain state
# across images generated by each call to prompt2img(), so we define it in
@ -121,6 +133,9 @@ class DreamServer(BaseHTTPRequestHandler):
# it doesn't need to know if batch_size > 1, just if this is _part of a batch_
step_writer = PngWriter('./outputs/intermediates/', prompt, 2)
def image_progress(sample, step):
if self.canceled.is_set():
self.wfile.write(bytes(json.dumps({'event':'canceled'}) + '\n', 'utf-8'))
raise CanceledException
url = None
# since rendering images is moderately expensive, only render every 5th image
# and don't bother with the last one, since it'll render anyway
@ -133,41 +148,46 @@ class DreamServer(BaseHTTPRequestHandler):
{'event':'step', 'step':step, 'url': url}
) + '\n',"utf-8"))
if initimg is None:
# Run txt2img
self.model.prompt2image(prompt,
iterations=iterations,
cfg_scale = cfgscale,
width = width,
height = height,
seed = seed,
steps = steps,
gfpgan_strength = gfpgan_strength,
upscale = upscale,
sampler_name = sampler_name,
step_callback=image_progress,
image_callback=image_done)
else:
# Decode initimg as base64 to temp file
with open("./img2img-tmp.png", "wb") as f:
initimg = initimg.split(",")[1] # Ignore mime type
f.write(base64.b64decode(initimg))
try:
if initimg is None:
# Run txt2img
self.model.prompt2image(prompt,
iterations=iterations,
cfg_scale = cfgscale,
width = width,
height = height,
seed = seed,
steps = steps,
gfpgan_strength = gfpgan_strength,
upscale = upscale,
sampler_name = sampler_name,
step_callback=image_progress,
image_callback=image_done)
else:
# Decode initimg as base64 to temp file
with open("./img2img-tmp.png", "wb") as f:
initimg = initimg.split(",")[1] # Ignore mime type
f.write(base64.b64decode(initimg))
# Run img2img
self.model.prompt2image(prompt,
init_img = "./img2img-tmp.png",
iterations = iterations,
cfg_scale = cfgscale,
seed = seed,
steps = steps,
sampler_name = sampler_name,
gfpgan_strength=gfpgan_strength,
upscale = upscale,
step_callback=image_progress,
image_callback=image_done)
# Remove the temp file
os.remove("./img2img-tmp.png")
try:
# Run img2img
self.model.prompt2image(prompt,
init_img = "./img2img-tmp.png",
iterations = iterations,
cfg_scale = cfgscale,
seed = seed,
steps = steps,
sampler_name = sampler_name,
gfpgan_strength=gfpgan_strength,
upscale = upscale,
step_callback=image_progress,
image_callback=image_done)
finally:
# Remove the temp file
os.remove("./img2img-tmp.png")
except CanceledException:
print(f"Canceled.")
return
print(f"Prompt generated!")

View File

@ -74,3 +74,7 @@ label {
width: 30vh;
height: 30vh;
}
#cancel-button {
cursor: pointer;
color: red;
}

View File

@ -87,6 +87,7 @@
<div id="about">For news and support for this web service, visit our <a href="http://github.com/lstein/stable-diffusion">GitHub site</a></div>
<div id="progress-section">
<progress id="progress-bar" value="0" max="1"></progress>
<span id="cancel-button" title="Cancel">&#10006;</span>
<br>
<img id="progress-image" src='data:image/svg+xml,<svg xmlns="http://www.w3.org/2000/svg"/>'></img>
<div id="scaling-inprocess-message">

View File

@ -89,23 +89,26 @@ async function generateSubmit(form) {
for (let event of value.split('\n').filter(e => e !== '')) {
const data = JSON.parse(event);
if (data.event == 'result') {
if (data.event === 'result') {
noOutputs = false;
document.querySelector("#no-results-message")?.remove();
appendOutput(data.files[0],data.files[1],data.config);
progressEle.setAttribute('value', 0);
progressEle.setAttribute('max', formData.steps);
progressImageEle.src = BLANK_IMAGE_URL;
} else if (data.event == 'upscaling-started') {
} else if (data.event === 'upscaling-started') {
document.getElementById("processing_cnt").textContent=data.processed_file_cnt;
document.getElementById("scaling-inprocess-message").style.display = "block";
} else if (data.event == 'upscaling-done') {
} else if (data.event === 'upscaling-done') {
document.getElementById("scaling-inprocess-message").style.display = "none";
} else if (data.event == 'step') {
} else if (data.event === 'step') {
progressEle.setAttribute('value', data.step);
if (data.url) {
progressImageEle.src = data.url;
}
} else if (data.event === 'canceled') {
// avoid alerting as if this were an error case
noOutputs = false;
}
}
}
@ -144,6 +147,12 @@ window.onload = () => {
});
loadFields(document.querySelector("#generate-form"));
document.querySelector('#cancel-button').addEventListener('click', () => {
fetch('/cancel').catch(e => {
console.error(e);
});
});
if (!config.gfpgan_model_exists) {
document.querySelector("#gfpgan").style.display = 'none';
}