correctly handle upscaling in webUI, including displaying status messages during GFPGAN/ESRGAN postprocessing

This commit is contained in:
Lincoln Stein 2022-08-29 12:08:18 -04:00
commit 4acfb76be6
7 changed files with 136 additions and 53 deletions

View File

@ -68,15 +68,12 @@ class PngWriter:
while not finished: while not finished:
series += 1 series += 1
filename = f'{basecount:06}.{seed}.png' filename = f'{basecount:06}.{seed}.png'
if self.batch_size > 1 or os.path.exists( path = os.path.join(self.outdir, filename)
os.path.join(self.outdir, filename) if self.batch_size > 1 or os.path.exists(path):
):
if upscaled: if upscaled:
break break
filename = f'{basecount:06}.{seed}.{series:02}.png' filename = f'{basecount:06}.{seed}.{series:02}.png'
finished = not os.path.exists( finished = not os.path.exists(path)
os.path.join(self.outdir, filename)
)
return os.path.join(self.outdir, filename) return os.path.join(self.outdir, filename)
def save_image_and_prompt_to_png(self, image, prompt, path): def save_image_and_prompt_to_png(self, image, prompt, path):

View File

@ -3,6 +3,7 @@ import base64
import mimetypes import mimetypes
import os import os
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from ldm.dream.pngwriter import PngWriter
class DreamServer(BaseHTTPRequestHandler): class DreamServer(BaseHTTPRequestHandler):
model = None model = None
@ -52,11 +53,63 @@ class DreamServer(BaseHTTPRequestHandler):
seed = None if int(post_data['seed']) == -1 else int(post_data['seed']) seed = None if int(post_data['seed']) == -1 else int(post_data['seed'])
print(f"Request to generate with prompt: {prompt}") print(f"Request to generate with prompt: {prompt}")
# In order to handle upscaled images, the PngWriter needs to maintain state
# across images generated by each call to prompt2img(), so we define it in
# the outer scope of image_done()
config = post_data.copy() # Shallow copy
config['initimg'] = ''
images_generated = 0 # helps keep track of when upscaling is started
images_upscaled = 0 # helps keep track of when upscaling is completed
pngwriter = PngWriter(
"./outputs/img-samples/", config['prompt'], 1
)
# if upscaling is requested, then this will be called twice, once when
# the images are first generated, and then again when after upscaling
# is complete. The upscaling replaces the original file, so the second
# entry should not be inserted into the image list.
def image_done(image, seed, upscaled=False):
pngwriter.write_image(image, seed, upscaled)
# Append post_data to log, but only once!
if not upscaled:
current_image = pngwriter.files_written[-1]
with open("./outputs/img-samples/dream_web_log.txt", "a") as log:
log.write(f"{current_image[0]}: {json.dumps(config)}\n")
self.wfile.write(bytes(json.dumps(
{'event':'result', 'files':current_image, 'config':config}
) + '\n',"utf-8"))
# control state of the "postprocessing..." message
upscaling_requested = upscale or gfpgan_strength>0
nonlocal images_generated # NB: Is this bad python style? It is typical usage in a perl closure.
nonlocal images_upscaled # NB: Is this bad python style? It is typical usage in a perl closure.
if upscaled:
images_upscaled += 1
else:
images_generated +=1
if upscaling_requested:
action = None
if images_generated >= iterations:
if images_upscaled < iterations:
action = 'upscaling-started'
else:
action = 'upscaling-done'
if action:
x = images_upscaled+1
self.wfile.write(bytes(json.dumps(
{'event':action,'processed_file_cnt':f'{x}/{iterations}'}
) + '\n',"utf-8"))
def image_progress(image, step):
self.wfile.write(bytes(json.dumps(
{'event':'step', 'step':step}
) + '\n',"utf-8"))
outputs = []
if initimg is None: if initimg is None:
# Run txt2img # Run txt2img
outputs = self.model.txt2img(prompt, self.model.prompt2image(prompt,
iterations=iterations, iterations=iterations,
cfg_scale = cfgscale, cfg_scale = cfgscale,
width = width, width = width,
@ -64,8 +117,9 @@ class DreamServer(BaseHTTPRequestHandler):
seed = seed, seed = seed,
steps = steps, steps = steps,
gfpgan_strength = gfpgan_strength, gfpgan_strength = gfpgan_strength,
upscale = upscale upscale = upscale,
) step_callback=image_progress,
image_callback=image_done)
else: else:
# Decode initimg as base64 to temp file # Decode initimg as base64 to temp file
with open("./img2img-tmp.png", "wb") as f: with open("./img2img-tmp.png", "wb") as f:
@ -73,30 +127,21 @@ class DreamServer(BaseHTTPRequestHandler):
f.write(base64.b64decode(initimg)) f.write(base64.b64decode(initimg))
# Run img2img # Run img2img
outputs = self.model.img2img(prompt, self.model.prompt2image(prompt,
init_img = "./img2img-tmp.png", init_img = "./img2img-tmp.png",
iterations = iterations, iterations = iterations,
cfg_scale = cfgscale, cfg_scale = cfgscale,
seed = seed, seed = seed,
steps = steps,
gfpgan_strength=gfpgan_strength, gfpgan_strength=gfpgan_strength,
upscale = upscale, upscale = upscale,
steps = steps step_callback=image_progress,
) image_callback=image_done)
# Remove the temp file # Remove the temp file
os.remove("./img2img-tmp.png") os.remove("./img2img-tmp.png")
print(f"Prompt generated with output: {outputs}") print(f"Prompt generated!")
post_data['initimg'] = '' # Don't send init image back
# Append post_data to log
with open("./outputs/img-samples/dream_web_log.txt", "a", encoding="utf-8") as log:
for output in outputs:
log.write(f"{output[0]}: {json.dumps(post_data)}\n")
outputs = [x + [post_data] for x in outputs] # Append config to each output
result = {'outputs': outputs}
self.wfile.write(bytes(json.dumps(result), "utf-8"))
class ThreadingDreamServer(ThreadingHTTPServer): class ThreadingDreamServer(ThreadingHTTPServer):

View File

@ -61,6 +61,9 @@ class KSampler(object):
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs, **kwargs,
): ):
def route_callback(k_callback_values):
if img_callback is not None:
img_callback(k_callback_values['x'], k_callback_values['i'])
sigmas = self.model.get_sigmas(S) sigmas = self.model.get_sigmas(S)
if x_T: if x_T:
@ -78,7 +81,8 @@ class KSampler(object):
} }
return ( return (
K.sampling.__dict__[f'sample_{self.schedule}']( K.sampling.__dict__[f'sample_{self.schedule}'](
model_wrap_cfg, x, sigmas, extra_args=extra_args model_wrap_cfg, x, sigmas, extra_args=extra_args,
callback=route_callback
), ),
None, None,
) )

View File

@ -201,6 +201,7 @@ class T2I:
ddim_eta=None, ddim_eta=None,
skip_normalize=False, skip_normalize=False,
image_callback=None, image_callback=None,
step_callback=None,
# these are specific to txt2img # these are specific to txt2img
width=None, width=None,
height=None, height=None,
@ -230,9 +231,14 @@ class T2I:
gfpgan_strength // strength for GFPGAN. 0.0 preserves image exactly, 1.0 replaces it completely gfpgan_strength // strength for GFPGAN. 0.0 preserves image exactly, 1.0 replaces it completely
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image) ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
variants // if >0, the 1st generated image will be passed back to img2img to generate the requested number of variants variants // if >0, the 1st generated image will be passed back to img2img to generate the requested number of variants
step_callback // a function or method that will be called each step
image_callback // a function or method that will be called each time an image is generated image_callback // a function or method that will be called each time an image is generated
To use the callback, define a function of method that receives two arguments, an Image object To use the step callback, define a function that receives two arguments:
- Image GPU data
- The step number
To use the image callback, define a function of method that receives two arguments, an Image object
and the seed. You can then do whatever you like with the image, including converting it to and the seed. You can then do whatever you like with the image, including converting it to
different formats and manipulating it. For example: different formats and manipulating it. For example:
@ -292,6 +298,7 @@ class T2I:
skip_normalize=skip_normalize, skip_normalize=skip_normalize,
init_img=init_img, init_img=init_img,
strength=strength, strength=strength,
callback=step_callback,
) )
else: else:
images_iterator = self._txt2img( images_iterator = self._txt2img(
@ -304,6 +311,7 @@ class T2I:
skip_normalize=skip_normalize, skip_normalize=skip_normalize,
width=width, width=width,
height=height, height=height,
callback=step_callback,
) )
with scope(self.device.type), self.model.ema_scope(): with scope(self.device.type), self.model.ema_scope():
@ -390,6 +398,7 @@ class T2I:
skip_normalize, skip_normalize,
width, width,
height, height,
callback,
): ):
""" """
An infinite iterator of images from the prompt. An infinite iterator of images from the prompt.
@ -413,6 +422,7 @@ class T2I:
unconditional_guidance_scale=cfg_scale, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc, unconditional_conditioning=uc,
eta=ddim_eta, eta=ddim_eta,
img_callback=callback
) )
yield self._samples_to_images(samples) yield self._samples_to_images(samples)
@ -428,6 +438,7 @@ class T2I:
skip_normalize, skip_normalize,
init_img, init_img,
strength, strength,
callback, # Currently not implemented for img2img
): ):
""" """
An infinite iterator of images from the prompt and the initial image An infinite iterator of images from the prompt and the initial image

View File

@ -18,6 +18,11 @@ fieldset {
#fieldset-search { #fieldset-search {
display: flex; display: flex;
} }
#scaling-inprocess-message{
font-weight: bold;
font-style: italic;
display: none;
}
#prompt { #prompt {
flex-grow: 1; flex-grow: 1;

View File

@ -79,8 +79,12 @@
</fieldset> </fieldset>
</form> </form>
<div id="about">For news and support for this web service, visit our <a href="http://github.com/lstein/stable-diffusion">GitHub site</a></div> <div id="about">For news and support for this web service, visit our <a href="http://github.com/lstein/stable-diffusion">GitHub site</a></div>
<br>
<progress id="progress" value="0" max="1"></progress>
<div id="scaling-inprocess-message">
<i><span>Postprocessing...</span><span id="processing_cnt">1/3</span></i>
</div>
</div> </div>
<hr>
<div id="results"> <div id="results">
<div id="no-results-message"> <div id="no-results-message">
<i><p>No results...</p></i> <i><p>No results...</p></i>

View File

@ -7,12 +7,11 @@ function toBase64(file) {
}); });
} }
function appendOutput(output) { function appendOutput(src, seed, config) {
let outputNode = document.createElement("img"); let outputNode = document.createElement("img");
outputNode.src = output[0]; outputNode.src = src;
let outputConfig = output[2]; let altText = seed.toString() + " | " + config.prompt;
let altText = output[1].toString() + " | " + outputConfig.prompt;
outputNode.alt = altText; outputNode.alt = altText;
outputNode.title = altText; outputNode.title = altText;
@ -20,9 +19,9 @@ function appendOutput(output) {
outputNode.addEventListener('click', () => { outputNode.addEventListener('click', () => {
let form = document.querySelector("#generate-form"); let form = document.querySelector("#generate-form");
for (const [k, v] of new FormData(form)) { for (const [k, v] of new FormData(form)) {
form.querySelector(`*[name=${k}]`).value = outputConfig[k]; form.querySelector(`*[name=${k}]`).value = config[k];
} }
document.querySelector("#seed").value = output[1]; document.querySelector("#seed").value = seed;
saveFields(document.querySelector("#generate-form")); saveFields(document.querySelector("#generate-form"));
}); });
@ -30,12 +29,6 @@ function appendOutput(output) {
document.querySelector("#results").prepend(outputNode); document.querySelector("#results").prepend(outputNode);
} }
function appendOutputs(outputs) {
for (const output of outputs) {
appendOutput(output);
}
}
function saveFields(form) { function saveFields(form) {
for (const [k, v] of new FormData(form)) { for (const [k, v] of new FormData(form)) {
if (typeof v !== 'object') { // Don't save 'file' type if (typeof v !== 'object') { // Don't save 'file' type
@ -65,21 +58,45 @@ async function generateSubmit(form) {
let formData = Object.fromEntries(new FormData(form)); let formData = Object.fromEntries(new FormData(form));
formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null; formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null;
// Post as JSON document.querySelector('progress').setAttribute('max', formData.steps);
// Post as JSON, using Fetch streaming to get results
fetch(form.action, { fetch(form.action, {
method: form.method, method: form.method,
body: JSON.stringify(formData), body: JSON.stringify(formData),
}).then(async (result) => { }).then(async (response) => {
let data = await result.json(); const reader = response.body.getReader();
let noOutputs = true;
while (true) {
let {value, done} = await reader.read();
value = new TextDecoder().decode(value);
if (done) break;
for (let event of value.split('\n').filter(e => e !== '')) {
const data = JSON.parse(event);
if (data.event == 'result') {
noOutputs = false;
document.querySelector("#no-results-message")?.remove();
appendOutput(data.files[0],data.files[1],data.config)
} else if (data.event == 'upscaling-started') {
document.getElementById("processing_cnt").textContent=data.processed_file_cnt;
document.getElementById("scaling-inprocess-message").style.display = "block";
} else if (data.event == 'upscaling-done') {
document.getElementById("scaling-inprocess-message").style.display = "none";
} else if (data.event == 'step') {
document.querySelector('progress').setAttribute('value', data.step.toString());
}
}
}
// Re-enable form, remove no-results-message // Re-enable form, remove no-results-message
form.querySelector('fieldset').removeAttribute('disabled'); form.querySelector('fieldset').removeAttribute('disabled');
document.querySelector("#prompt").value = prompt; document.querySelector("#prompt").value = prompt;
document.querySelector('progress').setAttribute('value', '0');
if (data.outputs.length != 0) { if (noOutputs) {
document.querySelector("#no-results-message")?.remove();
appendOutputs(data.outputs);
} else {
alert("Error occurred while generating."); alert("Error occurred while generating.");
} }
}); });