diff --git a/ldm/dream/image_util.py b/ldm/dream/image_util.py index fa14ec897b..e389fd50e3 100644 --- a/ldm/dream/image_util.py +++ b/ldm/dream/image_util.py @@ -1,3 +1,4 @@ +from math import sqrt, floor, ceil from PIL import Image class InitImageResizer(): @@ -51,4 +52,22 @@ class InitImageResizer(): return new_image - +def make_grid(image_list, rows=None, cols=None): + image_cnt = len(image_list) + if None in (rows, cols): + rows = floor(sqrt(image_cnt)) # try to make it square + cols = ceil(image_cnt / rows) + width = image_list[0].width + height = image_list[0].height + + grid_img = Image.new('RGB', (width * cols, height * rows)) + i = 0 + for r in range(0, rows): + for c in range(0, cols): + if i >= len(image_list): + break + grid_img.paste(image_list[i], (c * width, r * height)) + i = i + 1 + + return grid_img + diff --git a/ldm/dream/pngwriter.py b/ldm/dream/pngwriter.py index f7838a58bf..f6b1762883 100644 --- a/ldm/dream/pngwriter.py +++ b/ldm/dream/pngwriter.py @@ -2,97 +2,42 @@ Two helper classes for dealing with PNG images and their path names. PngWriter -- Converts Images generated by T2I into PNGs, finds appropriate names for them, and writes prompt metadata - into the PNG. Intended to be subclassable in order to - create more complex naming schemes, including using the - prompt for file/directory names. + into the PNG. PromptFormatter -- Utility for converting a Namespace of prompt parameters back into a formatted prompt string with command-line switches. """ import os import re -from math import sqrt, floor, ceil -from PIL import Image, PngImagePlugin +from PIL import PngImagePlugin # -------------------image generation utils----- class PngWriter: - def __init__(self, outdir, prompt=None): + def __init__(self, outdir): self.outdir = outdir - self.prompt = prompt - self.filepath = None - self.files_written = [] os.makedirs(outdir, exist_ok=True) - def write_image(self, image, seed, upscaled=False): - self.filepath = self.unique_filename( - seed, upscaled, self.filepath - ) # will increment name in some sensible way - try: - prompt = f'{self.prompt} -S{seed}' - self.save_image_and_prompt_to_png(image, prompt, self.filepath) - except IOError as e: - print(e) - if not upscaled: - self.files_written.append([self.filepath, seed]) + # gives the next unique prefix in outdir + def unique_prefix(self): + # sort reverse alphabetically until we find max+1 + dirlist = sorted(os.listdir(self.outdir), reverse=True) + # find the first filename that matches our pattern or return 000000.0.png + existing_name = next( + (f for f in dirlist if re.match('^(\d+)\..*\.png', f)), + '0000000.0.png', + ) + basecount = int(existing_name.split('.', 1)[0]) + 1 + return f'{basecount:06}' - def unique_filename(self, seed, upscaled=False, previouspath=None): - revision = 1 - - if previouspath is None: - # sort reverse alphabetically until we find max+1 - dirlist = sorted(os.listdir(self.outdir), reverse=True) - # find the first filename that matches our pattern or return 000000.0.png - filename = next( - (f for f in dirlist if re.match('^(\d+)\..*\.png', f)), - '0000000.0.png', - ) - basecount = int(filename.split('.', 1)[0]) - basecount += 1 - filename = f'{basecount:06}.{seed}.png' - return os.path.join(self.outdir, filename) - - else: - basename = os.path.basename(previouspath) - x = re.match('^(\d+)\..*\.png', basename) - if not x: - return self.unique_filename(seed, upscaled, previouspath) - - basecount = int(x.groups()[0]) - series = 0 - finished = False - while not finished: - series += 1 - filename = f'{basecount:06}.{seed}.png' - path = os.path.join(self.outdir, filename) - if os.path.exists(path) and upscaled: - break - finished = not os.path.exists(path) - return os.path.join(self.outdir, filename) - - def save_image_and_prompt_to_png(self, image, prompt, path): + # saves image named _image_ to outdir/name, writing metadata from prompt + # returns full path of output + def save_image_and_prompt_to_png(self, image, prompt, name): + path = os.path.join(self.outdir, name) info = PngImagePlugin.PngInfo() info.add_text('Dream', prompt) image.save(path, 'PNG', pnginfo=info) - - def make_grid(self, image_list, rows=None, cols=None): - image_cnt = len(image_list) - if None in (rows, cols): - rows = floor(sqrt(image_cnt)) # try to make it square - cols = ceil(image_cnt / rows) - width = image_list[0].width - height = image_list[0].height - - grid_img = Image.new('RGB', (width * cols, height * rows)) - i = 0 - for r in range(0, rows): - for c in range(0, cols): - if i>=len(image_list): - break - grid_img.paste(image_list[i], (c * width, r * height)) - i = i + 1 - - return grid_img + return path class PromptFormatter: diff --git a/ldm/dream/server.py b/ldm/dream/server.py index 346e114a2b..6a667f616b 100644 --- a/ldm/dream/server.py +++ b/ldm/dream/server.py @@ -88,24 +88,24 @@ class DreamServer(BaseHTTPRequestHandler): images_generated = 0 # helps keep track of when upscaling is started images_upscaled = 0 # helps keep track of when upscaling is completed - pngwriter = PngWriter( - "./outputs/img-samples/", config['prompt'], 1 - ) + pngwriter = PngWriter("./outputs/img-samples/") + prefix = pngwriter.unique_prefix() # if upscaling is requested, then this will be called twice, once when # the images are first generated, and then again when after upscaling # is complete. The upscaling replaces the original file, so the second # entry should not be inserted into the image list. def image_done(image, seed, upscaled=False): - pngwriter.write_image(image, seed, upscaled) + name = f'{prefix}.{seed}.png' + path = pngwriter.save_image_and_prompt_to_png(image, f'{prompt} -S{seed}', name) # Append post_data to log, but only once! if not upscaled: - current_image = pngwriter.files_written[-1] with open("./outputs/img-samples/dream_web_log.txt", "a") as log: - log.write(f"{current_image[0]}: {json.dumps(config)}\n") + log.write(f"{path}: {json.dumps(config)}\n") + self.wfile.write(bytes(json.dumps( - {'event':'result', 'files':current_image, 'config':config} + {'event': 'result', 'url': path, 'seed': seed, 'config': config} ) + '\n',"utf-8")) # control state of the "postprocessing..." message @@ -129,22 +129,24 @@ class DreamServer(BaseHTTPRequestHandler): {'event':action,'processed_file_cnt':f'{x}/{iterations}'} ) + '\n',"utf-8")) - # TODO: refactor PngWriter: - # it doesn't need to know if batch_size > 1, just if this is _part of a batch_ - step_writer = PngWriter('./outputs/intermediates/', prompt, 2) + step_writer = PngWriter('./outputs/intermediates/') + step_index = 1 def image_progress(sample, step): if self.canceled.is_set(): self.wfile.write(bytes(json.dumps({'event':'canceled'}) + '\n', 'utf-8')) raise CanceledException - url = None + path = None # since rendering images is moderately expensive, only render every 5th image # and don't bother with the last one, since it'll render anyway + nonlocal step_index if progress_images and step % 5 == 0 and step < steps - 1: image = self.model._sample_to_image(sample) - step_writer.write_image(image, seed) # TODO PngWriter to return path - url = step_writer.filepath + name = f'{prefix}.{seed}.{step_index}.png' + metadata = f'{prompt} -S{seed} [intermediate]' + path = step_writer.save_image_and_prompt_to_png(image, metadata, name) + step_index += 1 self.wfile.write(bytes(json.dumps( - {'event':'step', 'step':step + 1, 'url': url} + {'event': 'step', 'step': step + 1, 'url': path} ) + '\n',"utf-8")) try: diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index b83280471d..d969ac5e23 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -171,10 +171,14 @@ class T2I: Optional named arguments are the same as those passed to T2I and prompt2image() """ results = self.prompt2image(prompt, **kwargs) - pngwriter = PngWriter(outdir, prompt) - for r in results: - pngwriter.write_image(r[0], r[1]) - return pngwriter.files_written + pngwriter = PngWriter(outdir) + prefix = pngwriter.unique_prefix() + outputs = [] + for image, seed in results: + name = f'{prefix}.{seed}.png' + path = pngwriter.save_image_and_prompt_to_png(image, f'{prompt} -S{seed}', name) + outputs.append([path, seed]) + return outputs def txt2img(self, prompt, **kwargs): outdir = kwargs.pop('outdir', 'outputs/img-samples') @@ -349,10 +353,7 @@ class T2I: f'Error running RealESRGAN - Your image was not upscaled.\n{e}' ) if image_callback is not None: - if save_original: - image_callback(image, seed) - else: - image_callback(image, seed, upscaled=True) + image_callback(image, seed, upscaled=True) else: # no callback passed, so we simply replace old image with rescaled one result[0] = image diff --git a/scripts/dream.py b/scripts/dream.py index 0014fb6d4d..50be6dfa7c 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -12,6 +12,7 @@ import time import ldm.dream.readline from ldm.dream.pngwriter import PngWriter, PromptFormatter from ldm.dream.server import DreamServer, ThreadingDreamServer +from ldm.dream.image_util import make_grid def main(): """Initialize command-line parsers and the diffusion model""" @@ -203,24 +204,40 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile): # Here is where the images are actually generated! try: - file_writer = PngWriter(current_outdir, normalized_prompt) - callback = file_writer.write_image if individual_images else None - image_list = t2i.prompt2image(image_callback=callback, **vars(opt)) - results = ( - file_writer.files_written if individual_images else image_list - ) + file_writer = PngWriter(current_outdir) + prefix = file_writer.unique_prefix() + seeds = set() + results = [] + grid_images = dict() # seed -> Image, only used if `do_grid` + def image_writer(image, seed, upscaled=False): + if do_grid: + grid_images[seed] = image + else: + if upscaled and opt.save_original: + filename = f'{prefix}.{seed}.postprocessed.png' + else: + filename = f'{prefix}.{seed}.png' + path = file_writer.save_image_and_prompt_to_png(image, f'{normalized_prompt} -S{seed}', filename) + if (not upscaled) or opt.save_original: + # only append to results if we didn't overwrite an earlier output + results.append([path, seed]) - if do_grid and len(results) > 0: - grid_img = file_writer.make_grid([r[0] for r in results]) - filename = file_writer.unique_filename(results[0][1]) - seeds = [a[1] for a in results] - results = [[filename, seeds]] - metadata_prompt = f'{normalized_prompt} -S{results[0][1]}' - file_writer.save_image_and_prompt_to_png( + seeds.add(seed) + + t2i.prompt2image(image_callback=image_writer, **vars(opt)) + + if do_grid and len(grid_images) > 0: + grid_img = make_grid(list(grid_images.values())) + first_seed = next(iter(seeds)) + filename = f'{prefix}.{first_seed}.png' + # TODO better metadata for grid images + metadata_prompt = f'{normalized_prompt} -S{first_seed}' + path = file_writer.save_image_and_prompt_to_png( grid_img, metadata_prompt, filename ) + results = [[path, seeds]] - last_seeds = [r[1] for r in results] + last_seeds = list(seeds) except AssertionError as e: print(e) diff --git a/static/dream_web/index.js b/static/dream_web/index.js index 5ef75a34a3..4b1c8ac2ec 100644 --- a/static/dream_web/index.js +++ b/static/dream_web/index.js @@ -95,7 +95,7 @@ async function generateSubmit(form) { if (data.event === 'result') { noOutputs = false; document.querySelector("#no-results-message")?.remove(); - appendOutput(data.files[0],data.files[1],data.config); + appendOutput(data.url, data.seed, data.config); progressEle.setAttribute('value', 0); progressEle.setAttribute('max', totalSteps); progressImageEle.src = BLANK_IMAGE_URL;