add sd-metadata metadata_loads() and metadata_dumps() functions

This commit is contained in:
Lincoln Stein 2022-09-17 13:28:37 -04:00
parent 2faa116238
commit 239f41f3e0
3 changed files with 68 additions and 21 deletions

View File

@ -2,7 +2,10 @@
The Args class parses both the command line (shell) arguments, as well as the
command string passed at the dream> prompt. It serves as the definitive repository
of all the arguments used by Generate and their default values.
of all the arguments used by Generate and their default values, and implements the
preliminary metadata standards discussed here:
https://github.com/lstein/stable-diffusion/issues/266
To use:
opt = Args()
@ -52,10 +55,32 @@ you wish to apply logic as to which one to use. For example:
To add new attributes, edit the _create_arg_parser() and
_create_dream_cmd_parser() methods.
We also export the function build_metadata
**Generating and retrieving sd-metadata**
To generate a dict representing RFC266 metadata:
metadata = metadata_dumps(opt,<seeds,model_hash,postprocesser>)
This will generate an RFC266 dictionary that can then be turned into a JSON
and written to the PNG file. The optional seeds, weights, model_hash and
postprocesser arguments are not available to the opt object and so must be
provided externally. See how dream.py does it.
Note that this function was originally called format_metadata() and a wrapper
is provided that issues a deprecation notice.
To retrieve a (series of) opt objects corresponding to the metadata, do this:
opt_list = metadata_loads(metadata)
The metadata should be pulled out of the PNG image. pngwriter has a method
retrieve_metadata that will do this.
"""
import argparse
from argparse import Namespace
import shlex
import json
import hashlib
@ -540,17 +565,20 @@ class Args(object):
)
return parser
# very partial implementation of https://github.com/lstein/stable-diffusion/issues/266
# it does not write all the required top-level metadata, writes too much image
# data, and doesn't support grids yet. But you gotta start somewhere, no?
def format_metadata(opt,
seeds=[],
weights=None,
model_hash=None,
postprocessing=None):
def format_metadata(**kwargs):
print(f'format_metadata() is deprecated. Please use metadata_dumps()')
return metadata_dumps(kwargs)
def metadata_dumps(opt,
seeds=[],
model_hash=None,
postprocessing=None):
'''
Given an Args object, returns a partial implementation of
the stable diffusion metadata standard
Given an Args object, returns a dict containing the keys and
structure of the proposed stable diffusion metadata standard
https://github.com/lstein/stable-diffusion/discussions/392
This is intended to be turned into JSON and stored in the
"sd
'''
# add some RFC266 fields that are generated internally, and not as
# user args
@ -611,6 +639,27 @@ def format_metadata(opt,
'images' : images,
}
def metadata_loads(metadata):
'''
Takes the dictionary corresponding to RFC266 (https://github.com/lstein/stable-diffusion/issues/266)
and returns a series of opt objects for each of the images described in the dictionary.
'''
results = []
try:
images = metadata['sd-metadata']['images']
for image in images:
# repack the prompt and variations
image['prompt'] = ','.join([':'.join([x['prompt'], str(x['weight'])]) for x in image['prompt']])
image['variations'] = ','.join([':'.join([str(x['seed']),str(x['weight'])]) for x in image['variations']])
opt = Args()
opt._cmd_switches = Namespace(**image)
results.append(opt)
except KeyError as e:
import sys, traceback
print('>> badly-formatted metadata',file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return results
# image can either be a file path on disk or a base64-encoded
# representation of the file's contents
def calculate_init_img_hash(image_string):

View File

@ -4,7 +4,7 @@ import copy
import base64
import mimetypes
import os
from ldm.dream.args import Args, format_metadata
from ldm.dream.args import Args, metadata_dumps
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from ldm.dream.pngwriter import PngWriter
from threading import Event
@ -175,10 +175,9 @@ class DreamServer(BaseHTTPRequestHandler):
path = pngwriter.save_image_and_prompt_to_png(
image,
dream_prompt = formatted_prompt,
metadata = format_metadata(iter_opt,
seeds = [seed],
weights = self.model.weights,
model_hash = self.model.model_hash
metadata = metadata_dumps(iter_opt,
seeds = [seed],
model_hash = self.model.model_hash
),
name = name,
)

View File

@ -8,7 +8,7 @@ import copy
import warnings
import time
import ldm.dream.readline
from ldm.dream.args import Args, format_metadata
from ldm.dream.args import Args, metadata_dumps
from ldm.dream.pngwriter import PngWriter
from ldm.dream.server import DreamServer, ThreadingDreamServer
from ldm.dream.image_util import make_grid
@ -245,10 +245,9 @@ def main_loop(gen, opt, infile):
path = file_writer.save_image_and_prompt_to_png(
image = image,
dream_prompt = formatted_dream_prompt,
metadata = format_metadata(
metadata = metadata_dumps(
opt,
seeds = [seed],
weights = gen.weights,
model_hash = gen.model_hash,
),
name = filename,
@ -272,7 +271,7 @@ def main_loop(gen, opt, infile):
filename = f'{prefix}.{first_seed}.png'
formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed,grid=True,iterations=len(grid_images))
formatted_dream_prompt += f' # {grid_seeds}'
metadata = format_metadata(
metadata = metadata.dumps(
opt,
seeds = grid_seeds,
weights = gen.weights,