diff --git a/ldm/dream/pngwriter.py b/ldm/dream/pngwriter.py index feae9f387d..42bb575a97 100644 --- a/ldm/dream/pngwriter.py +++ b/ldm/dream/pngwriter.py @@ -68,15 +68,12 @@ class PngWriter: while not finished: series += 1 filename = f'{basecount:06}.{seed}.png' - if self.batch_size > 1 or os.path.exists( - os.path.join(self.outdir, filename) - ): + path = os.path.join(self.outdir, filename) + if self.batch_size > 1 or os.path.exists(path): if upscaled: break filename = f'{basecount:06}.{seed}.{series:02}.png' - finished = not os.path.exists( - os.path.join(self.outdir, filename) - ) + finished = not os.path.exists(path) return os.path.join(self.outdir, filename) def save_image_and_prompt_to_png(self, image, prompt, path): diff --git a/ldm/dream/server.py b/ldm/dream/server.py index 79fce320ce..efe3139f74 100644 --- a/ldm/dream/server.py +++ b/ldm/dream/server.py @@ -3,6 +3,7 @@ import base64 import mimetypes import os from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from ldm.dream.pngwriter import PngWriter class DreamServer(BaseHTTPRequestHandler): model = None @@ -52,11 +53,63 @@ class DreamServer(BaseHTTPRequestHandler): seed = None if int(post_data['seed']) == -1 else int(post_data['seed']) print(f"Request to generate with prompt: {prompt}") + # In order to handle upscaled images, the PngWriter needs to maintain state + # across images generated by each call to prompt2img(), so we define it in + # the outer scope of image_done() + config = post_data.copy() # Shallow copy + config['initimg'] = '' + + images_generated = 0 # helps keep track of when upscaling is started + images_upscaled = 0 # helps keep track of when upscaling is completed + pngwriter = PngWriter( + "./outputs/img-samples/", config['prompt'], 1 + ) + + # if upscaling is requested, then this will be called twice, once when + # the images are first generated, and then again when after upscaling + # is complete. The upscaling replaces the original file, so the second + # entry should not be inserted into the image list. + def image_done(image, seed, upscaled=False): + pngwriter.write_image(image, seed, upscaled) + + # Append post_data to log, but only once! + if not upscaled: + current_image = pngwriter.files_written[-1] + with open("./outputs/img-samples/dream_web_log.txt", "a") as log: + log.write(f"{current_image[0]}: {json.dumps(config)}\n") + self.wfile.write(bytes(json.dumps( + {'event':'result', 'files':current_image, 'config':config} + ) + '\n',"utf-8")) + + # control state of the "postprocessing..." message + upscaling_requested = upscale or gfpgan_strength>0 + nonlocal images_generated # NB: Is this bad python style? It is typical usage in a perl closure. + nonlocal images_upscaled # NB: Is this bad python style? It is typical usage in a perl closure. + if upscaled: + images_upscaled += 1 + else: + images_generated +=1 + if upscaling_requested: + action = None + if images_generated >= iterations: + if images_upscaled < iterations: + action = 'upscaling-started' + else: + action = 'upscaling-done' + if action: + x = images_upscaled+1 + self.wfile.write(bytes(json.dumps( + {'event':action,'processed_file_cnt':f'{x}/{iterations}'} + ) + '\n',"utf-8")) + + def image_progress(image, step): + self.wfile.write(bytes(json.dumps( + {'event':'step', 'step':step} + ) + '\n',"utf-8")) - outputs = [] if initimg is None: # Run txt2img - outputs = self.model.txt2img(prompt, + self.model.prompt2image(prompt, iterations=iterations, cfg_scale = cfgscale, width = width, @@ -64,8 +117,9 @@ class DreamServer(BaseHTTPRequestHandler): seed = seed, steps = steps, gfpgan_strength = gfpgan_strength, - upscale = upscale - ) + upscale = upscale, + step_callback=image_progress, + image_callback=image_done) else: # Decode initimg as base64 to temp file with open("./img2img-tmp.png", "wb") as f: @@ -73,30 +127,21 @@ class DreamServer(BaseHTTPRequestHandler): f.write(base64.b64decode(initimg)) # Run img2img - outputs = self.model.img2img(prompt, - init_img = "./img2img-tmp.png", - iterations = iterations, - cfg_scale = cfgscale, - seed = seed, - gfpgan_strength=gfpgan_strength, - upscale = upscale, - steps = steps - ) + self.model.prompt2image(prompt, + init_img = "./img2img-tmp.png", + iterations = iterations, + cfg_scale = cfgscale, + seed = seed, + steps = steps, + gfpgan_strength=gfpgan_strength, + upscale = upscale, + step_callback=image_progress, + image_callback=image_done) + # Remove the temp file os.remove("./img2img-tmp.png") - print(f"Prompt generated with output: {outputs}") - - post_data['initimg'] = '' # Don't send init image back - - # Append post_data to log - with open("./outputs/img-samples/dream_web_log.txt", "a", encoding="utf-8") as log: - for output in outputs: - log.write(f"{output[0]}: {json.dumps(post_data)}\n") - - outputs = [x + [post_data] for x in outputs] # Append config to each output - result = {'outputs': outputs} - self.wfile.write(bytes(json.dumps(result), "utf-8")) + print(f"Prompt generated!") class ThreadingDreamServer(ThreadingHTTPServer): diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index 1da81eee5a..6d463fa6ba 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -61,6 +61,9 @@ class KSampler(object): # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... **kwargs, ): + def route_callback(k_callback_values): + if img_callback is not None: + img_callback(k_callback_values['x'], k_callback_values['i']) sigmas = self.model.get_sigmas(S) if x_T: @@ -78,7 +81,8 @@ class KSampler(object): } return ( K.sampling.__dict__[f'sample_{self.schedule}']( - model_wrap_cfg, x, sigmas, extra_args=extra_args + model_wrap_cfg, x, sigmas, extra_args=extra_args, + callback=route_callback ), None, ) diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index d2f10c4a81..0f65acccb7 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -201,6 +201,7 @@ class T2I: ddim_eta=None, skip_normalize=False, image_callback=None, + step_callback=None, # these are specific to txt2img width=None, height=None, @@ -230,9 +231,14 @@ class T2I: gfpgan_strength // strength for GFPGAN. 0.0 preserves image exactly, 1.0 replaces it completely ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image) variants // if >0, the 1st generated image will be passed back to img2img to generate the requested number of variants + step_callback // a function or method that will be called each step image_callback // a function or method that will be called each time an image is generated - To use the callback, define a function of method that receives two arguments, an Image object + To use the step callback, define a function that receives two arguments: + - Image GPU data + - The step number + + To use the image callback, define a function of method that receives two arguments, an Image object and the seed. You can then do whatever you like with the image, including converting it to different formats and manipulating it. For example: @@ -292,6 +298,7 @@ class T2I: skip_normalize=skip_normalize, init_img=init_img, strength=strength, + callback=step_callback, ) else: images_iterator = self._txt2img( @@ -304,6 +311,7 @@ class T2I: skip_normalize=skip_normalize, width=width, height=height, + callback=step_callback, ) with scope(self.device.type), self.model.ema_scope(): @@ -390,6 +398,7 @@ class T2I: skip_normalize, width, height, + callback, ): """ An infinite iterator of images from the prompt. @@ -413,6 +422,7 @@ class T2I: unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, eta=ddim_eta, + img_callback=callback ) yield self._samples_to_images(samples) @@ -428,6 +438,7 @@ class T2I: skip_normalize, init_img, strength, + callback, # Currently not implemented for img2img ): """ An infinite iterator of images from the prompt and the initial image diff --git a/static/dream_web/index.css b/static/dream_web/index.css index d4240d3f35..2137b9907d 100644 --- a/static/dream_web/index.css +++ b/static/dream_web/index.css @@ -18,6 +18,11 @@ fieldset { #fieldset-search { display: flex; } +#scaling-inprocess-message{ + font-weight: bold; + font-style: italic; + display: none; +} #prompt { flex-grow: 1; diff --git a/static/dream_web/index.html b/static/dream_web/index.html index 31816a2cd4..27dbbbfe84 100644 --- a/static/dream_web/index.html +++ b/static/dream_web/index.html @@ -79,8 +79,12 @@
For news and support for this web service, visit our GitHub site
+
+ +
+ Postprocessing...1/3 +
-

No results...

diff --git a/static/dream_web/index.js b/static/dream_web/index.js index 74ab1ef5f5..839451c50b 100644 --- a/static/dream_web/index.js +++ b/static/dream_web/index.js @@ -7,12 +7,11 @@ function toBase64(file) { }); } -function appendOutput(output) { +function appendOutput(src, seed, config) { let outputNode = document.createElement("img"); - outputNode.src = output[0]; + outputNode.src = src; - let outputConfig = output[2]; - let altText = output[1].toString() + " | " + outputConfig.prompt; + let altText = seed.toString() + " | " + config.prompt; outputNode.alt = altText; outputNode.title = altText; @@ -20,9 +19,9 @@ function appendOutput(output) { outputNode.addEventListener('click', () => { let form = document.querySelector("#generate-form"); for (const [k, v] of new FormData(form)) { - form.querySelector(`*[name=${k}]`).value = outputConfig[k]; + form.querySelector(`*[name=${k}]`).value = config[k]; } - document.querySelector("#seed").value = output[1]; + document.querySelector("#seed").value = seed; saveFields(document.querySelector("#generate-form")); }); @@ -30,12 +29,6 @@ function appendOutput(output) { document.querySelector("#results").prepend(outputNode); } -function appendOutputs(outputs) { - for (const output of outputs) { - appendOutput(output); - } -} - function saveFields(form) { for (const [k, v] of new FormData(form)) { if (typeof v !== 'object') { // Don't save 'file' type @@ -65,21 +58,45 @@ async function generateSubmit(form) { let formData = Object.fromEntries(new FormData(form)); formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null; - // Post as JSON + document.querySelector('progress').setAttribute('max', formData.steps); + + // Post as JSON, using Fetch streaming to get results fetch(form.action, { method: form.method, body: JSON.stringify(formData), - }).then(async (result) => { - let data = await result.json(); + }).then(async (response) => { + const reader = response.body.getReader(); + + let noOutputs = true; + while (true) { + let {value, done} = await reader.read(); + value = new TextDecoder().decode(value); + if (done) break; + + for (let event of value.split('\n').filter(e => e !== '')) { + const data = JSON.parse(event); + + if (data.event == 'result') { + noOutputs = false; + document.querySelector("#no-results-message")?.remove(); + appendOutput(data.files[0],data.files[1],data.config) + } else if (data.event == 'upscaling-started') { + document.getElementById("processing_cnt").textContent=data.processed_file_cnt; + document.getElementById("scaling-inprocess-message").style.display = "block"; + } else if (data.event == 'upscaling-done') { + document.getElementById("scaling-inprocess-message").style.display = "none"; + } else if (data.event == 'step') { + document.querySelector('progress').setAttribute('value', data.step.toString()); + } + } + } // Re-enable form, remove no-results-message form.querySelector('fieldset').removeAttribute('disabled'); document.querySelector("#prompt").value = prompt; + document.querySelector('progress').setAttribute('value', '0'); - if (data.outputs.length != 0) { - document.querySelector("#no-results-message")?.remove(); - appendOutputs(data.outputs); - } else { + if (noOutputs) { alert("Error occurred while generating."); } });