mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
webui: stream progress events to page
This commit is contained in:
parent
9a8cd9684e
commit
070795a3b4
@ -3,6 +3,7 @@ import base64
|
|||||||
import mimetypes
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||||
|
from ldm.dream.pngwriter import PngWriter
|
||||||
|
|
||||||
class DreamServer(BaseHTTPRequestHandler):
|
class DreamServer(BaseHTTPRequestHandler):
|
||||||
model = None
|
model = None
|
||||||
@ -50,17 +51,42 @@ class DreamServer(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
print(f"Request to generate with prompt: {prompt}")
|
print(f"Request to generate with prompt: {prompt}")
|
||||||
|
|
||||||
outputs = []
|
def image_done(image, seed):
|
||||||
|
config = post_data.copy() # Shallow copy
|
||||||
|
config['initimg'] = ''
|
||||||
|
|
||||||
|
# Write PNGs
|
||||||
|
pngwriter = PngWriter(
|
||||||
|
"./outputs/img-samples/", config['prompt'], 1
|
||||||
|
)
|
||||||
|
pngwriter.write_image(image, seed)
|
||||||
|
|
||||||
|
# Append post_data to log
|
||||||
|
with open("./outputs/img-samples/dream_web_log.txt", "a") as log:
|
||||||
|
for file_path, _ in pngwriter.files_written:
|
||||||
|
log.write(f"{file_path}: {json.dumps(config)}\n")
|
||||||
|
|
||||||
|
self.wfile.write(bytes(json.dumps(
|
||||||
|
{'event':'result', 'files':pngwriter.files_written, 'config':config}
|
||||||
|
) + '\n',"utf-8"))
|
||||||
|
|
||||||
|
def image_progress(image, step):
|
||||||
|
self.wfile.write(bytes(json.dumps(
|
||||||
|
{'event':'step', 'step':step}
|
||||||
|
) + '\n',"utf-8"))
|
||||||
|
|
||||||
if initimg is None:
|
if initimg is None:
|
||||||
# Run txt2img
|
# Run txt2img
|
||||||
outputs = self.model.txt2img(prompt,
|
self.model.prompt2image(prompt,
|
||||||
iterations=iterations,
|
iterations=iterations,
|
||||||
cfg_scale = cfgscale,
|
cfg_scale = cfgscale,
|
||||||
width = width,
|
width = width,
|
||||||
height = height,
|
height = height,
|
||||||
seed = seed,
|
seed = seed,
|
||||||
steps = steps,
|
steps = steps,
|
||||||
gfpgan_strength = gfpgan_strength)
|
|
||||||
|
step_callback=image_progress,
|
||||||
|
image_callback=image_done)
|
||||||
else:
|
else:
|
||||||
# Decode initimg as base64 to temp file
|
# Decode initimg as base64 to temp file
|
||||||
with open("./img2img-tmp.png", "wb") as f:
|
with open("./img2img-tmp.png", "wb") as f:
|
||||||
@ -68,28 +94,19 @@ class DreamServer(BaseHTTPRequestHandler):
|
|||||||
f.write(base64.b64decode(initimg))
|
f.write(base64.b64decode(initimg))
|
||||||
|
|
||||||
# Run img2img
|
# Run img2img
|
||||||
outputs = self.model.img2img(prompt,
|
self.model.prompt2image(prompt,
|
||||||
init_img = "./img2img-tmp.png",
|
init_img = "./img2img-tmp.png",
|
||||||
iterations = iterations,
|
iterations = iterations,
|
||||||
cfg_scale = cfgscale,
|
cfg_scale = cfgscale,
|
||||||
seed = seed,
|
seed = seed,
|
||||||
gfpgan_strength=gfpgan_strength,
|
steps = steps,
|
||||||
steps = steps)
|
step_callback=image_progress,
|
||||||
|
image_callback=image_done)
|
||||||
|
|
||||||
# Remove the temp file
|
# Remove the temp file
|
||||||
os.remove("./img2img-tmp.png")
|
os.remove("./img2img-tmp.png")
|
||||||
|
|
||||||
print(f"Prompt generated with output: {outputs}")
|
print(f"Prompt generated!")
|
||||||
|
|
||||||
post_data['initimg'] = '' # Don't send init image back
|
|
||||||
|
|
||||||
# Append post_data to log
|
|
||||||
with open("./outputs/img-samples/dream_web_log.txt", "a") as log:
|
|
||||||
for output in outputs:
|
|
||||||
log.write(f"{output[0]}: {json.dumps(post_data)}\n")
|
|
||||||
|
|
||||||
outputs = [x + [post_data] for x in outputs] # Append config to each output
|
|
||||||
result = {'outputs': outputs}
|
|
||||||
self.wfile.write(bytes(json.dumps(result), "utf-8"))
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadingDreamServer(ThreadingHTTPServer):
|
class ThreadingDreamServer(ThreadingHTTPServer):
|
||||||
|
@ -61,6 +61,9 @@ class KSampler(object):
|
|||||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
def route_callback(k_callback_values):
|
||||||
|
if img_callback is not None:
|
||||||
|
img_callback(k_callback_values['x'], k_callback_values['i'])
|
||||||
|
|
||||||
sigmas = self.model.get_sigmas(S)
|
sigmas = self.model.get_sigmas(S)
|
||||||
if x_T:
|
if x_T:
|
||||||
@ -78,7 +81,8 @@ class KSampler(object):
|
|||||||
}
|
}
|
||||||
return (
|
return (
|
||||||
K.sampling.__dict__[f'sample_{self.schedule}'](
|
K.sampling.__dict__[f'sample_{self.schedule}'](
|
||||||
model_wrap_cfg, x, sigmas, extra_args=extra_args
|
model_wrap_cfg, x, sigmas, extra_args=extra_args,
|
||||||
|
callback=route_callback
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
@ -202,6 +202,7 @@ class T2I:
|
|||||||
ddim_eta=None,
|
ddim_eta=None,
|
||||||
skip_normalize=False,
|
skip_normalize=False,
|
||||||
image_callback=None,
|
image_callback=None,
|
||||||
|
step_callback=None,
|
||||||
# these are specific to txt2img
|
# these are specific to txt2img
|
||||||
width=None,
|
width=None,
|
||||||
height=None,
|
height=None,
|
||||||
@ -231,9 +232,14 @@ class T2I:
|
|||||||
gfpgan_strength // strength for GFPGAN. 0.0 preserves image exactly, 1.0 replaces it completely
|
gfpgan_strength // strength for GFPGAN. 0.0 preserves image exactly, 1.0 replaces it completely
|
||||||
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
|
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
|
||||||
variants // if >0, the 1st generated image will be passed back to img2img to generate the requested number of variants
|
variants // if >0, the 1st generated image will be passed back to img2img to generate the requested number of variants
|
||||||
|
step_callback // a function or method that will be called each step
|
||||||
image_callback // a function or method that will be called each time an image is generated
|
image_callback // a function or method that will be called each time an image is generated
|
||||||
|
|
||||||
To use the callback, define a function of method that receives two arguments, an Image object
|
To use the step callback, define a function that receives two arguments:
|
||||||
|
- Image GPU data
|
||||||
|
- The step number
|
||||||
|
|
||||||
|
To use the image callback, define a function of method that receives two arguments, an Image object
|
||||||
and the seed. You can then do whatever you like with the image, including converting it to
|
and the seed. You can then do whatever you like with the image, including converting it to
|
||||||
different formats and manipulating it. For example:
|
different formats and manipulating it. For example:
|
||||||
|
|
||||||
@ -293,6 +299,7 @@ class T2I:
|
|||||||
skip_normalize=skip_normalize,
|
skip_normalize=skip_normalize,
|
||||||
init_img=init_img,
|
init_img=init_img,
|
||||||
strength=strength,
|
strength=strength,
|
||||||
|
callback=step_callback,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
images_iterator = self._txt2img(
|
images_iterator = self._txt2img(
|
||||||
@ -305,6 +312,7 @@ class T2I:
|
|||||||
skip_normalize=skip_normalize,
|
skip_normalize=skip_normalize,
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
|
callback=step_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
with scope(self.device.type), self.model.ema_scope():
|
with scope(self.device.type), self.model.ema_scope():
|
||||||
@ -389,6 +397,7 @@ class T2I:
|
|||||||
skip_normalize,
|
skip_normalize,
|
||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
|
callback,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
An infinite iterator of images from the prompt.
|
An infinite iterator of images from the prompt.
|
||||||
@ -412,6 +421,7 @@ class T2I:
|
|||||||
unconditional_guidance_scale=cfg_scale,
|
unconditional_guidance_scale=cfg_scale,
|
||||||
unconditional_conditioning=uc,
|
unconditional_conditioning=uc,
|
||||||
eta=ddim_eta,
|
eta=ddim_eta,
|
||||||
|
img_callback=callback
|
||||||
)
|
)
|
||||||
yield self._samples_to_images(samples)
|
yield self._samples_to_images(samples)
|
||||||
|
|
||||||
@ -427,6 +437,7 @@ class T2I:
|
|||||||
skip_normalize,
|
skip_normalize,
|
||||||
init_img,
|
init_img,
|
||||||
strength,
|
strength,
|
||||||
|
callback, # Currently not implemented for img2img
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
An infinite iterator of images from the prompt and the initial image
|
An infinite iterator of images from the prompt and the initial image
|
||||||
|
@ -58,8 +58,9 @@
|
|||||||
</fieldset>
|
</fieldset>
|
||||||
</form>
|
</form>
|
||||||
<div id="about">For news and support for this web service, visit our <a href="http://github.com/lstein/stable-diffusion">GitHub site</a></div>
|
<div id="about">For news and support for this web service, visit our <a href="http://github.com/lstein/stable-diffusion">GitHub site</a></div>
|
||||||
|
<br>
|
||||||
|
<progress id="progress" value="0" max="1"></progress>
|
||||||
</div>
|
</div>
|
||||||
<hr>
|
|
||||||
<div id="results">
|
<div id="results">
|
||||||
<div id="no-results-message">
|
<div id="no-results-message">
|
||||||
<i><p>No results...</p></i>
|
<i><p>No results...</p></i>
|
||||||
|
@ -7,12 +7,11 @@ function toBase64(file) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
function appendOutput(output) {
|
function appendOutput(src, seed, config) {
|
||||||
let outputNode = document.createElement("img");
|
let outputNode = document.createElement("img");
|
||||||
outputNode.src = output[0];
|
outputNode.src = src;
|
||||||
|
|
||||||
let outputConfig = output[2];
|
let altText = seed.toString() + " | " + config.prompt;
|
||||||
let altText = output[1].toString() + " | " + outputConfig.prompt;
|
|
||||||
outputNode.alt = altText;
|
outputNode.alt = altText;
|
||||||
outputNode.title = altText;
|
outputNode.title = altText;
|
||||||
|
|
||||||
@ -20,9 +19,9 @@ function appendOutput(output) {
|
|||||||
outputNode.addEventListener('click', () => {
|
outputNode.addEventListener('click', () => {
|
||||||
let form = document.querySelector("#generate-form");
|
let form = document.querySelector("#generate-form");
|
||||||
for (const [k, v] of new FormData(form)) {
|
for (const [k, v] of new FormData(form)) {
|
||||||
form.querySelector(`*[name=${k}]`).value = outputConfig[k];
|
form.querySelector(`*[name=${k}]`).value = config[k];
|
||||||
}
|
}
|
||||||
document.querySelector("#seed").value = output[1];
|
document.querySelector("#seed").value = seed;
|
||||||
|
|
||||||
saveFields(document.querySelector("#generate-form"));
|
saveFields(document.querySelector("#generate-form"));
|
||||||
});
|
});
|
||||||
@ -30,12 +29,6 @@ function appendOutput(output) {
|
|||||||
document.querySelector("#results").prepend(outputNode);
|
document.querySelector("#results").prepend(outputNode);
|
||||||
}
|
}
|
||||||
|
|
||||||
function appendOutputs(outputs) {
|
|
||||||
for (const output of outputs) {
|
|
||||||
appendOutput(output);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function saveFields(form) {
|
function saveFields(form) {
|
||||||
for (const [k, v] of new FormData(form)) {
|
for (const [k, v] of new FormData(form)) {
|
||||||
if (typeof v !== 'object') { // Don't save 'file' type
|
if (typeof v !== 'object') { // Don't save 'file' type
|
||||||
@ -59,21 +52,43 @@ async function generateSubmit(form) {
|
|||||||
let formData = Object.fromEntries(new FormData(form));
|
let formData = Object.fromEntries(new FormData(form));
|
||||||
formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null;
|
formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null;
|
||||||
|
|
||||||
// Post as JSON
|
document.querySelector('progress').setAttribute('max', formData.steps);
|
||||||
|
|
||||||
|
// Post as JSON, using Fetch streaming to get results
|
||||||
fetch(form.action, {
|
fetch(form.action, {
|
||||||
method: form.method,
|
method: form.method,
|
||||||
body: JSON.stringify(formData),
|
body: JSON.stringify(formData),
|
||||||
}).then(async (result) => {
|
}).then(async (response) => {
|
||||||
let data = await result.json();
|
const reader = response.body.getReader();
|
||||||
|
|
||||||
|
let noOutputs = true;
|
||||||
|
while (true) {
|
||||||
|
let {value, done} = await reader.read();
|
||||||
|
value = new TextDecoder().decode(value);
|
||||||
|
if (done) break;
|
||||||
|
|
||||||
|
for (let event of value.split('\n').filter(e => e !== '')) {
|
||||||
|
const data = JSON.parse(event);
|
||||||
|
|
||||||
|
if (data.event == 'result') {
|
||||||
|
noOutputs = false;
|
||||||
|
document.querySelector("#no-results-message")?.remove();
|
||||||
|
|
||||||
|
for (let [file, seed] of data.files) {
|
||||||
|
appendOutput(file, seed, data.config);
|
||||||
|
}
|
||||||
|
} else if (data.event == 'step') {
|
||||||
|
document.querySelector('progress').setAttribute('value', data.step.toString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Re-enable form, remove no-results-message
|
// Re-enable form, remove no-results-message
|
||||||
form.querySelector('fieldset').removeAttribute('disabled');
|
form.querySelector('fieldset').removeAttribute('disabled');
|
||||||
document.querySelector("#prompt").value = prompt;
|
document.querySelector("#prompt").value = prompt;
|
||||||
|
document.querySelector('progress').setAttribute('value', '0');
|
||||||
|
|
||||||
if (data.outputs.length != 0) {
|
if (noOutputs) {
|
||||||
document.querySelector("#no-results-message")?.remove();
|
|
||||||
appendOutputs(data.outputs);
|
|
||||||
} else {
|
|
||||||
alert("Error occurred while generating.");
|
alert("Error occurred while generating.");
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
Loading…
Reference in New Issue
Block a user