InvokeAI/invokeai/backend/image_util/pngwriter.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

122 lines
4.3 KiB
Python
Raw Normal View History

"""
Two helper classes for dealing with PNG images and their path names.
PngWriter -- Converts Images generated by T2I into PNGs, finds
appropriate names for them, and writes prompt metadata
2022-08-31 04:36:38 +00:00
into the PNG.
Exports function retrieve_metadata(path)
"""
2023-03-03 06:02:00 +00:00
import json
import os
import re
2023-03-03 06:02:00 +00:00
from PIL import Image, PngImagePlugin
# -------------------image generation utils-----
class PngWriter:
2022-08-31 04:21:04 +00:00
def __init__(self, outdir):
self.outdir = outdir
os.makedirs(outdir, exist_ok=True)
2022-08-31 04:21:04 +00:00
# gives the next unique prefix in outdir
def unique_prefix(self):
# sort reverse alphabetically until we find max+1
dirlist = sorted(os.listdir(self.outdir), reverse=True)
# find the first filename that matches our pattern or return 000000.0.png
existing_name = next(
2023-03-03 06:02:00 +00:00
(f for f in dirlist if re.match("^(\d+)\..*\.png", f)),
"0000000.0.png",
2022-08-31 04:21:04 +00:00
)
2023-03-03 06:02:00 +00:00
basecount = int(existing_name.split(".", 1)[0]) + 1
return f"{basecount:06}"
2022-08-31 04:21:04 +00:00
# saves image named _image_ to outdir/name, writing metadata from prompt
# returns full path of output
2023-03-03 06:02:00 +00:00
def save_image_and_prompt_to_png(
self, image, dream_prompt, name, metadata=None, compress_level=6
):
2022-08-31 04:21:04 +00:00
path = os.path.join(self.outdir, name)
info = PngImagePlugin.PngInfo()
2023-03-03 06:02:00 +00:00
info.add_text("Dream", dream_prompt)
2022-10-03 18:39:58 +00:00
if metadata:
2023-03-03 06:02:00 +00:00
info.add_text("sd-metadata", json.dumps(metadata))
image.save(path, "PNG", pnginfo=info, compress_level=compress_level)
2022-08-31 04:21:04 +00:00
return path
2023-03-03 06:02:00 +00:00
def retrieve_metadata(self, img_basename):
"""
Given a PNG filename stored in outdir, returns the "sd-metadata"
metadata stored there, as a dict
2023-03-03 06:02:00 +00:00
"""
path = os.path.join(self.outdir, img_basename)
2022-09-16 20:16:16 +00:00
all_metadata = retrieve_metadata(path)
2023-03-03 06:02:00 +00:00
return all_metadata["sd-metadata"]
def retrieve_metadata(img_path):
2023-03-03 06:02:00 +00:00
"""
Given a path to a PNG image, returns the "sd-metadata"
metadata stored there, as a dict
2023-03-03 06:02:00 +00:00
"""
im = Image.open(img_path)
2023-03-03 06:02:00 +00:00
if hasattr(im, "text"):
md = im.text.get("sd-metadata", "{}")
dream_prompt = im.text.get("Dream", "")
else:
# When trying to retrieve metadata from images without a 'text' payload, such as JPG images.
2023-03-03 06:02:00 +00:00
md = "{}"
dream_prompt = ""
return {"sd-metadata": json.loads(md), "Dream": dream_prompt}
2023-03-03 06:02:00 +00:00
def write_metadata(img_path: str, meta: dict):
2022-10-03 18:39:58 +00:00
im = Image.open(img_path)
info = PngImagePlugin.PngInfo()
2023-03-03 06:02:00 +00:00
info.add_text("sd-metadata", json.dumps(meta))
im.save(img_path, "PNG", pnginfo=info)
class PromptFormatter:
def __init__(self, t2i, opt):
self.t2i = t2i
self.opt = opt
# note: the t2i object should provide all these values.
# there should be no need to or against opt values
def normalize_prompt(self):
"""Normalize the prompt and switches"""
t2i = self.t2i
opt = self.opt
switches = list()
switches.append(f'"{opt.prompt}"')
2023-03-03 06:02:00 +00:00
switches.append(f"-s{opt.steps or t2i.steps}")
switches.append(f"-W{opt.width or t2i.width}")
switches.append(f"-H{opt.height or t2i.height}")
switches.append(f"-C{opt.cfg_scale or t2i.cfg_scale}")
switches.append(f"-A{opt.sampler_name or t2i.sampler_name}")
# to do: put model name into the t2i object
# switches.append(f'--model{t2i.model_name}')
if opt.seamless or t2i.seamless:
2023-03-03 06:02:00 +00:00
switches.append(f"--seamless")
if opt.init_img:
2023-03-03 06:02:00 +00:00
switches.append(f"-I{opt.init_img}")
if opt.fit:
2023-03-03 06:02:00 +00:00
switches.append(f"--fit")
if opt.strength and opt.init_img is not None:
2023-03-03 06:02:00 +00:00
switches.append(f"-f{opt.strength or t2i.strength}")
if opt.gfpgan_strength:
2023-03-03 06:02:00 +00:00
switches.append(f"-G{opt.gfpgan_strength}")
if opt.upscale:
switches.append(f'-U {" ".join([str(u) for u in opt.upscale])}')
if opt.variation_amount > 0:
2023-03-03 06:02:00 +00:00
switches.append(f"-v{opt.variation_amount}")
if opt.with_variations:
2023-03-03 06:02:00 +00:00
formatted_variations = ",".join(
f"{seed}:{weight}" for seed, weight in opt.with_variations
)
switches.append(f"-V{formatted_variations}")
return " ".join(switches)