diff --git a/ldm/dream/args.py b/ldm/dream/args.py index 0b6cfda4cc..6ff164ebfa 100644 --- a/ldm/dream/args.py +++ b/ldm/dream/args.py @@ -2,7 +2,10 @@ The Args class parses both the command line (shell) arguments, as well as the command string passed at the dream> prompt. It serves as the definitive repository -of all the arguments used by Generate and their default values. +of all the arguments used by Generate and their default values, and implements the +preliminary metadata standards discussed here: + +https://github.com/lstein/stable-diffusion/issues/266 To use: opt = Args() @@ -52,10 +55,32 @@ you wish to apply logic as to which one to use. For example: To add new attributes, edit the _create_arg_parser() and _create_dream_cmd_parser() methods. -We also export the function build_metadata +**Generating and retrieving sd-metadata** + +To generate a dict representing RFC266 metadata: + + metadata = metadata_dumps(opt,) + +This will generate an RFC266 dictionary that can then be turned into a JSON +and written to the PNG file. The optional seeds, weights, model_hash and +postprocesser arguments are not available to the opt object and so must be +provided externally. See how dream.py does it. + +Note that this function was originally called format_metadata() and a wrapper +is provided that issues a deprecation notice. + +To retrieve a (series of) opt objects corresponding to the metadata, do this: + + opt_list = metadata_loads(metadata) + +The metadata should be pulled out of the PNG image. pngwriter has a method +retrieve_metadata that will do this. + + """ import argparse +from argparse import Namespace import shlex import json import hashlib @@ -540,17 +565,20 @@ class Args(object): ) return parser -# very partial implementation of https://github.com/lstein/stable-diffusion/issues/266 -# it does not write all the required top-level metadata, writes too much image -# data, and doesn't support grids yet. But you gotta start somewhere, no? -def format_metadata(opt, - seeds=[], - weights=None, - model_hash=None, - postprocessing=None): +def format_metadata(**kwargs): + print(f'format_metadata() is deprecated. Please use metadata_dumps()') + return metadata_dumps(kwargs) + +def metadata_dumps(opt, + seeds=[], + model_hash=None, + postprocessing=None): ''' - Given an Args object, returns a partial implementation of - the stable diffusion metadata standard + Given an Args object, returns a dict containing the keys and + structure of the proposed stable diffusion metadata standard + https://github.com/lstein/stable-diffusion/discussions/392 + This is intended to be turned into JSON and stored in the + "sd ''' # add some RFC266 fields that are generated internally, and not as # user args @@ -611,6 +639,27 @@ def format_metadata(opt, 'images' : images, } +def metadata_loads(metadata): + ''' + Takes the dictionary corresponding to RFC266 (https://github.com/lstein/stable-diffusion/issues/266) + and returns a series of opt objects for each of the images described in the dictionary. + ''' + results = [] + try: + images = metadata['sd-metadata']['images'] + for image in images: + # repack the prompt and variations + image['prompt'] = ','.join([':'.join([x['prompt'], str(x['weight'])]) for x in image['prompt']]) + image['variations'] = ','.join([':'.join([str(x['seed']),str(x['weight'])]) for x in image['variations']]) + opt = Args() + opt._cmd_switches = Namespace(**image) + results.append(opt) + except KeyError as e: + import sys, traceback + print('>> badly-formatted metadata',file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + return results + # image can either be a file path on disk or a base64-encoded # representation of the file's contents def calculate_init_img_hash(image_string): diff --git a/ldm/dream/server.py b/ldm/dream/server.py index 5dd8f38c2f..9e37c070d1 100644 --- a/ldm/dream/server.py +++ b/ldm/dream/server.py @@ -4,7 +4,7 @@ import copy import base64 import mimetypes import os -from ldm.dream.args import Args, format_metadata +from ldm.dream.args import Args, metadata_dumps from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from ldm.dream.pngwriter import PngWriter from threading import Event @@ -175,10 +175,9 @@ class DreamServer(BaseHTTPRequestHandler): path = pngwriter.save_image_and_prompt_to_png( image, dream_prompt = formatted_prompt, - metadata = format_metadata(iter_opt, - seeds = [seed], - weights = self.model.weights, - model_hash = self.model.model_hash + metadata = metadata_dumps(iter_opt, + seeds = [seed], + model_hash = self.model.model_hash ), name = name, ) diff --git a/scripts/dream.py b/scripts/dream.py index 20a2d87e65..289a89e8ad 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -8,7 +8,7 @@ import copy import warnings import time import ldm.dream.readline -from ldm.dream.args import Args, format_metadata +from ldm.dream.args import Args, metadata_dumps from ldm.dream.pngwriter import PngWriter from ldm.dream.server import DreamServer, ThreadingDreamServer from ldm.dream.image_util import make_grid @@ -245,10 +245,9 @@ def main_loop(gen, opt, infile): path = file_writer.save_image_and_prompt_to_png( image = image, dream_prompt = formatted_dream_prompt, - metadata = format_metadata( + metadata = metadata_dumps( opt, seeds = [seed], - weights = gen.weights, model_hash = gen.model_hash, ), name = filename, @@ -272,7 +271,7 @@ def main_loop(gen, opt, infile): filename = f'{prefix}.{first_seed}.png' formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed,grid=True,iterations=len(grid_images)) formatted_dream_prompt += f' # {grid_seeds}' - metadata = format_metadata( + metadata = metadata.dumps( opt, seeds = grid_seeds, weights = gen.weights,