add sd-metadata metadata_loads() and metadata_dumps() functions

This commit is contained in:
Lincoln Stein 2022-09-17 13:28:37 -04:00
parent 2faa116238
commit 239f41f3e0
3 changed files with 68 additions and 21 deletions

View File

@ -2,7 +2,10 @@
The Args class parses both the command line (shell) arguments, as well as the The Args class parses both the command line (shell) arguments, as well as the
command string passed at the dream> prompt. It serves as the definitive repository command string passed at the dream> prompt. It serves as the definitive repository
of all the arguments used by Generate and their default values. of all the arguments used by Generate and their default values, and implements the
preliminary metadata standards discussed here:
https://github.com/lstein/stable-diffusion/issues/266
To use: To use:
opt = Args() opt = Args()
@ -52,10 +55,32 @@ you wish to apply logic as to which one to use. For example:
To add new attributes, edit the _create_arg_parser() and To add new attributes, edit the _create_arg_parser() and
_create_dream_cmd_parser() methods. _create_dream_cmd_parser() methods.
We also export the function build_metadata **Generating and retrieving sd-metadata**
To generate a dict representing RFC266 metadata:
metadata = metadata_dumps(opt,<seeds,model_hash,postprocesser>)
This will generate an RFC266 dictionary that can then be turned into a JSON
and written to the PNG file. The optional seeds, weights, model_hash and
postprocesser arguments are not available to the opt object and so must be
provided externally. See how dream.py does it.
Note that this function was originally called format_metadata() and a wrapper
is provided that issues a deprecation notice.
To retrieve a (series of) opt objects corresponding to the metadata, do this:
opt_list = metadata_loads(metadata)
The metadata should be pulled out of the PNG image. pngwriter has a method
retrieve_metadata that will do this.
""" """
import argparse import argparse
from argparse import Namespace
import shlex import shlex
import json import json
import hashlib import hashlib
@ -540,17 +565,20 @@ class Args(object):
) )
return parser return parser
# very partial implementation of https://github.com/lstein/stable-diffusion/issues/266 def format_metadata(**kwargs):
# it does not write all the required top-level metadata, writes too much image print(f'format_metadata() is deprecated. Please use metadata_dumps()')
# data, and doesn't support grids yet. But you gotta start somewhere, no? return metadata_dumps(kwargs)
def format_metadata(opt,
seeds=[], def metadata_dumps(opt,
weights=None, seeds=[],
model_hash=None, model_hash=None,
postprocessing=None): postprocessing=None):
''' '''
Given an Args object, returns a partial implementation of Given an Args object, returns a dict containing the keys and
the stable diffusion metadata standard structure of the proposed stable diffusion metadata standard
https://github.com/lstein/stable-diffusion/discussions/392
This is intended to be turned into JSON and stored in the
"sd
''' '''
# add some RFC266 fields that are generated internally, and not as # add some RFC266 fields that are generated internally, and not as
# user args # user args
@ -611,6 +639,27 @@ def format_metadata(opt,
'images' : images, 'images' : images,
} }
def metadata_loads(metadata):
'''
Takes the dictionary corresponding to RFC266 (https://github.com/lstein/stable-diffusion/issues/266)
and returns a series of opt objects for each of the images described in the dictionary.
'''
results = []
try:
images = metadata['sd-metadata']['images']
for image in images:
# repack the prompt and variations
image['prompt'] = ','.join([':'.join([x['prompt'], str(x['weight'])]) for x in image['prompt']])
image['variations'] = ','.join([':'.join([str(x['seed']),str(x['weight'])]) for x in image['variations']])
opt = Args()
opt._cmd_switches = Namespace(**image)
results.append(opt)
except KeyError as e:
import sys, traceback
print('>> badly-formatted metadata',file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return results
# image can either be a file path on disk or a base64-encoded # image can either be a file path on disk or a base64-encoded
# representation of the file's contents # representation of the file's contents
def calculate_init_img_hash(image_string): def calculate_init_img_hash(image_string):

View File

@ -4,7 +4,7 @@ import copy
import base64 import base64
import mimetypes import mimetypes
import os import os
from ldm.dream.args import Args, format_metadata from ldm.dream.args import Args, metadata_dumps
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from ldm.dream.pngwriter import PngWriter from ldm.dream.pngwriter import PngWriter
from threading import Event from threading import Event
@ -175,10 +175,9 @@ class DreamServer(BaseHTTPRequestHandler):
path = pngwriter.save_image_and_prompt_to_png( path = pngwriter.save_image_and_prompt_to_png(
image, image,
dream_prompt = formatted_prompt, dream_prompt = formatted_prompt,
metadata = format_metadata(iter_opt, metadata = metadata_dumps(iter_opt,
seeds = [seed], seeds = [seed],
weights = self.model.weights, model_hash = self.model.model_hash
model_hash = self.model.model_hash
), ),
name = name, name = name,
) )

View File

@ -8,7 +8,7 @@ import copy
import warnings import warnings
import time import time
import ldm.dream.readline import ldm.dream.readline
from ldm.dream.args import Args, format_metadata from ldm.dream.args import Args, metadata_dumps
from ldm.dream.pngwriter import PngWriter from ldm.dream.pngwriter import PngWriter
from ldm.dream.server import DreamServer, ThreadingDreamServer from ldm.dream.server import DreamServer, ThreadingDreamServer
from ldm.dream.image_util import make_grid from ldm.dream.image_util import make_grid
@ -245,10 +245,9 @@ def main_loop(gen, opt, infile):
path = file_writer.save_image_and_prompt_to_png( path = file_writer.save_image_and_prompt_to_png(
image = image, image = image,
dream_prompt = formatted_dream_prompt, dream_prompt = formatted_dream_prompt,
metadata = format_metadata( metadata = metadata_dumps(
opt, opt,
seeds = [seed], seeds = [seed],
weights = gen.weights,
model_hash = gen.model_hash, model_hash = gen.model_hash,
), ),
name = filename, name = filename,
@ -272,7 +271,7 @@ def main_loop(gen, opt, infile):
filename = f'{prefix}.{first_seed}.png' filename = f'{prefix}.{first_seed}.png'
formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed,grid=True,iterations=len(grid_images)) formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed,grid=True,iterations=len(grid_images))
formatted_dream_prompt += f' # {grid_seeds}' formatted_dream_prompt += f' # {grid_seeds}'
metadata = format_metadata( metadata = metadata.dumps(
opt, opt,
seeds = grid_seeds, seeds = grid_seeds,
weights = gen.weights, weights = gen.weights,