webui: stream progress events to page

This commit is contained in:
tesseractcat 2022-08-26 21:10:13 -04:00 committed by Kevin Gibbons
parent 9a8cd9684e
commit 070795a3b4
5 changed files with 98 additions and 50 deletions

View File

@ -3,6 +3,7 @@ import base64
import mimetypes
import os
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from ldm.dream.pngwriter import PngWriter
class DreamServer(BaseHTTPRequestHandler):
model = None
@ -50,17 +51,42 @@ class DreamServer(BaseHTTPRequestHandler):
print(f"Request to generate with prompt: {prompt}")
outputs = []
def image_done(image, seed):
config = post_data.copy() # Shallow copy
config['initimg'] = ''
# Write PNGs
pngwriter = PngWriter(
"./outputs/img-samples/", config['prompt'], 1
)
pngwriter.write_image(image, seed)
# Append post_data to log
with open("./outputs/img-samples/dream_web_log.txt", "a") as log:
for file_path, _ in pngwriter.files_written:
log.write(f"{file_path}: {json.dumps(config)}\n")
self.wfile.write(bytes(json.dumps(
{'event':'result', 'files':pngwriter.files_written, 'config':config}
) + '\n',"utf-8"))
def image_progress(image, step):
self.wfile.write(bytes(json.dumps(
{'event':'step', 'step':step}
) + '\n',"utf-8"))
if initimg is None:
# Run txt2img
outputs = self.model.txt2img(prompt,
iterations=iterations,
cfg_scale = cfgscale,
width = width,
height = height,
seed = seed,
steps = steps,
gfpgan_strength = gfpgan_strength)
self.model.prompt2image(prompt,
iterations=iterations,
cfg_scale = cfgscale,
width = width,
height = height,
seed = seed,
steps = steps,
step_callback=image_progress,
image_callback=image_done)
else:
# Decode initimg as base64 to temp file
with open("./img2img-tmp.png", "wb") as f:
@ -68,28 +94,19 @@ class DreamServer(BaseHTTPRequestHandler):
f.write(base64.b64decode(initimg))
# Run img2img
outputs = self.model.img2img(prompt,
init_img = "./img2img-tmp.png",
iterations = iterations,
cfg_scale = cfgscale,
seed = seed,
gfpgan_strength=gfpgan_strength,
steps = steps)
self.model.prompt2image(prompt,
init_img = "./img2img-tmp.png",
iterations = iterations,
cfg_scale = cfgscale,
seed = seed,
steps = steps,
step_callback=image_progress,
image_callback=image_done)
# Remove the temp file
os.remove("./img2img-tmp.png")
print(f"Prompt generated with output: {outputs}")
post_data['initimg'] = '' # Don't send init image back
# Append post_data to log
with open("./outputs/img-samples/dream_web_log.txt", "a") as log:
for output in outputs:
log.write(f"{output[0]}: {json.dumps(post_data)}\n")
outputs = [x + [post_data] for x in outputs] # Append config to each output
result = {'outputs': outputs}
self.wfile.write(bytes(json.dumps(result), "utf-8"))
print(f"Prompt generated!")
class ThreadingDreamServer(ThreadingHTTPServer):

View File

@ -61,6 +61,9 @@ class KSampler(object):
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs,
):
def route_callback(k_callback_values):
if img_callback is not None:
img_callback(k_callback_values['x'], k_callback_values['i'])
sigmas = self.model.get_sigmas(S)
if x_T:
@ -78,7 +81,8 @@ class KSampler(object):
}
return (
K.sampling.__dict__[f'sample_{self.schedule}'](
model_wrap_cfg, x, sigmas, extra_args=extra_args
model_wrap_cfg, x, sigmas, extra_args=extra_args,
callback=route_callback
),
None,
)

View File

@ -202,6 +202,7 @@ class T2I:
ddim_eta=None,
skip_normalize=False,
image_callback=None,
step_callback=None,
# these are specific to txt2img
width=None,
height=None,
@ -231,9 +232,14 @@ class T2I:
gfpgan_strength // strength for GFPGAN. 0.0 preserves image exactly, 1.0 replaces it completely
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
variants // if >0, the 1st generated image will be passed back to img2img to generate the requested number of variants
step_callback // a function or method that will be called each step
image_callback // a function or method that will be called each time an image is generated
To use the callback, define a function of method that receives two arguments, an Image object
To use the step callback, define a function that receives two arguments:
- Image GPU data
- The step number
To use the image callback, define a function of method that receives two arguments, an Image object
and the seed. You can then do whatever you like with the image, including converting it to
different formats and manipulating it. For example:
@ -293,6 +299,7 @@ class T2I:
skip_normalize=skip_normalize,
init_img=init_img,
strength=strength,
callback=step_callback,
)
else:
images_iterator = self._txt2img(
@ -305,6 +312,7 @@ class T2I:
skip_normalize=skip_normalize,
width=width,
height=height,
callback=step_callback,
)
with scope(self.device.type), self.model.ema_scope():
@ -389,6 +397,7 @@ class T2I:
skip_normalize,
width,
height,
callback,
):
"""
An infinite iterator of images from the prompt.
@ -412,6 +421,7 @@ class T2I:
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
eta=ddim_eta,
img_callback=callback
)
yield self._samples_to_images(samples)
@ -427,6 +437,7 @@ class T2I:
skip_normalize,
init_img,
strength,
callback, # Currently not implemented for img2img
):
"""
An infinite iterator of images from the prompt and the initial image

View File

@ -58,8 +58,9 @@
</fieldset>
</form>
<div id="about">For news and support for this web service, visit our <a href="http://github.com/lstein/stable-diffusion">GitHub site</a></div>
<br>
<progress id="progress" value="0" max="1"></progress>
</div>
<hr>
<div id="results">
<div id="no-results-message">
<i><p>No results...</p></i>

View File

@ -7,12 +7,11 @@ function toBase64(file) {
});
}
function appendOutput(output) {
function appendOutput(src, seed, config) {
let outputNode = document.createElement("img");
outputNode.src = output[0];
outputNode.src = src;
let outputConfig = output[2];
let altText = output[1].toString() + " | " + outputConfig.prompt;
let altText = seed.toString() + " | " + config.prompt;
outputNode.alt = altText;
outputNode.title = altText;
@ -20,9 +19,9 @@ function appendOutput(output) {
outputNode.addEventListener('click', () => {
let form = document.querySelector("#generate-form");
for (const [k, v] of new FormData(form)) {
form.querySelector(`*[name=${k}]`).value = outputConfig[k];
form.querySelector(`*[name=${k}]`).value = config[k];
}
document.querySelector("#seed").value = output[1];
document.querySelector("#seed").value = seed;
saveFields(document.querySelector("#generate-form"));
});
@ -30,12 +29,6 @@ function appendOutput(output) {
document.querySelector("#results").prepend(outputNode);
}
function appendOutputs(outputs) {
for (const output of outputs) {
appendOutput(output);
}
}
function saveFields(form) {
for (const [k, v] of new FormData(form)) {
if (typeof v !== 'object') { // Don't save 'file' type
@ -59,21 +52,43 @@ async function generateSubmit(form) {
let formData = Object.fromEntries(new FormData(form));
formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null;
// Post as JSON
document.querySelector('progress').setAttribute('max', formData.steps);
// Post as JSON, using Fetch streaming to get results
fetch(form.action, {
method: form.method,
body: JSON.stringify(formData),
}).then(async (result) => {
let data = await result.json();
}).then(async (response) => {
const reader = response.body.getReader();
let noOutputs = true;
while (true) {
let {value, done} = await reader.read();
value = new TextDecoder().decode(value);
if (done) break;
for (let event of value.split('\n').filter(e => e !== '')) {
const data = JSON.parse(event);
if (data.event == 'result') {
noOutputs = false;
document.querySelector("#no-results-message")?.remove();
for (let [file, seed] of data.files) {
appendOutput(file, seed, data.config);
}
} else if (data.event == 'step') {
document.querySelector('progress').setAttribute('value', data.step.toString());
}
}
}
// Re-enable form, remove no-results-message
form.querySelector('fieldset').removeAttribute('disabled');
document.querySelector("#prompt").value = prompt;
document.querySelector('progress').setAttribute('value', '0');
if (data.outputs.length != 0) {
document.querySelector("#no-results-message")?.remove();
appendOutputs(data.outputs);
} else {
if (noOutputs) {
alert("Error occurred while generating.");
}
});