mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
webui: support cancelation
This commit is contained in:
parent
dae2b26765
commit
5d13207aa6
@ -4,9 +4,14 @@ import mimetypes
|
|||||||
import os
|
import os
|
||||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||||
from ldm.dream.pngwriter import PngWriter
|
from ldm.dream.pngwriter import PngWriter
|
||||||
|
from threading import Event
|
||||||
|
|
||||||
|
class CanceledException(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
class DreamServer(BaseHTTPRequestHandler):
|
class DreamServer(BaseHTTPRequestHandler):
|
||||||
model = None
|
model = None
|
||||||
|
canceled = Event()
|
||||||
|
|
||||||
def do_GET(self):
|
def do_GET(self):
|
||||||
if self.path == "/":
|
if self.path == "/":
|
||||||
@ -25,6 +30,12 @@ class DreamServer(BaseHTTPRequestHandler):
|
|||||||
'gfpgan_model_exists': gfpgan_model_exists
|
'gfpgan_model_exists': gfpgan_model_exists
|
||||||
}
|
}
|
||||||
self.wfile.write(bytes("let config = " + json.dumps(config) + ";\n", "utf-8"))
|
self.wfile.write(bytes("let config = " + json.dumps(config) + ";\n", "utf-8"))
|
||||||
|
elif self.path == "/cancel":
|
||||||
|
self.canceled.set()
|
||||||
|
self.send_response(200)
|
||||||
|
self.send_header("Content-type", "application/json")
|
||||||
|
self.end_headers()
|
||||||
|
self.wfile.write(bytes('{}', 'utf8'))
|
||||||
else:
|
else:
|
||||||
path = "." + self.path
|
path = "." + self.path
|
||||||
cwd = os.path.realpath(os.getcwd())
|
cwd = os.path.realpath(os.getcwd())
|
||||||
@ -67,6 +78,7 @@ class DreamServer(BaseHTTPRequestHandler):
|
|||||||
progress_images = 'progress_images' in post_data
|
progress_images = 'progress_images' in post_data
|
||||||
seed = self.model.seed if int(post_data['seed']) == -1 else int(post_data['seed'])
|
seed = self.model.seed if int(post_data['seed']) == -1 else int(post_data['seed'])
|
||||||
|
|
||||||
|
self.canceled.clear()
|
||||||
print(f"Request to generate with prompt: {prompt}")
|
print(f"Request to generate with prompt: {prompt}")
|
||||||
# In order to handle upscaled images, the PngWriter needs to maintain state
|
# In order to handle upscaled images, the PngWriter needs to maintain state
|
||||||
# across images generated by each call to prompt2img(), so we define it in
|
# across images generated by each call to prompt2img(), so we define it in
|
||||||
@ -121,6 +133,9 @@ class DreamServer(BaseHTTPRequestHandler):
|
|||||||
# it doesn't need to know if batch_size > 1, just if this is _part of a batch_
|
# it doesn't need to know if batch_size > 1, just if this is _part of a batch_
|
||||||
step_writer = PngWriter('./outputs/intermediates/', prompt, 2)
|
step_writer = PngWriter('./outputs/intermediates/', prompt, 2)
|
||||||
def image_progress(sample, step):
|
def image_progress(sample, step):
|
||||||
|
if self.canceled.is_set():
|
||||||
|
self.wfile.write(bytes(json.dumps({'event':'canceled'}) + '\n', 'utf-8'))
|
||||||
|
raise CanceledException
|
||||||
url = None
|
url = None
|
||||||
# since rendering images is moderately expensive, only render every 5th image
|
# since rendering images is moderately expensive, only render every 5th image
|
||||||
# and don't bother with the last one, since it'll render anyway
|
# and don't bother with the last one, since it'll render anyway
|
||||||
@ -133,41 +148,46 @@ class DreamServer(BaseHTTPRequestHandler):
|
|||||||
{'event':'step', 'step':step, 'url': url}
|
{'event':'step', 'step':step, 'url': url}
|
||||||
) + '\n',"utf-8"))
|
) + '\n',"utf-8"))
|
||||||
|
|
||||||
if initimg is None:
|
try:
|
||||||
# Run txt2img
|
if initimg is None:
|
||||||
self.model.prompt2image(prompt,
|
# Run txt2img
|
||||||
iterations=iterations,
|
self.model.prompt2image(prompt,
|
||||||
cfg_scale = cfgscale,
|
iterations=iterations,
|
||||||
width = width,
|
cfg_scale = cfgscale,
|
||||||
height = height,
|
width = width,
|
||||||
seed = seed,
|
height = height,
|
||||||
steps = steps,
|
seed = seed,
|
||||||
gfpgan_strength = gfpgan_strength,
|
steps = steps,
|
||||||
upscale = upscale,
|
gfpgan_strength = gfpgan_strength,
|
||||||
sampler_name = sampler_name,
|
upscale = upscale,
|
||||||
step_callback=image_progress,
|
sampler_name = sampler_name,
|
||||||
image_callback=image_done)
|
step_callback=image_progress,
|
||||||
else:
|
image_callback=image_done)
|
||||||
# Decode initimg as base64 to temp file
|
else:
|
||||||
with open("./img2img-tmp.png", "wb") as f:
|
# Decode initimg as base64 to temp file
|
||||||
initimg = initimg.split(",")[1] # Ignore mime type
|
with open("./img2img-tmp.png", "wb") as f:
|
||||||
f.write(base64.b64decode(initimg))
|
initimg = initimg.split(",")[1] # Ignore mime type
|
||||||
|
f.write(base64.b64decode(initimg))
|
||||||
|
|
||||||
# Run img2img
|
try:
|
||||||
self.model.prompt2image(prompt,
|
# Run img2img
|
||||||
init_img = "./img2img-tmp.png",
|
self.model.prompt2image(prompt,
|
||||||
iterations = iterations,
|
init_img = "./img2img-tmp.png",
|
||||||
cfg_scale = cfgscale,
|
iterations = iterations,
|
||||||
seed = seed,
|
cfg_scale = cfgscale,
|
||||||
steps = steps,
|
seed = seed,
|
||||||
sampler_name = sampler_name,
|
steps = steps,
|
||||||
gfpgan_strength=gfpgan_strength,
|
sampler_name = sampler_name,
|
||||||
upscale = upscale,
|
gfpgan_strength=gfpgan_strength,
|
||||||
step_callback=image_progress,
|
upscale = upscale,
|
||||||
image_callback=image_done)
|
step_callback=image_progress,
|
||||||
|
image_callback=image_done)
|
||||||
# Remove the temp file
|
finally:
|
||||||
os.remove("./img2img-tmp.png")
|
# Remove the temp file
|
||||||
|
os.remove("./img2img-tmp.png")
|
||||||
|
except CanceledException:
|
||||||
|
print(f"Canceled.")
|
||||||
|
return
|
||||||
|
|
||||||
print(f"Prompt generated!")
|
print(f"Prompt generated!")
|
||||||
|
|
||||||
|
@ -74,3 +74,7 @@ label {
|
|||||||
width: 30vh;
|
width: 30vh;
|
||||||
height: 30vh;
|
height: 30vh;
|
||||||
}
|
}
|
||||||
|
#cancel-button {
|
||||||
|
cursor: pointer;
|
||||||
|
color: red;
|
||||||
|
}
|
||||||
|
@ -87,6 +87,7 @@
|
|||||||
<div id="about">For news and support for this web service, visit our <a href="http://github.com/lstein/stable-diffusion">GitHub site</a></div>
|
<div id="about">For news and support for this web service, visit our <a href="http://github.com/lstein/stable-diffusion">GitHub site</a></div>
|
||||||
<div id="progress-section">
|
<div id="progress-section">
|
||||||
<progress id="progress-bar" value="0" max="1"></progress>
|
<progress id="progress-bar" value="0" max="1"></progress>
|
||||||
|
<span id="cancel-button" title="Cancel">✖</span>
|
||||||
<br>
|
<br>
|
||||||
<img id="progress-image" src='data:image/svg+xml,<svg xmlns="http://www.w3.org/2000/svg"/>'></img>
|
<img id="progress-image" src='data:image/svg+xml,<svg xmlns="http://www.w3.org/2000/svg"/>'></img>
|
||||||
<div id="scaling-inprocess-message">
|
<div id="scaling-inprocess-message">
|
||||||
|
@ -89,23 +89,26 @@ async function generateSubmit(form) {
|
|||||||
for (let event of value.split('\n').filter(e => e !== '')) {
|
for (let event of value.split('\n').filter(e => e !== '')) {
|
||||||
const data = JSON.parse(event);
|
const data = JSON.parse(event);
|
||||||
|
|
||||||
if (data.event == 'result') {
|
if (data.event === 'result') {
|
||||||
noOutputs = false;
|
noOutputs = false;
|
||||||
document.querySelector("#no-results-message")?.remove();
|
document.querySelector("#no-results-message")?.remove();
|
||||||
appendOutput(data.files[0],data.files[1],data.config);
|
appendOutput(data.files[0],data.files[1],data.config);
|
||||||
progressEle.setAttribute('value', 0);
|
progressEle.setAttribute('value', 0);
|
||||||
progressEle.setAttribute('max', formData.steps);
|
progressEle.setAttribute('max', formData.steps);
|
||||||
progressImageEle.src = BLANK_IMAGE_URL;
|
progressImageEle.src = BLANK_IMAGE_URL;
|
||||||
} else if (data.event == 'upscaling-started') {
|
} else if (data.event === 'upscaling-started') {
|
||||||
document.getElementById("processing_cnt").textContent=data.processed_file_cnt;
|
document.getElementById("processing_cnt").textContent=data.processed_file_cnt;
|
||||||
document.getElementById("scaling-inprocess-message").style.display = "block";
|
document.getElementById("scaling-inprocess-message").style.display = "block";
|
||||||
} else if (data.event == 'upscaling-done') {
|
} else if (data.event === 'upscaling-done') {
|
||||||
document.getElementById("scaling-inprocess-message").style.display = "none";
|
document.getElementById("scaling-inprocess-message").style.display = "none";
|
||||||
} else if (data.event == 'step') {
|
} else if (data.event === 'step') {
|
||||||
progressEle.setAttribute('value', data.step);
|
progressEle.setAttribute('value', data.step);
|
||||||
if (data.url) {
|
if (data.url) {
|
||||||
progressImageEle.src = data.url;
|
progressImageEle.src = data.url;
|
||||||
}
|
}
|
||||||
|
} else if (data.event === 'canceled') {
|
||||||
|
// avoid alerting as if this were an error case
|
||||||
|
noOutputs = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -144,6 +147,12 @@ window.onload = () => {
|
|||||||
});
|
});
|
||||||
loadFields(document.querySelector("#generate-form"));
|
loadFields(document.querySelector("#generate-form"));
|
||||||
|
|
||||||
|
document.querySelector('#cancel-button').addEventListener('click', () => {
|
||||||
|
fetch('/cancel').catch(e => {
|
||||||
|
console.error(e);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
if (!config.gfpgan_model_exists) {
|
if (!config.gfpgan_model_exists) {
|
||||||
document.querySelector("#gfpgan").style.display = 'none';
|
document.querySelector("#gfpgan").style.display = 'none';
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user