mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
correctly handle upscaling in webUI, including displaying status messages during GFPGAN/ESRGAN postprocessing
This commit is contained in:
commit
4acfb76be6
@ -68,15 +68,12 @@ class PngWriter:
|
||||
while not finished:
|
||||
series += 1
|
||||
filename = f'{basecount:06}.{seed}.png'
|
||||
if self.batch_size > 1 or os.path.exists(
|
||||
os.path.join(self.outdir, filename)
|
||||
):
|
||||
path = os.path.join(self.outdir, filename)
|
||||
if self.batch_size > 1 or os.path.exists(path):
|
||||
if upscaled:
|
||||
break
|
||||
filename = f'{basecount:06}.{seed}.{series:02}.png'
|
||||
finished = not os.path.exists(
|
||||
os.path.join(self.outdir, filename)
|
||||
)
|
||||
finished = not os.path.exists(path)
|
||||
return os.path.join(self.outdir, filename)
|
||||
|
||||
def save_image_and_prompt_to_png(self, image, prompt, path):
|
||||
|
@ -3,6 +3,7 @@ import base64
|
||||
import mimetypes
|
||||
import os
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from ldm.dream.pngwriter import PngWriter
|
||||
|
||||
class DreamServer(BaseHTTPRequestHandler):
|
||||
model = None
|
||||
@ -52,11 +53,63 @@ class DreamServer(BaseHTTPRequestHandler):
|
||||
seed = None if int(post_data['seed']) == -1 else int(post_data['seed'])
|
||||
|
||||
print(f"Request to generate with prompt: {prompt}")
|
||||
# In order to handle upscaled images, the PngWriter needs to maintain state
|
||||
# across images generated by each call to prompt2img(), so we define it in
|
||||
# the outer scope of image_done()
|
||||
config = post_data.copy() # Shallow copy
|
||||
config['initimg'] = ''
|
||||
|
||||
images_generated = 0 # helps keep track of when upscaling is started
|
||||
images_upscaled = 0 # helps keep track of when upscaling is completed
|
||||
pngwriter = PngWriter(
|
||||
"./outputs/img-samples/", config['prompt'], 1
|
||||
)
|
||||
|
||||
# if upscaling is requested, then this will be called twice, once when
|
||||
# the images are first generated, and then again when after upscaling
|
||||
# is complete. The upscaling replaces the original file, so the second
|
||||
# entry should not be inserted into the image list.
|
||||
def image_done(image, seed, upscaled=False):
|
||||
pngwriter.write_image(image, seed, upscaled)
|
||||
|
||||
# Append post_data to log, but only once!
|
||||
if not upscaled:
|
||||
current_image = pngwriter.files_written[-1]
|
||||
with open("./outputs/img-samples/dream_web_log.txt", "a") as log:
|
||||
log.write(f"{current_image[0]}: {json.dumps(config)}\n")
|
||||
self.wfile.write(bytes(json.dumps(
|
||||
{'event':'result', 'files':current_image, 'config':config}
|
||||
) + '\n',"utf-8"))
|
||||
|
||||
# control state of the "postprocessing..." message
|
||||
upscaling_requested = upscale or gfpgan_strength>0
|
||||
nonlocal images_generated # NB: Is this bad python style? It is typical usage in a perl closure.
|
||||
nonlocal images_upscaled # NB: Is this bad python style? It is typical usage in a perl closure.
|
||||
if upscaled:
|
||||
images_upscaled += 1
|
||||
else:
|
||||
images_generated +=1
|
||||
if upscaling_requested:
|
||||
action = None
|
||||
if images_generated >= iterations:
|
||||
if images_upscaled < iterations:
|
||||
action = 'upscaling-started'
|
||||
else:
|
||||
action = 'upscaling-done'
|
||||
if action:
|
||||
x = images_upscaled+1
|
||||
self.wfile.write(bytes(json.dumps(
|
||||
{'event':action,'processed_file_cnt':f'{x}/{iterations}'}
|
||||
) + '\n',"utf-8"))
|
||||
|
||||
def image_progress(image, step):
|
||||
self.wfile.write(bytes(json.dumps(
|
||||
{'event':'step', 'step':step}
|
||||
) + '\n',"utf-8"))
|
||||
|
||||
outputs = []
|
||||
if initimg is None:
|
||||
# Run txt2img
|
||||
outputs = self.model.txt2img(prompt,
|
||||
self.model.prompt2image(prompt,
|
||||
iterations=iterations,
|
||||
cfg_scale = cfgscale,
|
||||
width = width,
|
||||
@ -64,8 +117,9 @@ class DreamServer(BaseHTTPRequestHandler):
|
||||
seed = seed,
|
||||
steps = steps,
|
||||
gfpgan_strength = gfpgan_strength,
|
||||
upscale = upscale
|
||||
)
|
||||
upscale = upscale,
|
||||
step_callback=image_progress,
|
||||
image_callback=image_done)
|
||||
else:
|
||||
# Decode initimg as base64 to temp file
|
||||
with open("./img2img-tmp.png", "wb") as f:
|
||||
@ -73,30 +127,21 @@ class DreamServer(BaseHTTPRequestHandler):
|
||||
f.write(base64.b64decode(initimg))
|
||||
|
||||
# Run img2img
|
||||
outputs = self.model.img2img(prompt,
|
||||
self.model.prompt2image(prompt,
|
||||
init_img = "./img2img-tmp.png",
|
||||
iterations = iterations,
|
||||
cfg_scale = cfgscale,
|
||||
seed = seed,
|
||||
steps = steps,
|
||||
gfpgan_strength=gfpgan_strength,
|
||||
upscale = upscale,
|
||||
steps = steps
|
||||
)
|
||||
step_callback=image_progress,
|
||||
image_callback=image_done)
|
||||
|
||||
# Remove the temp file
|
||||
os.remove("./img2img-tmp.png")
|
||||
|
||||
print(f"Prompt generated with output: {outputs}")
|
||||
|
||||
post_data['initimg'] = '' # Don't send init image back
|
||||
|
||||
# Append post_data to log
|
||||
with open("./outputs/img-samples/dream_web_log.txt", "a", encoding="utf-8") as log:
|
||||
for output in outputs:
|
||||
log.write(f"{output[0]}: {json.dumps(post_data)}\n")
|
||||
|
||||
outputs = [x + [post_data] for x in outputs] # Append config to each output
|
||||
result = {'outputs': outputs}
|
||||
self.wfile.write(bytes(json.dumps(result), "utf-8"))
|
||||
print(f"Prompt generated!")
|
||||
|
||||
|
||||
class ThreadingDreamServer(ThreadingHTTPServer):
|
||||
|
@ -61,6 +61,9 @@ class KSampler(object):
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
**kwargs,
|
||||
):
|
||||
def route_callback(k_callback_values):
|
||||
if img_callback is not None:
|
||||
img_callback(k_callback_values['x'], k_callback_values['i'])
|
||||
|
||||
sigmas = self.model.get_sigmas(S)
|
||||
if x_T:
|
||||
@ -78,7 +81,8 @@ class KSampler(object):
|
||||
}
|
||||
return (
|
||||
K.sampling.__dict__[f'sample_{self.schedule}'](
|
||||
model_wrap_cfg, x, sigmas, extra_args=extra_args
|
||||
model_wrap_cfg, x, sigmas, extra_args=extra_args,
|
||||
callback=route_callback
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
@ -201,6 +201,7 @@ class T2I:
|
||||
ddim_eta=None,
|
||||
skip_normalize=False,
|
||||
image_callback=None,
|
||||
step_callback=None,
|
||||
# these are specific to txt2img
|
||||
width=None,
|
||||
height=None,
|
||||
@ -230,9 +231,14 @@ class T2I:
|
||||
gfpgan_strength // strength for GFPGAN. 0.0 preserves image exactly, 1.0 replaces it completely
|
||||
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
|
||||
variants // if >0, the 1st generated image will be passed back to img2img to generate the requested number of variants
|
||||
step_callback // a function or method that will be called each step
|
||||
image_callback // a function or method that will be called each time an image is generated
|
||||
|
||||
To use the callback, define a function of method that receives two arguments, an Image object
|
||||
To use the step callback, define a function that receives two arguments:
|
||||
- Image GPU data
|
||||
- The step number
|
||||
|
||||
To use the image callback, define a function of method that receives two arguments, an Image object
|
||||
and the seed. You can then do whatever you like with the image, including converting it to
|
||||
different formats and manipulating it. For example:
|
||||
|
||||
@ -292,6 +298,7 @@ class T2I:
|
||||
skip_normalize=skip_normalize,
|
||||
init_img=init_img,
|
||||
strength=strength,
|
||||
callback=step_callback,
|
||||
)
|
||||
else:
|
||||
images_iterator = self._txt2img(
|
||||
@ -304,6 +311,7 @@ class T2I:
|
||||
skip_normalize=skip_normalize,
|
||||
width=width,
|
||||
height=height,
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
with scope(self.device.type), self.model.ema_scope():
|
||||
@ -390,6 +398,7 @@ class T2I:
|
||||
skip_normalize,
|
||||
width,
|
||||
height,
|
||||
callback,
|
||||
):
|
||||
"""
|
||||
An infinite iterator of images from the prompt.
|
||||
@ -413,6 +422,7 @@ class T2I:
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,
|
||||
eta=ddim_eta,
|
||||
img_callback=callback
|
||||
)
|
||||
yield self._samples_to_images(samples)
|
||||
|
||||
@ -428,6 +438,7 @@ class T2I:
|
||||
skip_normalize,
|
||||
init_img,
|
||||
strength,
|
||||
callback, # Currently not implemented for img2img
|
||||
):
|
||||
"""
|
||||
An infinite iterator of images from the prompt and the initial image
|
||||
|
@ -18,6 +18,11 @@ fieldset {
|
||||
#fieldset-search {
|
||||
display: flex;
|
||||
}
|
||||
#scaling-inprocess-message{
|
||||
font-weight: bold;
|
||||
font-style: italic;
|
||||
display: none;
|
||||
}
|
||||
#prompt {
|
||||
flex-grow: 1;
|
||||
|
||||
|
@ -79,8 +79,12 @@
|
||||
</fieldset>
|
||||
</form>
|
||||
<div id="about">For news and support for this web service, visit our <a href="http://github.com/lstein/stable-diffusion">GitHub site</a></div>
|
||||
<br>
|
||||
<progress id="progress" value="0" max="1"></progress>
|
||||
<div id="scaling-inprocess-message">
|
||||
<i><span>Postprocessing...</span><span id="processing_cnt">1/3</span></i>
|
||||
</div>
|
||||
</div>
|
||||
<hr>
|
||||
<div id="results">
|
||||
<div id="no-results-message">
|
||||
<i><p>No results...</p></i>
|
||||
|
@ -7,12 +7,11 @@ function toBase64(file) {
|
||||
});
|
||||
}
|
||||
|
||||
function appendOutput(output) {
|
||||
function appendOutput(src, seed, config) {
|
||||
let outputNode = document.createElement("img");
|
||||
outputNode.src = output[0];
|
||||
outputNode.src = src;
|
||||
|
||||
let outputConfig = output[2];
|
||||
let altText = output[1].toString() + " | " + outputConfig.prompt;
|
||||
let altText = seed.toString() + " | " + config.prompt;
|
||||
outputNode.alt = altText;
|
||||
outputNode.title = altText;
|
||||
|
||||
@ -20,9 +19,9 @@ function appendOutput(output) {
|
||||
outputNode.addEventListener('click', () => {
|
||||
let form = document.querySelector("#generate-form");
|
||||
for (const [k, v] of new FormData(form)) {
|
||||
form.querySelector(`*[name=${k}]`).value = outputConfig[k];
|
||||
form.querySelector(`*[name=${k}]`).value = config[k];
|
||||
}
|
||||
document.querySelector("#seed").value = output[1];
|
||||
document.querySelector("#seed").value = seed;
|
||||
|
||||
saveFields(document.querySelector("#generate-form"));
|
||||
});
|
||||
@ -30,12 +29,6 @@ function appendOutput(output) {
|
||||
document.querySelector("#results").prepend(outputNode);
|
||||
}
|
||||
|
||||
function appendOutputs(outputs) {
|
||||
for (const output of outputs) {
|
||||
appendOutput(output);
|
||||
}
|
||||
}
|
||||
|
||||
function saveFields(form) {
|
||||
for (const [k, v] of new FormData(form)) {
|
||||
if (typeof v !== 'object') { // Don't save 'file' type
|
||||
@ -65,21 +58,45 @@ async function generateSubmit(form) {
|
||||
let formData = Object.fromEntries(new FormData(form));
|
||||
formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null;
|
||||
|
||||
// Post as JSON
|
||||
document.querySelector('progress').setAttribute('max', formData.steps);
|
||||
|
||||
// Post as JSON, using Fetch streaming to get results
|
||||
fetch(form.action, {
|
||||
method: form.method,
|
||||
body: JSON.stringify(formData),
|
||||
}).then(async (result) => {
|
||||
let data = await result.json();
|
||||
}).then(async (response) => {
|
||||
const reader = response.body.getReader();
|
||||
|
||||
let noOutputs = true;
|
||||
while (true) {
|
||||
let {value, done} = await reader.read();
|
||||
value = new TextDecoder().decode(value);
|
||||
if (done) break;
|
||||
|
||||
for (let event of value.split('\n').filter(e => e !== '')) {
|
||||
const data = JSON.parse(event);
|
||||
|
||||
if (data.event == 'result') {
|
||||
noOutputs = false;
|
||||
document.querySelector("#no-results-message")?.remove();
|
||||
appendOutput(data.files[0],data.files[1],data.config)
|
||||
} else if (data.event == 'upscaling-started') {
|
||||
document.getElementById("processing_cnt").textContent=data.processed_file_cnt;
|
||||
document.getElementById("scaling-inprocess-message").style.display = "block";
|
||||
} else if (data.event == 'upscaling-done') {
|
||||
document.getElementById("scaling-inprocess-message").style.display = "none";
|
||||
} else if (data.event == 'step') {
|
||||
document.querySelector('progress').setAttribute('value', data.step.toString());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Re-enable form, remove no-results-message
|
||||
form.querySelector('fieldset').removeAttribute('disabled');
|
||||
document.querySelector("#prompt").value = prompt;
|
||||
document.querySelector('progress').setAttribute('value', '0');
|
||||
|
||||
if (data.outputs.length != 0) {
|
||||
document.querySelector("#no-results-message")?.remove();
|
||||
appendOutputs(data.outputs);
|
||||
} else {
|
||||
if (noOutputs) {
|
||||
alert("Error occurred while generating.");
|
||||
}
|
||||
});
|
||||
|
Loading…
Reference in New Issue
Block a user