refactor pngwriter

This commit is contained in:
Kevin Gibbons 2022-08-30 21:21:04 -07:00
parent 3be1cee17c
commit 153c93bdd4
4 changed files with 74 additions and 84 deletions

View File

@ -17,62 +17,32 @@ from PIL import Image, PngImagePlugin
class PngWriter:
def __init__(self, outdir, prompt=None):
def __init__(self, outdir):
self.outdir = outdir
self.prompt = prompt
self.filepath = None
self.files_written = []
os.makedirs(outdir, exist_ok=True)
def write_image(self, image, seed, upscaled=False):
self.filepath = self.unique_filename(
seed, upscaled, self.filepath
) # will increment name in some sensible way
try:
prompt = f'{self.prompt} -S{seed}'
self.save_image_and_prompt_to_png(image, prompt, self.filepath)
except IOError as e:
print(e)
if not upscaled:
self.files_written.append([self.filepath, seed])
# gives the next unique prefix in outdir
def unique_prefix(self):
# sort reverse alphabetically until we find max+1
dirlist = sorted(os.listdir(self.outdir), reverse=True)
# find the first filename that matches our pattern or return 000000.0.png
existing_name = next(
(f for f in dirlist if re.match('^(\d+)\..*\.png', f)),
'0000000.0.png',
)
basecount = int(existing_name.split('.', 1)[0]) + 1
return f'{basecount:06}'
def unique_filename(self, seed, upscaled=False, previouspath=None):
revision = 1
if previouspath is None:
# sort reverse alphabetically until we find max+1
dirlist = sorted(os.listdir(self.outdir), reverse=True)
# find the first filename that matches our pattern or return 000000.0.png
filename = next(
(f for f in dirlist if re.match('^(\d+)\..*\.png', f)),
'0000000.0.png',
)
basecount = int(filename.split('.', 1)[0])
basecount += 1
filename = f'{basecount:06}.{seed}.png'
return os.path.join(self.outdir, filename)
else:
basename = os.path.basename(previouspath)
x = re.match('^(\d+)\..*\.png', basename)
if not x:
return self.unique_filename(seed, upscaled, previouspath)
basecount = int(x.groups()[0])
series = 0
finished = False
while not finished:
series += 1
filename = f'{basecount:06}.{seed}.png'
path = os.path.join(self.outdir, filename)
finished = not os.path.exists(path)
return os.path.join(self.outdir, filename)
def save_image_and_prompt_to_png(self, image, prompt, path):
# saves image named _image_ to outdir/name, writing metadata from prompt
# returns full path of output
def save_image_and_prompt_to_png(self, image, prompt, name):
path = os.path.join(self.outdir, name)
info = PngImagePlugin.PngInfo()
info.add_text('Dream', prompt)
image.save(path, 'PNG', pnginfo=info)
return path
# TODO move this to its own helper function; it's not really a method of pngwriter
def make_grid(self, image_list, rows=None, cols=None):
image_cnt = len(image_list)
if None in (rows, cols):

View File

@ -88,24 +88,25 @@ class DreamServer(BaseHTTPRequestHandler):
images_generated = 0 # helps keep track of when upscaling is started
images_upscaled = 0 # helps keep track of when upscaling is completed
pngwriter = PngWriter(
"./outputs/img-samples/", config['prompt'], 1
)
pngwriter = PngWriter("./outputs/img-samples/")
prefix = pngwriter.unique_prefix()
# if upscaling is requested, then this will be called twice, once when
# the images are first generated, and then again when after upscaling
# is complete. The upscaling replaces the original file, so the second
# entry should not be inserted into the image list.
def image_done(image, seed, upscaled=False):
pngwriter.write_image(image, seed, upscaled)
name = f'{prefix}.{seed}.png'
path = pngwriter.save_image_and_prompt_to_png(image, f'{prompt} -S{seed}', name)
# Append post_data to log, but only once!
if not upscaled:
current_image = pngwriter.files_written[-1]
with open("./outputs/img-samples/dream_web_log.txt", "a") as log:
log.write(f"{current_image[0]}: {json.dumps(config)}\n")
log.write(f"{path}: {json.dumps(config)}\n")
# TODO fix format of this event
self.wfile.write(bytes(json.dumps(
{'event':'result', 'files':current_image, 'config':config}
{'event': 'result', 'files': [path, seed], 'config': config}
) + '\n',"utf-8"))
# control state of the "postprocessing..." message
@ -129,22 +130,24 @@ class DreamServer(BaseHTTPRequestHandler):
{'event':action,'processed_file_cnt':f'{x}/{iterations}'}
) + '\n',"utf-8"))
# TODO: refactor PngWriter:
# it doesn't need to know if batch_size > 1, just if this is _part of a batch_
step_writer = PngWriter('./outputs/intermediates/', prompt, 2)
step_writer = PngWriter('./outputs/intermediates/')
step_index = 1
def image_progress(sample, step):
if self.canceled.is_set():
self.wfile.write(bytes(json.dumps({'event':'canceled'}) + '\n', 'utf-8'))
raise CanceledException
url = None
path = None
# since rendering images is moderately expensive, only render every 5th image
# and don't bother with the last one, since it'll render anyway
nonlocal step_index
if progress_images and step % 5 == 0 and step < steps - 1:
image = self.model._sample_to_image(sample)
step_writer.write_image(image, seed) # TODO PngWriter to return path
url = step_writer.filepath
name = f'{prefix}.{seed}.{step_index}.png'
metadata = f'{prompt} -S{seed} [intermediate]'
path = step_writer.save_image_and_prompt_to_png(image, metadata, name)
step_index += 1
self.wfile.write(bytes(json.dumps(
{'event':'step', 'step':step + 1, 'url': url}
{'event': 'step', 'step': step + 1, 'url': path}
) + '\n',"utf-8"))
try:

View File

@ -171,10 +171,14 @@ class T2I:
Optional named arguments are the same as those passed to T2I and prompt2image()
"""
results = self.prompt2image(prompt, **kwargs)
pngwriter = PngWriter(outdir, prompt)
for r in results:
pngwriter.write_image(r[0], r[1])
return pngwriter.files_written
pngwriter = PngWriter(outdir)
prefix = pngwriter.unique_prefix()
outputs = []
for image, seed in results:
name = f'{prefix}.{seed}.png'
path = pngwriter.save_image_and_prompt_to_png(image, f'{prompt} -S{seed}', name)
outputs.append([path, seed])
return outputs
def txt2img(self, prompt, **kwargs):
outdir = kwargs.pop('outdir', 'outputs/img-samples')
@ -349,10 +353,7 @@ class T2I:
f'Error running RealESRGAN - Your image was not upscaled.\n{e}'
)
if image_callback is not None:
if save_original:
image_callback(image, seed)
else:
image_callback(image, seed, upscaled=True)
image_callback(image, seed, upscaled=True)
else: # no callback passed, so we simply replace old image with rescaled one
result[0] = image

View File

@ -203,24 +203,40 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
# Here is where the images are actually generated!
try:
file_writer = PngWriter(current_outdir, normalized_prompt)
callback = file_writer.write_image if individual_images else None
image_list = t2i.prompt2image(image_callback=callback, **vars(opt))
results = (
file_writer.files_written if individual_images else image_list
)
file_writer = PngWriter(current_outdir)
prefix = file_writer.unique_prefix()
seeds = set()
results = []
grid_images = dict() # seed -> Image, only used if `do_grid`
def image_writer(image, seed, upscaled=False):
if do_grid:
grid_images[seed] = image
else:
if upscaled and opt.save_original:
filename = f'{prefix}.{seed}.postprocessed.png'
else:
filename = f'{prefix}.{seed}.png'
path = file_writer.save_image_and_prompt_to_png(image, f'{normalized_prompt} -S{seed}', filename)
if (not upscaled) or opt.save_original:
# only append to results if we didn't overwrite an earlier output
results.append([path, seed])
if do_grid and len(results) > 0:
grid_img = file_writer.make_grid([r[0] for r in results])
filename = file_writer.unique_filename(results[0][1])
seeds = [a[1] for a in results]
results = [[filename, seeds]]
metadata_prompt = f'{normalized_prompt} -S{results[0][1]}'
file_writer.save_image_and_prompt_to_png(
seeds.add(seed)
t2i.prompt2image(image_callback=image_writer, **vars(opt))
if do_grid and len(grid_images) > 0:
grid_img = file_writer.make_grid(list(grid_images.values()))
first_seed = next(iter(seeds))
filename = f'{prefix}.{first_seed}.png'
# TODO better metadata for grid images
metadata_prompt = f'{normalized_prompt} -S{first_seed}'
path = file_writer.save_image_and_prompt_to_png(
grid_img, metadata_prompt, filename
)
results = [[path, seeds]]
last_seeds = [r[1] for r in results]
last_seeds = list(seeds)
except AssertionError as e:
print(e)