diff --git a/ldm/dream/server.py b/ldm/dream/server.py index e2eabc3e43..6d11afb8e0 100644 --- a/ldm/dream/server.py +++ b/ldm/dream/server.py @@ -3,6 +3,7 @@ import base64 import mimetypes import os from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from ldm.dream.pngwriter import PngWriter class DreamServer(BaseHTTPRequestHandler): model = None @@ -50,17 +51,42 @@ class DreamServer(BaseHTTPRequestHandler): print(f"Request to generate with prompt: {prompt}") - outputs = [] + def image_done(image, seed): + config = post_data.copy() # Shallow copy + config['initimg'] = '' + + # Write PNGs + pngwriter = PngWriter( + "./outputs/img-samples/", config['prompt'], 1 + ) + pngwriter.write_image(image, seed) + + # Append post_data to log + with open("./outputs/img-samples/dream_web_log.txt", "a") as log: + for file_path, _ in pngwriter.files_written: + log.write(f"{file_path}: {json.dumps(config)}\n") + + self.wfile.write(bytes(json.dumps( + {'event':'result', 'files':pngwriter.files_written, 'config':config} + ) + '\n',"utf-8")) + + def image_progress(image, step): + self.wfile.write(bytes(json.dumps( + {'event':'step', 'step':step} + ) + '\n',"utf-8")) + if initimg is None: # Run txt2img - outputs = self.model.txt2img(prompt, - iterations=iterations, - cfg_scale = cfgscale, - width = width, - height = height, - seed = seed, - steps = steps, - gfpgan_strength = gfpgan_strength) + self.model.prompt2image(prompt, + iterations=iterations, + cfg_scale = cfgscale, + width = width, + height = height, + seed = seed, + steps = steps, + + step_callback=image_progress, + image_callback=image_done) else: # Decode initimg as base64 to temp file with open("./img2img-tmp.png", "wb") as f: @@ -68,28 +94,19 @@ class DreamServer(BaseHTTPRequestHandler): f.write(base64.b64decode(initimg)) # Run img2img - outputs = self.model.img2img(prompt, - init_img = "./img2img-tmp.png", - iterations = iterations, - cfg_scale = cfgscale, - seed = seed, - gfpgan_strength=gfpgan_strength, - steps = steps) + self.model.prompt2image(prompt, + init_img = "./img2img-tmp.png", + iterations = iterations, + cfg_scale = cfgscale, + seed = seed, + steps = steps, + step_callback=image_progress, + image_callback=image_done) + # Remove the temp file os.remove("./img2img-tmp.png") - print(f"Prompt generated with output: {outputs}") - - post_data['initimg'] = '' # Don't send init image back - - # Append post_data to log - with open("./outputs/img-samples/dream_web_log.txt", "a") as log: - for output in outputs: - log.write(f"{output[0]}: {json.dumps(post_data)}\n") - - outputs = [x + [post_data] for x in outputs] # Append config to each output - result = {'outputs': outputs} - self.wfile.write(bytes(json.dumps(result), "utf-8")) + print(f"Prompt generated!") class ThreadingDreamServer(ThreadingHTTPServer): diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index 1da81eee5a..6d463fa6ba 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -61,6 +61,9 @@ class KSampler(object): # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... **kwargs, ): + def route_callback(k_callback_values): + if img_callback is not None: + img_callback(k_callback_values['x'], k_callback_values['i']) sigmas = self.model.get_sigmas(S) if x_T: @@ -78,7 +81,8 @@ class KSampler(object): } return ( K.sampling.__dict__[f'sample_{self.schedule}']( - model_wrap_cfg, x, sigmas, extra_args=extra_args + model_wrap_cfg, x, sigmas, extra_args=extra_args, + callback=route_callback ), None, ) diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index f1aec32c29..45656aa781 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -202,6 +202,7 @@ class T2I: ddim_eta=None, skip_normalize=False, image_callback=None, + step_callback=None, # these are specific to txt2img width=None, height=None, @@ -231,9 +232,14 @@ class T2I: gfpgan_strength // strength for GFPGAN. 0.0 preserves image exactly, 1.0 replaces it completely ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image) variants // if >0, the 1st generated image will be passed back to img2img to generate the requested number of variants + step_callback // a function or method that will be called each step image_callback // a function or method that will be called each time an image is generated - To use the callback, define a function of method that receives two arguments, an Image object + To use the step callback, define a function that receives two arguments: + - Image GPU data + - The step number + + To use the image callback, define a function of method that receives two arguments, an Image object and the seed. You can then do whatever you like with the image, including converting it to different formats and manipulating it. For example: @@ -293,6 +299,7 @@ class T2I: skip_normalize=skip_normalize, init_img=init_img, strength=strength, + callback=step_callback, ) else: images_iterator = self._txt2img( @@ -305,6 +312,7 @@ class T2I: skip_normalize=skip_normalize, width=width, height=height, + callback=step_callback, ) with scope(self.device.type), self.model.ema_scope(): @@ -389,6 +397,7 @@ class T2I: skip_normalize, width, height, + callback, ): """ An infinite iterator of images from the prompt. @@ -412,6 +421,7 @@ class T2I: unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, eta=ddim_eta, + img_callback=callback ) yield self._samples_to_images(samples) @@ -427,6 +437,7 @@ class T2I: skip_normalize, init_img, strength, + callback, # Currently not implemented for img2img ): """ An infinite iterator of images from the prompt and the initial image diff --git a/static/dream_web/index.html b/static/dream_web/index.html index 21591ab9b4..dfde207030 100644 --- a/static/dream_web/index.html +++ b/static/dream_web/index.html @@ -58,8 +58,9 @@
No results...
diff --git a/static/dream_web/index.js b/static/dream_web/index.js index 3b99deecf4..3952201b73 100644 --- a/static/dream_web/index.js +++ b/static/dream_web/index.js @@ -7,12 +7,11 @@ function toBase64(file) { }); } -function appendOutput(output) { +function appendOutput(src, seed, config) { let outputNode = document.createElement("img"); - outputNode.src = output[0]; + outputNode.src = src; - let outputConfig = output[2]; - let altText = output[1].toString() + " | " + outputConfig.prompt; + let altText = seed.toString() + " | " + config.prompt; outputNode.alt = altText; outputNode.title = altText; @@ -20,9 +19,9 @@ function appendOutput(output) { outputNode.addEventListener('click', () => { let form = document.querySelector("#generate-form"); for (const [k, v] of new FormData(form)) { - form.querySelector(`*[name=${k}]`).value = outputConfig[k]; + form.querySelector(`*[name=${k}]`).value = config[k]; } - document.querySelector("#seed").value = output[1]; + document.querySelector("#seed").value = seed; saveFields(document.querySelector("#generate-form")); }); @@ -30,12 +29,6 @@ function appendOutput(output) { document.querySelector("#results").prepend(outputNode); } -function appendOutputs(outputs) { - for (const output of outputs) { - appendOutput(output); - } -} - function saveFields(form) { for (const [k, v] of new FormData(form)) { if (typeof v !== 'object') { // Don't save 'file' type @@ -59,21 +52,43 @@ async function generateSubmit(form) { let formData = Object.fromEntries(new FormData(form)); formData.initimg = formData.initimg.name !== '' ? await toBase64(formData.initimg) : null; - // Post as JSON + document.querySelector('progress').setAttribute('max', formData.steps); + + // Post as JSON, using Fetch streaming to get results fetch(form.action, { method: form.method, body: JSON.stringify(formData), - }).then(async (result) => { - let data = await result.json(); + }).then(async (response) => { + const reader = response.body.getReader(); + + let noOutputs = true; + while (true) { + let {value, done} = await reader.read(); + value = new TextDecoder().decode(value); + if (done) break; + + for (let event of value.split('\n').filter(e => e !== '')) { + const data = JSON.parse(event); + + if (data.event == 'result') { + noOutputs = false; + document.querySelector("#no-results-message")?.remove(); + + for (let [file, seed] of data.files) { + appendOutput(file, seed, data.config); + } + } else if (data.event == 'step') { + document.querySelector('progress').setAttribute('value', data.step.toString()); + } + } + } // Re-enable form, remove no-results-message form.querySelector('fieldset').removeAttribute('disabled'); document.querySelector("#prompt").value = prompt; + document.querySelector('progress').setAttribute('value', '0'); - if (data.outputs.length != 0) { - document.querySelector("#no-results-message")?.remove(); - appendOutputs(data.outputs); - } else { + if (noOutputs) { alert("Error occurred while generating."); } });