InvokeAI/invokeai/backend/image_util/pngwriter.py

118 lines
4.2 KiB
Python
Raw Normal View History

"""
Two helper classes for dealing with PNG images and their path names.
PngWriter -- Converts Images generated by T2I into PNGs, finds
appropriate names for them, and writes prompt metadata
2022-08-31 04:36:38 +00:00
into the PNG.
Exports function retrieve_metadata(path)
"""
2023-03-03 06:02:00 +00:00
import json
import os
import re
2023-03-03 06:02:00 +00:00
from PIL import Image, PngImagePlugin
# -------------------image generation utils-----
class PngWriter:
2022-08-31 04:21:04 +00:00
def __init__(self, outdir):
self.outdir = outdir
os.makedirs(outdir, exist_ok=True)
2022-08-31 04:21:04 +00:00
# gives the next unique prefix in outdir
def unique_prefix(self):
# sort reverse alphabetically until we find max+1
dirlist = sorted(os.listdir(self.outdir), reverse=True)
# find the first filename that matches our pattern or return 000000.0.png
existing_name = next(
2023-03-03 06:02:00 +00:00
(f for f in dirlist if re.match("^(\d+)\..*\.png", f)),
"0000000.0.png",
2022-08-31 04:21:04 +00:00
)
2023-03-03 06:02:00 +00:00
basecount = int(existing_name.split(".", 1)[0]) + 1
return f"{basecount:06}"
2022-08-31 04:21:04 +00:00
# saves image named _image_ to outdir/name, writing metadata from prompt
# returns full path of output
2023-07-28 13:46:44 +00:00
def save_image_and_prompt_to_png(self, image, dream_prompt, name, metadata=None, compress_level=6):
2022-08-31 04:21:04 +00:00
path = os.path.join(self.outdir, name)
info = PngImagePlugin.PngInfo()
2023-03-03 06:02:00 +00:00
info.add_text("Dream", dream_prompt)
2022-10-03 18:39:58 +00:00
if metadata:
2023-03-03 06:02:00 +00:00
info.add_text("sd-metadata", json.dumps(metadata))
image.save(path, "PNG", pnginfo=info, compress_level=compress_level)
2022-08-31 04:21:04 +00:00
return path
2023-03-03 06:02:00 +00:00
def retrieve_metadata(self, img_basename):
"""
Given a PNG filename stored in outdir, returns the "sd-metadata"
metadata stored there, as a dict
2023-03-03 06:02:00 +00:00
"""
path = os.path.join(self.outdir, img_basename)
2022-09-16 20:16:16 +00:00
all_metadata = retrieve_metadata(path)
2023-03-03 06:02:00 +00:00
return all_metadata["sd-metadata"]
def retrieve_metadata(img_path):
2023-03-03 06:02:00 +00:00
"""
Given a path to a PNG image, returns the "sd-metadata"
metadata stored there, as a dict
2023-03-03 06:02:00 +00:00
"""
im = Image.open(img_path)
2023-03-03 06:02:00 +00:00
if hasattr(im, "text"):
md = im.text.get("sd-metadata", "{}")
dream_prompt = im.text.get("Dream", "")
else:
# When trying to retrieve metadata from images without a 'text' payload, such as JPG images.
2023-03-03 06:02:00 +00:00
md = "{}"
dream_prompt = ""
return {"sd-metadata": json.loads(md), "Dream": dream_prompt}
2023-03-03 06:02:00 +00:00
def write_metadata(img_path: str, meta: dict):
2022-10-03 18:39:58 +00:00
im = Image.open(img_path)
info = PngImagePlugin.PngInfo()
2023-03-03 06:02:00 +00:00
info.add_text("sd-metadata", json.dumps(meta))
im.save(img_path, "PNG", pnginfo=info)
class PromptFormatter:
def __init__(self, t2i, opt):
self.t2i = t2i
self.opt = opt
# note: the t2i object should provide all these values.
# there should be no need to or against opt values
def normalize_prompt(self):
"""Normalize the prompt and switches"""
t2i = self.t2i
opt = self.opt
switches = list()
switches.append(f'"{opt.prompt}"')
2023-03-03 06:02:00 +00:00
switches.append(f"-s{opt.steps or t2i.steps}")
switches.append(f"-W{opt.width or t2i.width}")
switches.append(f"-H{opt.height or t2i.height}")
switches.append(f"-C{opt.cfg_scale or t2i.cfg_scale}")
switches.append(f"-A{opt.sampler_name or t2i.sampler_name}")
# to do: put model name into the t2i object
# switches.append(f'--model{t2i.model_name}')
if opt.seamless or t2i.seamless:
2023-03-03 06:02:00 +00:00
switches.append(f"--seamless")
if opt.init_img:
2023-03-03 06:02:00 +00:00
switches.append(f"-I{opt.init_img}")
if opt.fit:
2023-03-03 06:02:00 +00:00
switches.append(f"--fit")
if opt.strength and opt.init_img is not None:
2023-03-03 06:02:00 +00:00
switches.append(f"-f{opt.strength or t2i.strength}")
if opt.gfpgan_strength:
2023-03-03 06:02:00 +00:00
switches.append(f"-G{opt.gfpgan_strength}")
if opt.upscale:
switches.append(f'-U {" ".join([str(u) for u in opt.upscale])}')
if opt.variation_amount > 0:
2023-03-03 06:02:00 +00:00
switches.append(f"-v{opt.variation_amount}")
if opt.with_variations:
2023-07-28 13:46:44 +00:00
formatted_variations = ",".join(f"{seed}:{weight}" for seed, weight in opt.with_variations)
2023-03-03 06:02:00 +00:00
switches.append(f"-V{formatted_variations}")
return " ".join(switches)