mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
webui: stream progress events to page
This commit is contained in:
parent
9a8cd9684e
commit
070795a3b4
@ -3,6 +3,7 @@ import base64
|
||||
import mimetypes
|
||||
import os
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from ldm.dream.pngwriter import PngWriter
|
||||
|
||||
class DreamServer(BaseHTTPRequestHandler):
|
||||
model = None
|
||||
@ -50,17 +51,42 @@ class DreamServer(BaseHTTPRequestHandler):
|
||||
|
||||
print(f"Request to generate with prompt: {prompt}")
|
||||
|
||||
outputs = []
|
||||
def image_done(image, seed):
|
||||
config = post_data.copy() # Shallow copy
|
||||
config['initimg'] = ''
|
||||
|
||||
# Write PNGs
|
||||
pngwriter = PngWriter(
|
||||
"./outputs/img-samples/", config['prompt'], 1
|
||||
)
|
||||
pngwriter.write_image(image, seed)
|
||||
|
||||
# Append post_data to log
|
||||
with open("./outputs/img-samples/dream_web_log.txt", "a") as log:
|
||||
for file_path, _ in pngwriter.files_written:
|
||||
log.write(f"{file_path}: {json.dumps(config)}\n")
|
||||
|
||||
self.wfile.write(bytes(json.dumps(
|
||||
{'event':'result', 'files':pngwriter.files_written, 'config':config}
|
||||
) + '\n',"utf-8"))
|
||||
|
||||
def image_progress(image, step):
|
||||
self.wfile.write(bytes(json.dumps(
|
||||
{'event':'step', 'step':step}
|
||||
) + '\n',"utf-8"))
|
||||
|
||||
if initimg is None:
|
||||
# Run txt2img
|
||||
outputs = self.model.txt2img(prompt,
|
||||
iterations=iterations,
|
||||
cfg_scale = cfgscale,
|
||||
width = width,
|
||||
height = height,
|
||||
seed = seed,
|
||||
steps = steps,
|
||||
gfpgan_strength = gfpgan_strength)
|
||||
self.model.prompt2image(prompt,
|
||||
iterations=iterations,
|
||||
cfg_scale = cfgscale,
|
||||
width = width,
|
||||
height = height,
|
||||
seed = seed,
|
||||
steps = steps,
|
||||
|
||||
step_callback=image_progress,
|
||||
image_callback=image_done)
|
||||
else:
|
||||
# Decode initimg as base64 to temp file
|
||||
with open("./img2img-tmp.png", "wb") as f:
|
||||
@ -68,28 +94,19 @@ class DreamServer(BaseHTTPRequestHandler):
|
||||
f.write(base64.b64decode(initimg))
|
||||
|
||||
# Run img2img
|
||||
outputs = self.model.img2img(prompt,
|
||||
init_img = "./img2img-tmp.png",
|
||||
iterations = iterations,
|
||||
cfg_scale = cfgscale,
|
||||
seed = seed,
|
||||
gfpgan_strength=gfpgan_strength,
|
||||
steps = steps)
|
||||
self.model.prompt2image(prompt,
|
||||
init_img = "./img2img-tmp.png",
|
||||
iterations = iterations,
|
||||
cfg_scale = cfgscale,
|
||||
seed = seed,
|
||||
steps = steps,
|
||||
step_callback=image_progress,
|
||||
image_callback=image_done)
|
||||
|
||||
# Remove the temp file
|
||||
os.remove("./img2img-tmp.png")
|
||||
|
||||
print(f"Prompt generated with output: {outputs}")
|
||||
|
||||
post_data['initimg'] = '' # Don't send init image back
|
||||
|
||||
# Append post_data to log
|
||||
with open("./outputs/img-samples/dream_web_log.txt", "a") as log:
|
||||
for output in outputs:
|
||||
log.write(f"{output[0]}: {json.dumps(post_data)}\n")
|
||||
|
||||
outputs = [x + [post_data] for x in outputs] # Append config to each output
|
||||
result = {'outputs': outputs}
|
||||
self.wfile.write(bytes(json.dumps(result), "utf-8"))
|
||||
print(f"Prompt generated!")
|
||||
|
||||
|
||||
class ThreadingDreamServer(ThreadingHTTPServer):
|
||||
|
@ -61,6 +61,9 @@ class KSampler(object):
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs,
|
||||
):
|
||||
def route_callback(k_callback_values):
|
||||
if img_callback is not None:
|
||||
img_callback(k_callback_values['x'], k_callback_values['i'])
|
||||
|
||||
sigmas = self.model.get_sigmas(S)
|
||||
if x_T:
|
||||
@ -78,7 +81,8 @@ class KSampler(object):
|
||||
}
|
||||
return (
|
||||
K.sampling.__dict__[f'sample_{self.schedule}'](
|
||||
model_wrap_cfg, x, sigmas, extra_args=extra_args
|
||||
model_wrap_cfg, x, sigmas, extra_args=extra_args,
|
||||
callback=route_callback
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
@ -202,6 +202,7 @@ class T2I:
|
||||
ddim_eta=None,
|
||||
skip_normalize=False,
|
||||
image_callback=None,
|
||||
step_callback=None,
|
||||
# these are specific to txt2img
|
||||
width=None,
|
||||
height=None,
|
||||
@ -231,9 +232,14 @@ class T2I:
|
||||
gfpgan_strength // strength for GFPGAN. 0.0 preserves image exactly, 1.0 replaces it completely
|
||||
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
|
||||
variants // if >0, the 1st generated image will be passed back to img2img to generate the requested number of variants
|
||||
step_callback // a function or method that will be called each step
|
||||
image_callback // a function or method that will be called each time an image is generated
|
||||
|
||||
To use the callback, define a function of method that receives two arguments, an Image object
|
||||
To use the step callback, define a function that receives two arguments:
|
||||
- Image GPU data
|
||||
- The step number
|
||||
|
||||
To use the image callback, define a function of method that receives two arguments, an Image object
|
||||
and the seed. You can then do whatever you like with the image, including converting it to
|
||||
different formats and manipulating it. For example:
|
||||
|
||||
@ -293,6 +299,7 @@ class T2I:
|
||||
skip_normalize=skip_normalize,
|
||||
init_img=init_img,
|
||||
strength=strength,
|
||||
callback=step_callback,
|
||||
)
|
||||
else:
|
||||
images_iterator = self._txt2img(
|
||||
@ -305,6 +312,7 @@ class T2I:
|
||||
skip_normalize=skip_normalize,
|
||||
width=width,
|
||||
height=height,
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
with scope(self.device.type), self.model.ema_scope():
|
||||
@ -389,6 +397,7 @@ class T2I:
|
||||
skip_normalize,
|
||||
width,
|
||||
height,
|
||||
callback,
|
||||
):
|
||||
"""
|
||||
An infinite iterator of images from the prompt.
|
||||
@ -412,6 +421,7 @@ class T2I:
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,
|
||||
eta=ddim_eta,
|
||||
img_callback=callback
|
||||
)
|
||||
yield self._samples_to_images(samples)
|
||||
|
||||
@ -427,6 +437,7 @@ class T2I:
|
||||
skip_normalize,
|
||||
init_img,
|
||||
strength,
|
||||
callback, # Currently not implemented for img2img
|
||||
):
|
||||
"""
|
||||
An infinite iterator of images from the prompt and the initial image
|
||||
|
@ -58,8 +58,9 @@
|
||||
</fieldset>
|
||||
</form>
|
||||
<div id="about">For news and support for this web service, visit our <a href="http://github.com/lstein/stable-diffusion">GitHub site</a></div>
|
||||
<br>
|
||||
<progress id="progress" value="0" max="1"></progress>
|
||||
</div>
|
||||
<hr>
|
||||
<div id="results">
|
||||
<div id="no-results-message">
|
||||
<i><p>No results...</p></i>
|
||||
|
@ -7,12 +7,11 @@ function toBase64(file) {
|
||||
});
|
||||
}
|
||||
|
||||
function appendOutput(output) {
|
||||
function appendOutput(src, seed, config) {
|
||||
let outputNode = document.createElement("img");
|
||||
outputNode.src = output[0];
|
||||
outputNode.src = src;
|
||||
|
||||
let outputConfig = output[2];
|
||||
let altText = output[1].toString() + " | " + outputConfig.prompt;
|
||||
let altText = seed.toString() + " | " + config.prompt;
|
||||
outputNode.alt = altText;
|
||||
outputNode.title = altText;
|
||||
|
||||
@ -20,9 +19,9 @@ function appendOutput(output) {
|
||||
outputNode.addEventListener('click', () => {
|
||||
let form = document.querySelector("#generate-form");
|
||||
for (const [k, v] of new FormData(form)) {
|
||||
form.querySelector(`*[name=${k}]`).value = outputConfig[k];
|
||||
form.querySelector(`*[name=${k}]`).value = config[k];
|
||||
}
|
||||
document.querySelector("#seed").value = output[1];
|
||||
document.querySelector("#seed").value = seed;
|
||||
|
||||
saveFields(document.querySelector("#generate-form"));
|
||||
});
|
||||
@ -30,12 +29,6 @@ function appendOutput(output) {
|
||||
document.querySelector("#results").prepend(outputNode);
|
||||
}
|
||||
|
||||
function appendOutputs(outputs) {
|
||||
for (const output of outputs) {
|
||||
appendOutput(output);
|
||||
}
|
||||
}
|
||||
|
||||
function saveFields(form) {
|
||||
for (const [k, v] of new FormData(form)) {
|
||||
if (typeof v !== 'object') { // Don't save 'file' type
|
||||
@ -59,21 +52,43 @@ async function generateSubmit(form) {
|
||||
let formData = Object.fromEntries(new FormData(form));
|
||||
formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null;
|
||||
|
||||
// Post as JSON
|
||||
document.querySelector('progress').setAttribute('max', formData.steps);
|
||||
|
||||
// Post as JSON, using Fetch streaming to get results
|
||||
fetch(form.action, {
|
||||
method: form.method,
|
||||
body: JSON.stringify(formData),
|
||||
}).then(async (result) => {
|
||||
let data = await result.json();
|
||||
}).then(async (response) => {
|
||||
const reader = response.body.getReader();
|
||||
|
||||
let noOutputs = true;
|
||||
while (true) {
|
||||
let {value, done} = await reader.read();
|
||||
value = new TextDecoder().decode(value);
|
||||
if (done) break;
|
||||
|
||||
for (let event of value.split('\n').filter(e => e !== '')) {
|
||||
const data = JSON.parse(event);
|
||||
|
||||
if (data.event == 'result') {
|
||||
noOutputs = false;
|
||||
document.querySelector("#no-results-message")?.remove();
|
||||
|
||||
for (let [file, seed] of data.files) {
|
||||
appendOutput(file, seed, data.config);
|
||||
}
|
||||
} else if (data.event == 'step') {
|
||||
document.querySelector('progress').setAttribute('value', data.step.toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Re-enable form, remove no-results-message
|
||||
form.querySelector('fieldset').removeAttribute('disabled');
|
||||
document.querySelector("#prompt").value = prompt;
|
||||
document.querySelector('progress').setAttribute('value', '0');
|
||||
|
||||
if (data.outputs.length != 0) {
|
||||
document.querySelector("#no-results-message")?.remove();
|
||||
appendOutputs(data.outputs);
|
||||
} else {
|
||||
if (noOutputs) {
|
||||
alert("Error occurred while generating.");
|
||||
}
|
||||
});
|
||||
|
Loading…
Reference in New Issue
Block a user