make metadata retrieval more robust to changes in storage format

- args.py will now attempt to return a metadata-containing Args
  object using the following methods:

1. By looking for the 'sd-metadata' tag in the PNG info
2. By looking from the 'Dream' tag
3. As a last resort, fetch the seed from the filename and assume
   defaults for all other options.
This commit is contained in:
Lincoln Stein 2022-09-26 04:18:23 -04:00
parent b512d198f0
commit 14616f4178

View File

@ -86,6 +86,7 @@ import shlex
import json
import hashlib
import os
import re
import copy
import base64
import ldm.dream.pngwriter
@ -732,19 +733,24 @@ def metadata_dumps(opt,
return metadata
def metadata_from_png(png_file_path):
def metadata_from_png(png_file_path) -> Args:
'''
Given the path to a PNG file created by dream.py, retrieves
an Args object containing the image metadata
an Args object containing the image metadata. Note that this
returns a single Args object, not multiple.
'''
meta = ldm.dream.pngwriter.retrieve_metadata(png_file_path)
opts = metadata_loads(meta)
return opts[0]
if 'sd-metadata' in meta and len(meta['sd-metadata'])>0 :
return metadata_loads(meta)[0]
else:
return legacy_metadata_load(meta,png_file_path)
def metadata_loads(metadata):
def metadata_loads(metadata) ->list:
'''
Takes the dictionary corresponding to RFC266 (https://github.com/lstein/stable-diffusion/issues/266)
and returns a series of opt objects for each of the images described in the dictionary.
and returns a series of opt objects for each of the images described in the dictionary. Note that this
returns a list, and not a single object. See metadata_from_png() for a more convenient function for
files that contain a single image.
'''
results = []
try:
@ -797,3 +803,18 @@ def sha256(path):
sha.update(data)
return sha.hexdigest()
def legacy_metadata_load(meta,pathname) -> Args:
if 'Dream' in meta and len(meta['Dream']) > 0:
dream_prompt = meta['Dream']
opt = Args()
opt.parse_cmd(dream_prompt)
return opt
else: # if nothing else, we can get the seed
match = re.search('\d+\.(\d+)',pathname)
if match:
seed = match.groups()[0]
opt = Args()
opt.seed = seed
return opt
return None