correctly handle upscaling in webUI, including displaying status messages during GFPGAN/ESRGAN postprocessing

This commit is contained in:
Lincoln Stein 2022-08-29 12:08:18 -04:00
commit 4acfb76be6
7 changed files with 136 additions and 53 deletions

View File

@ -68,15 +68,12 @@ class PngWriter:
while not finished:
series += 1
filename = f'{basecount:06}.{seed}.png'
if self.batch_size > 1 or os.path.exists(
os.path.join(self.outdir, filename)
):
path = os.path.join(self.outdir, filename)
if self.batch_size > 1 or os.path.exists(path):
if upscaled:
break
filename = f'{basecount:06}.{seed}.{series:02}.png'
finished = not os.path.exists(
os.path.join(self.outdir, filename)
)
finished = not os.path.exists(path)
return os.path.join(self.outdir, filename)
def save_image_and_prompt_to_png(self, image, prompt, path):

View File

@ -3,6 +3,7 @@ import base64
import mimetypes
import os
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from ldm.dream.pngwriter import PngWriter
class DreamServer(BaseHTTPRequestHandler):
model = None
@ -52,11 +53,63 @@ class DreamServer(BaseHTTPRequestHandler):
seed = None if int(post_data['seed']) == -1 else int(post_data['seed'])
print(f"Request to generate with prompt: {prompt}")
# In order to handle upscaled images, the PngWriter needs to maintain state
# across images generated by each call to prompt2img(), so we define it in
# the outer scope of image_done()
config = post_data.copy() # Shallow copy
config['initimg'] = ''
images_generated = 0 # helps keep track of when upscaling is started
images_upscaled = 0 # helps keep track of when upscaling is completed
pngwriter = PngWriter(
"./outputs/img-samples/", config['prompt'], 1
)
# if upscaling is requested, then this will be called twice, once when
# the images are first generated, and then again when after upscaling
# is complete. The upscaling replaces the original file, so the second
# entry should not be inserted into the image list.
def image_done(image, seed, upscaled=False):
pngwriter.write_image(image, seed, upscaled)
# Append post_data to log, but only once!
if not upscaled:
current_image = pngwriter.files_written[-1]
with open("./outputs/img-samples/dream_web_log.txt", "a") as log:
log.write(f"{current_image[0]}: {json.dumps(config)}\n")
self.wfile.write(bytes(json.dumps(
{'event':'result', 'files':current_image, 'config':config}
) + '\n',"utf-8"))
# control state of the "postprocessing..." message
upscaling_requested = upscale or gfpgan_strength>0
nonlocal images_generated # NB: Is this bad python style? It is typical usage in a perl closure.
nonlocal images_upscaled # NB: Is this bad python style? It is typical usage in a perl closure.
if upscaled:
images_upscaled += 1
else:
images_generated +=1
if upscaling_requested:
action = None
if images_generated >= iterations:
if images_upscaled < iterations:
action = 'upscaling-started'
else:
action = 'upscaling-done'
if action:
x = images_upscaled+1
self.wfile.write(bytes(json.dumps(
{'event':action,'processed_file_cnt':f'{x}/{iterations}'}
) + '\n',"utf-8"))
def image_progress(image, step):
self.wfile.write(bytes(json.dumps(
{'event':'step', 'step':step}
) + '\n',"utf-8"))
outputs = []
if initimg is None:
# Run txt2img
outputs = self.model.txt2img(prompt,
self.model.prompt2image(prompt,
iterations=iterations,
cfg_scale = cfgscale,
width = width,
@ -64,8 +117,9 @@ class DreamServer(BaseHTTPRequestHandler):
seed = seed,
steps = steps,
gfpgan_strength = gfpgan_strength,
upscale = upscale
)
upscale = upscale,
step_callback=image_progress,
image_callback=image_done)
else:
# Decode initimg as base64 to temp file
with open("./img2img-tmp.png", "wb") as f:
@ -73,30 +127,21 @@ class DreamServer(BaseHTTPRequestHandler):
f.write(base64.b64decode(initimg))
# Run img2img
outputs = self.model.img2img(prompt,
self.model.prompt2image(prompt,
init_img = "./img2img-tmp.png",
iterations = iterations,
cfg_scale = cfgscale,
seed = seed,
steps = steps,
gfpgan_strength=gfpgan_strength,
upscale = upscale,
steps = steps
)
step_callback=image_progress,
image_callback=image_done)
# Remove the temp file
os.remove("./img2img-tmp.png")
print(f"Prompt generated with output: {outputs}")
post_data['initimg'] = '' # Don't send init image back
# Append post_data to log
with open("./outputs/img-samples/dream_web_log.txt", "a", encoding="utf-8") as log:
for output in outputs:
log.write(f"{output[0]}: {json.dumps(post_data)}\n")
outputs = [x + [post_data] for x in outputs] # Append config to each output
result = {'outputs': outputs}
self.wfile.write(bytes(json.dumps(result), "utf-8"))
print(f"Prompt generated!")
class ThreadingDreamServer(ThreadingHTTPServer):

View File

@ -61,6 +61,9 @@ class KSampler(object):
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs,
):
def route_callback(k_callback_values):
if img_callback is not None:
img_callback(k_callback_values['x'], k_callback_values['i'])
sigmas = self.model.get_sigmas(S)
if x_T:
@ -78,7 +81,8 @@ class KSampler(object):
}
return (
K.sampling.__dict__[f'sample_{self.schedule}'](
model_wrap_cfg, x, sigmas, extra_args=extra_args
model_wrap_cfg, x, sigmas, extra_args=extra_args,
callback=route_callback
),
None,
)

View File

@ -201,6 +201,7 @@ class T2I:
ddim_eta=None,
skip_normalize=False,
image_callback=None,
step_callback=None,
# these are specific to txt2img
width=None,
height=None,
@ -230,9 +231,14 @@ class T2I:
gfpgan_strength // strength for GFPGAN. 0.0 preserves image exactly, 1.0 replaces it completely
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
variants // if >0, the 1st generated image will be passed back to img2img to generate the requested number of variants
step_callback // a function or method that will be called each step
image_callback // a function or method that will be called each time an image is generated
To use the callback, define a function of method that receives two arguments, an Image object
To use the step callback, define a function that receives two arguments:
- Image GPU data
- The step number
To use the image callback, define a function of method that receives two arguments, an Image object
and the seed. You can then do whatever you like with the image, including converting it to
different formats and manipulating it. For example:
@ -292,6 +298,7 @@ class T2I:
skip_normalize=skip_normalize,
init_img=init_img,
strength=strength,
callback=step_callback,
)
else:
images_iterator = self._txt2img(
@ -304,6 +311,7 @@ class T2I:
skip_normalize=skip_normalize,
width=width,
height=height,
callback=step_callback,
)
with scope(self.device.type), self.model.ema_scope():
@ -390,6 +398,7 @@ class T2I:
skip_normalize,
width,
height,
callback,
):
"""
An infinite iterator of images from the prompt.
@ -413,6 +422,7 @@ class T2I:
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
eta=ddim_eta,
img_callback=callback
)
yield self._samples_to_images(samples)
@ -428,6 +438,7 @@ class T2I:
skip_normalize,
init_img,
strength,
callback, # Currently not implemented for img2img
):
"""
An infinite iterator of images from the prompt and the initial image

View File

@ -18,6 +18,11 @@ fieldset {
#fieldset-search {
display: flex;
}
#scaling-inprocess-message{
font-weight: bold;
font-style: italic;
display: none;
}
#prompt {
flex-grow: 1;

View File

@ -79,8 +79,12 @@
</fieldset>
</form>
<div id="about">For news and support for this web service, visit our <a href="http://github.com/lstein/stable-diffusion">GitHub site</a></div>
<br>
<progress id="progress" value="0" max="1"></progress>
<div id="scaling-inprocess-message">
<i><span>Postprocessing...</span><span id="processing_cnt">1/3</span></i>
</div>
</div>
<hr>
<div id="results">
<div id="no-results-message">
<i><p>No results...</p></i>

View File

@ -7,12 +7,11 @@ function toBase64(file) {
});
}
function appendOutput(output) {
function appendOutput(src, seed, config) {
let outputNode = document.createElement("img");
outputNode.src = output[0];
outputNode.src = src;
let outputConfig = output[2];
let altText = output[1].toString() + " | " + outputConfig.prompt;
let altText = seed.toString() + " | " + config.prompt;
outputNode.alt = altText;
outputNode.title = altText;
@ -20,9 +19,9 @@ function appendOutput(output) {
outputNode.addEventListener('click', () => {
let form = document.querySelector("#generate-form");
for (const [k, v] of new FormData(form)) {
form.querySelector(`*[name=${k}]`).value = outputConfig[k];
form.querySelector(`*[name=${k}]`).value = config[k];
}
document.querySelector("#seed").value = output[1];
document.querySelector("#seed").value = seed;
saveFields(document.querySelector("#generate-form"));
});
@ -30,12 +29,6 @@ function appendOutput(output) {
document.querySelector("#results").prepend(outputNode);
}
function appendOutputs(outputs) {
for (const output of outputs) {
appendOutput(output);
}
}
function saveFields(form) {
for (const [k, v] of new FormData(form)) {
if (typeof v !== 'object') { // Don't save 'file' type
@ -65,21 +58,45 @@ async function generateSubmit(form) {
let formData = Object.fromEntries(new FormData(form));
formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null;
// Post as JSON
document.querySelector('progress').setAttribute('max', formData.steps);
// Post as JSON, using Fetch streaming to get results
fetch(form.action, {
method: form.method,
body: JSON.stringify(formData),
}).then(async (result) => {
let data = await result.json();
}).then(async (response) => {
const reader = response.body.getReader();
let noOutputs = true;
while (true) {
let {value, done} = await reader.read();
value = new TextDecoder().decode(value);
if (done) break;
for (let event of value.split('\n').filter(e => e !== '')) {
const data = JSON.parse(event);
if (data.event == 'result') {
noOutputs = false;
document.querySelector("#no-results-message")?.remove();
appendOutput(data.files[0],data.files[1],data.config)
} else if (data.event == 'upscaling-started') {
document.getElementById("processing_cnt").textContent=data.processed_file_cnt;
document.getElementById("scaling-inprocess-message").style.display = "block";
} else if (data.event == 'upscaling-done') {
document.getElementById("scaling-inprocess-message").style.display = "none";
} else if (data.event == 'step') {
document.querySelector('progress').setAttribute('value', data.step.toString());
}
}
}
// Re-enable form, remove no-results-message
form.querySelector('fieldset').removeAttribute('disabled');
document.querySelector("#prompt").value = prompt;
document.querySelector('progress').setAttribute('value', '0');
if (data.outputs.length != 0) {
document.querySelector("#no-results-message")?.remove();
appendOutputs(data.outputs);
} else {
if (noOutputs) {
alert("Error occurred while generating.");
}
});