mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
correctly handle upscaling in webUI, including displaying status messages during GFPGAN/ESRGAN postprocessing
This commit is contained in:
commit
4acfb76be6
@ -68,15 +68,12 @@ class PngWriter:
|
|||||||
while not finished:
|
while not finished:
|
||||||
series += 1
|
series += 1
|
||||||
filename = f'{basecount:06}.{seed}.png'
|
filename = f'{basecount:06}.{seed}.png'
|
||||||
if self.batch_size > 1 or os.path.exists(
|
path = os.path.join(self.outdir, filename)
|
||||||
os.path.join(self.outdir, filename)
|
if self.batch_size > 1 or os.path.exists(path):
|
||||||
):
|
|
||||||
if upscaled:
|
if upscaled:
|
||||||
break
|
break
|
||||||
filename = f'{basecount:06}.{seed}.{series:02}.png'
|
filename = f'{basecount:06}.{seed}.{series:02}.png'
|
||||||
finished = not os.path.exists(
|
finished = not os.path.exists(path)
|
||||||
os.path.join(self.outdir, filename)
|
|
||||||
)
|
|
||||||
return os.path.join(self.outdir, filename)
|
return os.path.join(self.outdir, filename)
|
||||||
|
|
||||||
def save_image_and_prompt_to_png(self, image, prompt, path):
|
def save_image_and_prompt_to_png(self, image, prompt, path):
|
||||||
|
@ -3,6 +3,7 @@ import base64
|
|||||||
import mimetypes
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||||
|
from ldm.dream.pngwriter import PngWriter
|
||||||
|
|
||||||
class DreamServer(BaseHTTPRequestHandler):
|
class DreamServer(BaseHTTPRequestHandler):
|
||||||
model = None
|
model = None
|
||||||
@ -52,11 +53,63 @@ class DreamServer(BaseHTTPRequestHandler):
|
|||||||
seed = None if int(post_data['seed']) == -1 else int(post_data['seed'])
|
seed = None if int(post_data['seed']) == -1 else int(post_data['seed'])
|
||||||
|
|
||||||
print(f"Request to generate with prompt: {prompt}")
|
print(f"Request to generate with prompt: {prompt}")
|
||||||
|
# In order to handle upscaled images, the PngWriter needs to maintain state
|
||||||
|
# across images generated by each call to prompt2img(), so we define it in
|
||||||
|
# the outer scope of image_done()
|
||||||
|
config = post_data.copy() # Shallow copy
|
||||||
|
config['initimg'] = ''
|
||||||
|
|
||||||
|
images_generated = 0 # helps keep track of when upscaling is started
|
||||||
|
images_upscaled = 0 # helps keep track of when upscaling is completed
|
||||||
|
pngwriter = PngWriter(
|
||||||
|
"./outputs/img-samples/", config['prompt'], 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# if upscaling is requested, then this will be called twice, once when
|
||||||
|
# the images are first generated, and then again when after upscaling
|
||||||
|
# is complete. The upscaling replaces the original file, so the second
|
||||||
|
# entry should not be inserted into the image list.
|
||||||
|
def image_done(image, seed, upscaled=False):
|
||||||
|
pngwriter.write_image(image, seed, upscaled)
|
||||||
|
|
||||||
|
# Append post_data to log, but only once!
|
||||||
|
if not upscaled:
|
||||||
|
current_image = pngwriter.files_written[-1]
|
||||||
|
with open("./outputs/img-samples/dream_web_log.txt", "a") as log:
|
||||||
|
log.write(f"{current_image[0]}: {json.dumps(config)}\n")
|
||||||
|
self.wfile.write(bytes(json.dumps(
|
||||||
|
{'event':'result', 'files':current_image, 'config':config}
|
||||||
|
) + '\n',"utf-8"))
|
||||||
|
|
||||||
|
# control state of the "postprocessing..." message
|
||||||
|
upscaling_requested = upscale or gfpgan_strength>0
|
||||||
|
nonlocal images_generated # NB: Is this bad python style? It is typical usage in a perl closure.
|
||||||
|
nonlocal images_upscaled # NB: Is this bad python style? It is typical usage in a perl closure.
|
||||||
|
if upscaled:
|
||||||
|
images_upscaled += 1
|
||||||
|
else:
|
||||||
|
images_generated +=1
|
||||||
|
if upscaling_requested:
|
||||||
|
action = None
|
||||||
|
if images_generated >= iterations:
|
||||||
|
if images_upscaled < iterations:
|
||||||
|
action = 'upscaling-started'
|
||||||
|
else:
|
||||||
|
action = 'upscaling-done'
|
||||||
|
if action:
|
||||||
|
x = images_upscaled+1
|
||||||
|
self.wfile.write(bytes(json.dumps(
|
||||||
|
{'event':action,'processed_file_cnt':f'{x}/{iterations}'}
|
||||||
|
) + '\n',"utf-8"))
|
||||||
|
|
||||||
|
def image_progress(image, step):
|
||||||
|
self.wfile.write(bytes(json.dumps(
|
||||||
|
{'event':'step', 'step':step}
|
||||||
|
) + '\n',"utf-8"))
|
||||||
|
|
||||||
outputs = []
|
|
||||||
if initimg is None:
|
if initimg is None:
|
||||||
# Run txt2img
|
# Run txt2img
|
||||||
outputs = self.model.txt2img(prompt,
|
self.model.prompt2image(prompt,
|
||||||
iterations=iterations,
|
iterations=iterations,
|
||||||
cfg_scale = cfgscale,
|
cfg_scale = cfgscale,
|
||||||
width = width,
|
width = width,
|
||||||
@ -64,8 +117,9 @@ class DreamServer(BaseHTTPRequestHandler):
|
|||||||
seed = seed,
|
seed = seed,
|
||||||
steps = steps,
|
steps = steps,
|
||||||
gfpgan_strength = gfpgan_strength,
|
gfpgan_strength = gfpgan_strength,
|
||||||
upscale = upscale
|
upscale = upscale,
|
||||||
)
|
step_callback=image_progress,
|
||||||
|
image_callback=image_done)
|
||||||
else:
|
else:
|
||||||
# Decode initimg as base64 to temp file
|
# Decode initimg as base64 to temp file
|
||||||
with open("./img2img-tmp.png", "wb") as f:
|
with open("./img2img-tmp.png", "wb") as f:
|
||||||
@ -73,30 +127,21 @@ class DreamServer(BaseHTTPRequestHandler):
|
|||||||
f.write(base64.b64decode(initimg))
|
f.write(base64.b64decode(initimg))
|
||||||
|
|
||||||
# Run img2img
|
# Run img2img
|
||||||
outputs = self.model.img2img(prompt,
|
self.model.prompt2image(prompt,
|
||||||
init_img = "./img2img-tmp.png",
|
init_img = "./img2img-tmp.png",
|
||||||
iterations = iterations,
|
iterations = iterations,
|
||||||
cfg_scale = cfgscale,
|
cfg_scale = cfgscale,
|
||||||
seed = seed,
|
seed = seed,
|
||||||
|
steps = steps,
|
||||||
gfpgan_strength=gfpgan_strength,
|
gfpgan_strength=gfpgan_strength,
|
||||||
upscale = upscale,
|
upscale = upscale,
|
||||||
steps = steps
|
step_callback=image_progress,
|
||||||
)
|
image_callback=image_done)
|
||||||
|
|
||||||
# Remove the temp file
|
# Remove the temp file
|
||||||
os.remove("./img2img-tmp.png")
|
os.remove("./img2img-tmp.png")
|
||||||
|
|
||||||
print(f"Prompt generated with output: {outputs}")
|
print(f"Prompt generated!")
|
||||||
|
|
||||||
post_data['initimg'] = '' # Don't send init image back
|
|
||||||
|
|
||||||
# Append post_data to log
|
|
||||||
with open("./outputs/img-samples/dream_web_log.txt", "a", encoding="utf-8") as log:
|
|
||||||
for output in outputs:
|
|
||||||
log.write(f"{output[0]}: {json.dumps(post_data)}\n")
|
|
||||||
|
|
||||||
outputs = [x + [post_data] for x in outputs] # Append config to each output
|
|
||||||
result = {'outputs': outputs}
|
|
||||||
self.wfile.write(bytes(json.dumps(result), "utf-8"))
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadingDreamServer(ThreadingHTTPServer):
|
class ThreadingDreamServer(ThreadingHTTPServer):
|
||||||
|
@ -61,6 +61,9 @@ class KSampler(object):
|
|||||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
def route_callback(k_callback_values):
|
||||||
|
if img_callback is not None:
|
||||||
|
img_callback(k_callback_values['x'], k_callback_values['i'])
|
||||||
|
|
||||||
sigmas = self.model.get_sigmas(S)
|
sigmas = self.model.get_sigmas(S)
|
||||||
if x_T:
|
if x_T:
|
||||||
@ -78,7 +81,8 @@ class KSampler(object):
|
|||||||
}
|
}
|
||||||
return (
|
return (
|
||||||
K.sampling.__dict__[f'sample_{self.schedule}'](
|
K.sampling.__dict__[f'sample_{self.schedule}'](
|
||||||
model_wrap_cfg, x, sigmas, extra_args=extra_args
|
model_wrap_cfg, x, sigmas, extra_args=extra_args,
|
||||||
|
callback=route_callback
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
@ -201,6 +201,7 @@ class T2I:
|
|||||||
ddim_eta=None,
|
ddim_eta=None,
|
||||||
skip_normalize=False,
|
skip_normalize=False,
|
||||||
image_callback=None,
|
image_callback=None,
|
||||||
|
step_callback=None,
|
||||||
# these are specific to txt2img
|
# these are specific to txt2img
|
||||||
width=None,
|
width=None,
|
||||||
height=None,
|
height=None,
|
||||||
@ -230,9 +231,14 @@ class T2I:
|
|||||||
gfpgan_strength // strength for GFPGAN. 0.0 preserves image exactly, 1.0 replaces it completely
|
gfpgan_strength // strength for GFPGAN. 0.0 preserves image exactly, 1.0 replaces it completely
|
||||||
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
|
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
|
||||||
variants // if >0, the 1st generated image will be passed back to img2img to generate the requested number of variants
|
variants // if >0, the 1st generated image will be passed back to img2img to generate the requested number of variants
|
||||||
|
step_callback // a function or method that will be called each step
|
||||||
image_callback // a function or method that will be called each time an image is generated
|
image_callback // a function or method that will be called each time an image is generated
|
||||||
|
|
||||||
To use the callback, define a function of method that receives two arguments, an Image object
|
To use the step callback, define a function that receives two arguments:
|
||||||
|
- Image GPU data
|
||||||
|
- The step number
|
||||||
|
|
||||||
|
To use the image callback, define a function of method that receives two arguments, an Image object
|
||||||
and the seed. You can then do whatever you like with the image, including converting it to
|
and the seed. You can then do whatever you like with the image, including converting it to
|
||||||
different formats and manipulating it. For example:
|
different formats and manipulating it. For example:
|
||||||
|
|
||||||
@ -292,6 +298,7 @@ class T2I:
|
|||||||
skip_normalize=skip_normalize,
|
skip_normalize=skip_normalize,
|
||||||
init_img=init_img,
|
init_img=init_img,
|
||||||
strength=strength,
|
strength=strength,
|
||||||
|
callback=step_callback,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
images_iterator = self._txt2img(
|
images_iterator = self._txt2img(
|
||||||
@ -304,6 +311,7 @@ class T2I:
|
|||||||
skip_normalize=skip_normalize,
|
skip_normalize=skip_normalize,
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
|
callback=step_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
with scope(self.device.type), self.model.ema_scope():
|
with scope(self.device.type), self.model.ema_scope():
|
||||||
@ -390,6 +398,7 @@ class T2I:
|
|||||||
skip_normalize,
|
skip_normalize,
|
||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
|
callback,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
An infinite iterator of images from the prompt.
|
An infinite iterator of images from the prompt.
|
||||||
@ -413,6 +422,7 @@ class T2I:
|
|||||||
unconditional_guidance_scale=cfg_scale,
|
unconditional_guidance_scale=cfg_scale,
|
||||||
unconditional_conditioning=uc,
|
unconditional_conditioning=uc,
|
||||||
eta=ddim_eta,
|
eta=ddim_eta,
|
||||||
|
img_callback=callback
|
||||||
)
|
)
|
||||||
yield self._samples_to_images(samples)
|
yield self._samples_to_images(samples)
|
||||||
|
|
||||||
@ -428,6 +438,7 @@ class T2I:
|
|||||||
skip_normalize,
|
skip_normalize,
|
||||||
init_img,
|
init_img,
|
||||||
strength,
|
strength,
|
||||||
|
callback, # Currently not implemented for img2img
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
An infinite iterator of images from the prompt and the initial image
|
An infinite iterator of images from the prompt and the initial image
|
||||||
|
@ -18,6 +18,11 @@ fieldset {
|
|||||||
#fieldset-search {
|
#fieldset-search {
|
||||||
display: flex;
|
display: flex;
|
||||||
}
|
}
|
||||||
|
#scaling-inprocess-message{
|
||||||
|
font-weight: bold;
|
||||||
|
font-style: italic;
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
#prompt {
|
#prompt {
|
||||||
flex-grow: 1;
|
flex-grow: 1;
|
||||||
|
|
||||||
|
@ -79,8 +79,12 @@
|
|||||||
</fieldset>
|
</fieldset>
|
||||||
</form>
|
</form>
|
||||||
<div id="about">For news and support for this web service, visit our <a href="http://github.com/lstein/stable-diffusion">GitHub site</a></div>
|
<div id="about">For news and support for this web service, visit our <a href="http://github.com/lstein/stable-diffusion">GitHub site</a></div>
|
||||||
|
<br>
|
||||||
|
<progress id="progress" value="0" max="1"></progress>
|
||||||
|
<div id="scaling-inprocess-message">
|
||||||
|
<i><span>Postprocessing...</span><span id="processing_cnt">1/3</span></i>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<hr>
|
|
||||||
<div id="results">
|
<div id="results">
|
||||||
<div id="no-results-message">
|
<div id="no-results-message">
|
||||||
<i><p>No results...</p></i>
|
<i><p>No results...</p></i>
|
||||||
|
@ -7,12 +7,11 @@ function toBase64(file) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
function appendOutput(output) {
|
function appendOutput(src, seed, config) {
|
||||||
let outputNode = document.createElement("img");
|
let outputNode = document.createElement("img");
|
||||||
outputNode.src = output[0];
|
outputNode.src = src;
|
||||||
|
|
||||||
let outputConfig = output[2];
|
let altText = seed.toString() + " | " + config.prompt;
|
||||||
let altText = output[1].toString() + " | " + outputConfig.prompt;
|
|
||||||
outputNode.alt = altText;
|
outputNode.alt = altText;
|
||||||
outputNode.title = altText;
|
outputNode.title = altText;
|
||||||
|
|
||||||
@ -20,9 +19,9 @@ function appendOutput(output) {
|
|||||||
outputNode.addEventListener('click', () => {
|
outputNode.addEventListener('click', () => {
|
||||||
let form = document.querySelector("#generate-form");
|
let form = document.querySelector("#generate-form");
|
||||||
for (const [k, v] of new FormData(form)) {
|
for (const [k, v] of new FormData(form)) {
|
||||||
form.querySelector(`*[name=${k}]`).value = outputConfig[k];
|
form.querySelector(`*[name=${k}]`).value = config[k];
|
||||||
}
|
}
|
||||||
document.querySelector("#seed").value = output[1];
|
document.querySelector("#seed").value = seed;
|
||||||
|
|
||||||
saveFields(document.querySelector("#generate-form"));
|
saveFields(document.querySelector("#generate-form"));
|
||||||
});
|
});
|
||||||
@ -30,12 +29,6 @@ function appendOutput(output) {
|
|||||||
document.querySelector("#results").prepend(outputNode);
|
document.querySelector("#results").prepend(outputNode);
|
||||||
}
|
}
|
||||||
|
|
||||||
function appendOutputs(outputs) {
|
|
||||||
for (const output of outputs) {
|
|
||||||
appendOutput(output);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function saveFields(form) {
|
function saveFields(form) {
|
||||||
for (const [k, v] of new FormData(form)) {
|
for (const [k, v] of new FormData(form)) {
|
||||||
if (typeof v !== 'object') { // Don't save 'file' type
|
if (typeof v !== 'object') { // Don't save 'file' type
|
||||||
@ -65,21 +58,45 @@ async function generateSubmit(form) {
|
|||||||
let formData = Object.fromEntries(new FormData(form));
|
let formData = Object.fromEntries(new FormData(form));
|
||||||
formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null;
|
formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null;
|
||||||
|
|
||||||
// Post as JSON
|
document.querySelector('progress').setAttribute('max', formData.steps);
|
||||||
|
|
||||||
|
// Post as JSON, using Fetch streaming to get results
|
||||||
fetch(form.action, {
|
fetch(form.action, {
|
||||||
method: form.method,
|
method: form.method,
|
||||||
body: JSON.stringify(formData),
|
body: JSON.stringify(formData),
|
||||||
}).then(async (result) => {
|
}).then(async (response) => {
|
||||||
let data = await result.json();
|
const reader = response.body.getReader();
|
||||||
|
|
||||||
|
let noOutputs = true;
|
||||||
|
while (true) {
|
||||||
|
let {value, done} = await reader.read();
|
||||||
|
value = new TextDecoder().decode(value);
|
||||||
|
if (done) break;
|
||||||
|
|
||||||
|
for (let event of value.split('\n').filter(e => e !== '')) {
|
||||||
|
const data = JSON.parse(event);
|
||||||
|
|
||||||
|
if (data.event == 'result') {
|
||||||
|
noOutputs = false;
|
||||||
|
document.querySelector("#no-results-message")?.remove();
|
||||||
|
appendOutput(data.files[0],data.files[1],data.config)
|
||||||
|
} else if (data.event == 'upscaling-started') {
|
||||||
|
document.getElementById("processing_cnt").textContent=data.processed_file_cnt;
|
||||||
|
document.getElementById("scaling-inprocess-message").style.display = "block";
|
||||||
|
} else if (data.event == 'upscaling-done') {
|
||||||
|
document.getElementById("scaling-inprocess-message").style.display = "none";
|
||||||
|
} else if (data.event == 'step') {
|
||||||
|
document.querySelector('progress').setAttribute('value', data.step.toString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Re-enable form, remove no-results-message
|
// Re-enable form, remove no-results-message
|
||||||
form.querySelector('fieldset').removeAttribute('disabled');
|
form.querySelector('fieldset').removeAttribute('disabled');
|
||||||
document.querySelector("#prompt").value = prompt;
|
document.querySelector("#prompt").value = prompt;
|
||||||
|
document.querySelector('progress').setAttribute('value', '0');
|
||||||
|
|
||||||
if (data.outputs.length != 0) {
|
if (noOutputs) {
|
||||||
document.querySelector("#no-results-message")?.remove();
|
|
||||||
appendOutputs(data.outputs);
|
|
||||||
} else {
|
|
||||||
alert("Error occurred while generating.");
|
alert("Error occurred while generating.");
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
Loading…
Reference in New Issue
Block a user