implementation of RFC #266 (#587)

* Feature complete for #266 with exception of several small deviations:
1. initial image and model weight hashes use full sha256 hash rather than first 8 digits
2. Initialization parameters for post-processing steps not provided
3. Uses top-level "images" tags for both a single image and a grid of images. This change was suggested in a comment.

* Added scripts/sd_metadata.py to retrieve and print metadata from PNG files
* New ldm.dream.args.Args class is a namespace like object which holds all defaults and can be modified during exection to hold current settings.
* Modified dream.py and server.py to accommodate Args class.
This commit is contained in:
Lincoln Stein 2022-09-16 13:09:04 -04:00 committed by GitHub
parent 45af30f3a4
commit 403d02d94f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 768 additions and 433 deletions

619
ldm/dream/args.py Normal file
View File

@ -0,0 +1,619 @@
"""Helper class for dealing with image generation arguments.
The Args class parses both the command line (shell) arguments, as well as the
command string passed at the dream> prompt. It serves as the definitive repository
of all the arguments used by Generate and their default values.
To use:
opt = Args()
# Read in the command line options:
# this returns a namespace object like the underlying argparse library)
# You do not have to use the return value, but you can check it against None
# to detect illegal arguments on the command line.
args = opt.parse_args()
if not args:
print('oops')
sys.exit(-1)
# read in a command passed to the dream> prompt:
opts = opt.parse_cmd('do androids dream of electric sheep? -H256 -W1024 -n4')
# The Args object acts like a namespace object
print(opt.model)
You can set attributes in the usual way, use vars(), etc.:
opt.model = 'something-else'
do_something(**vars(a))
It is helpful in saving metadata:
# To get a json representation of all the values, allowing
# you to override any values dynamically
j = opt.json(seed=42)
# To get the prompt string with the switches, allowing you
# to override any values dynamically
j = opt.dream_prompt_str(seed=42)
If you want to access the namespace objects from the shell args or the
parsed command directly, you may use the values returned from the
original calls to parse_args() and parse_cmd(), or get them later
using the _arg_switches and _cmd_switches attributes. This can be
useful if both the args and the command contain the same attribute and
you wish to apply logic as to which one to use. For example:
a = Args()
args = a.parse_args()
opts = a.parse_cmd(string)
do_grid = args.grid or opts.grid
To add new attributes, edit the _create_arg_parser() and
_create_dream_cmd_parser() methods.
We also export the function build_metadata
"""
import argparse
import shlex
import json
import hashlib
import os
import copy
from ldm.dream.conditioning import split_weighted_subprompts
SAMPLER_CHOICES = [
'ddim',
'k_dpm_2_a',
'k_dpm_2',
'k_euler_a',
'k_euler',
'k_heun',
'k_lms',
'plms',
]
# is there a way to pick this up during git commits?
APP_ID = 'lstein/stable-diffusion'
APP_VERSION = 'v1.15'
class Args(object):
def __init__(self,arg_parser=None,cmd_parser=None):
'''
Initialize new Args class. It takes two optional arguments, an argparse
parser for switches given on the shell command line, and an argparse
parser for switches given on the dream> CLI line. If one or both are
missing, it creates appropriate parsers internally.
'''
self._arg_parser = arg_parser or self._create_arg_parser()
self._cmd_parser = cmd_parser or self._create_dream_cmd_parser()
self._arg_switches = self.parse_cmd('') # fill in defaults
self._cmd_switches = self.parse_cmd('') # fill in defaults
def parse_args(self):
'''Parse the shell switches and store.'''
try:
self._arg_switches = self._arg_parser.parse_args()
return self._arg_switches
except:
return None
def parse_cmd(self,cmd_string):
'''Parse a dream>-style command string '''
command = cmd_string.replace("'", "\\'")
try:
elements = shlex.split(command)
except ValueError:
print(traceback.format_exc(), file=sys.stderr)
return
switches = ['']
switches_started = False
for element in elements:
if element[0] == '-' and not switches_started:
switches_started = True
if switches_started:
switches.append(element)
else:
switches[0] += element
switches[0] += ' '
switches[0] = switches[0][: len(switches[0]) - 1]
try:
self._cmd_switches = self._cmd_parser.parse_args(switches)
return self._cmd_switches
except:
return None
def json(self,**kwargs):
return json.dumps(self.to_dict(**kwargs))
def to_dict(self,**kwargs):
a = vars(self)
a.update(kwargs)
return a
# Isn't there a more automated way of doing this?
# Ideally we get the switch strings out of the argparse objects,
# but I don't see a documented API for this.
def dream_prompt_str(self,**kwargs):
"""Normalized dream_prompt."""
a = vars(self)
a.update(kwargs)
switches = list()
switches.append(f'"{a["prompt"]}')
switches.append(f'-s {a["steps"]}')
switches.append(f'-W {a["width"]}')
switches.append(f'-H {a["height"]}')
switches.append(f'-C {a["cfg_scale"]}')
switches.append(f'-A {a["sampler_name"]}')
switches.append(f'-S {a["seed"]}')
if a['grid']:
switches.append('--grid')
if a['iterations'] and a['iterations']>0:
switches.append(f'-n {a["iterations"]}')
if a['seamless']:
switches.append('--seamless')
if a['init_img'] and len(a['init_img'])>0:
switches.append(f'-I {a["init_img"]}')
if a['fit']:
switches.append(f'--fit')
if a['strength'] and a['strength']>0:
switches.append(f'-f {a["strength"]}')
if a['gfpgan_strength']:
switches.append(f'-G {a["gfpgan_strength"]}')
if a['upscale']:
switches.append(f'-U {" ".join([str(u) for u in a["upscale"]])}')
if a['embiggen']:
switches.append(f'--embiggen {" ".join([str(u) for u in a["embiggen"]])}')
if a['embiggen_tiles']:
switches.append(f'--embiggen_tiles {" ".join([str(u) for u in a["embiggen_tiles"]])}')
if a['variation_amount'] > 0:
switches.append(f'-v {a["variation_amount"]}')
if a['with_variations']:
formatted_variations = ','.join(f'{seed}:{weight}' for seed, weight in (a["with_variations"]))
switches.append(f'-V {formatted_variations}')
return ' '.join(switches)
def __getattribute__(self,name):
'''
Returns union of command-line arguments and dream_prompt arguments,
with the latter superseding the former.
'''
cmd_switches = None
arg_switches = None
try:
cmd_switches = object.__getattribute__(self,'_cmd_switches')
arg_switches = object.__getattribute__(self,'_arg_switches')
except AttributeError:
pass
if cmd_switches and arg_switches and name=='__dict__':
a = arg_switches.__dict__
a.update(cmd_switches.__dict__)
return a
try:
return object.__getattribute__(self,name)
except AttributeError:
pass
if not hasattr(cmd_switches,name) and not hasattr(arg_switches,name):
raise AttributeError
value_arg,value_cmd = (None,None)
try:
value_cmd = getattr(cmd_switches,name)
except AttributeError:
pass
try:
value_arg = getattr(arg_switches,name)
except AttributeError:
pass
# here is where we can pick and choose which to use
# default behavior is to choose the dream_command value over
# the arg value. For example, the --grid and --individual options are a little
# funny because of their push/pull relationship. This is how to handle it.
if name=='grid':
return value_arg or value_cmd # arg supersedes cmd
if name=='individual':
return value_cmd or value_arg # cmd supersedes arg
if value_cmd is not None:
return value_cmd
else:
return value_arg
def __setattr__(self,name,value):
if name.startswith('_'):
object.__setattr__(self,name,value)
else:
self._cmd_switches.__dict__[name] = value
def _create_arg_parser(self):
'''
This defines all the arguments used on the command line when you launch
the CLI or web backend.
'''
parser = argparse.ArgumentParser(
description=
"""
Generate images using Stable Diffusion.
Use --web to launch the web interface.
Use --from_file to load prompts from a file path or standard input ("-").
Otherwise you will be dropped into an interactive command prompt (type -h for help.)
Other command-line arguments are defaults that can usually be overridden
prompt the command prompt.
""",
)
model_group = parser.add_argument_group('Model selection')
file_group = parser.add_argument_group('Input/output')
web_server_group = parser.add_argument_group('Web server')
render_group = parser.add_argument_group('Rendering')
postprocessing_group = parser.add_argument_group('Postprocessing')
deprecated_group = parser.add_argument_group('Deprecated options')
deprecated_group.add_argument('--laion400m')
deprecated_group.add_argument('--weights') # deprecated
model_group.add_argument(
'--conf',
'-c',
'-conf',
dest='conf',
default='./configs/models.yaml',
help='Path to configuration file for alternate models.',
)
model_group.add_argument(
'--model',
default='stable-diffusion-1.4',
help='Indicates which diffusion model to load. (currently "stable-diffusion-1.4" (default) or "laion400m")',
)
model_group.add_argument(
'-F',
'--full_precision',
dest='full_precision',
action='store_true',
help='Use more memory-intensive full precision math for calculations',
)
file_group.add_argument(
'--from_file',
dest='infile',
type=str,
help='If specified, load prompts from this file',
)
file_group.add_argument(
'--outdir',
'-o',
type=str,
help='Directory to save generated images and a log of prompts and seeds. Default: outputs/img-samples',
default='outputs/img-samples',
)
file_group.add_argument(
'--prompt_as_dir',
'-p',
action='store_true',
help='Place images in subdirectories named after the prompt.',
)
render_group.add_argument(
'--seamless',
action='store_true',
help='Change the model to seamless tiling (circular) mode',
)
render_group.add_argument(
'--grid',
'-g',
action='store_true',
help='generate a grid'
)
render_group.add_argument(
'--embedding_path',
type=str,
help='Path to a pre-trained embedding manager checkpoint - can only be set on command line',
)
# GFPGAN related args
postprocessing_group.add_argument(
'--gfpgan_bg_upsampler',
type=str,
default='realesrgan',
help='Background upsampler. Default: realesrgan. Options: realesrgan, none.',
)
postprocessing_group.add_argument(
'--gfpgan_bg_tile',
type=int,
default=400,
help='Tile size for background sampler, 0 for no tile during testing. Default: 400.',
)
postprocessing_group.add_argument(
'--gfpgan_model_path',
type=str,
default='experiments/pretrained_models/GFPGANv1.3.pth',
help='Indicates the path to the GFPGAN model, relative to --gfpgan_dir.',
)
postprocessing_group.add_argument(
'--gfpgan_dir',
type=str,
default='./src/gfpgan',
help='Indicates the directory containing the GFPGAN code.',
)
web_server_group.add_argument(
'--web',
dest='web',
action='store_true',
help='Start in web server mode.',
)
web_server_group.add_argument(
'--host',
type=str,
default='127.0.0.1',
help='Web server: Host or IP to listen on. Set to 0.0.0.0 to accept traffic from other devices on your network.'
)
web_server_group.add_argument(
'--port',
type=int,
default='9090',
help='Web server: Port to listen on'
)
return parser
# This creates the parser that processes commands on the dream> command line
def _create_dream_cmd_parser(self):
parser = argparse.ArgumentParser(
description='Example: dream> a fantastic alien landscape -W1024 -H960 -s100 -n12'
)
render_group = parser.add_argument_group('General rendering')
img2img_group = parser.add_argument_group('Image-to-image and inpainting')
variation_group = parser.add_argument_group('Creating and combining variations')
postprocessing_group = parser.add_argument_group('Post-processing')
special_effects_group = parser.add_argument_group('Special effects')
render_group.add_argument('prompt')
render_group.add_argument(
'-s',
'--steps',
type=int,
default=50,
help='Number of steps'
)
render_group.add_argument(
'-S',
'--seed',
type=int,
default=None,
help='Image seed; a +ve integer, or use -1 for the previous seed, -2 for the one before that, etc',
)
render_group.add_argument(
'-n',
'--iterations',
type=int,
default=1,
help='Number of samplings to perform (slower, but will provide seeds for individual images)',
)
render_group.add_argument(
'-W',
'--width',
type=int,
help='Image width, multiple of 64',
default=512
)
render_group.add_argument(
'-H',
'--height',
type=int,
help='Image height, multiple of 64',
default=512,
)
render_group.add_argument(
'-C',
'--cfg_scale',
default=7.5,
type=float,
help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.',
)
render_group.add_argument(
'--grid',
'-g',
action='store_true',
help='generate a grid'
)
render_group.add_argument(
'--individual',
'-i',
action='store_true',
help='override command-line --grid setting and generate individual images'
)
render_group.add_argument(
'-x',
'--skip_normalize',
action='store_true',
help='Skip subprompt weight normalization',
)
render_group.add_argument(
'-A',
'-m',
'--sampler',
dest='sampler_name',
type=str,
choices=SAMPLER_CHOICES,
metavar='SAMPLER_NAME',
help=f'Switch to a different sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
default='k_lms',
)
render_group.add_argument(
'-t',
'--log_tokenization',
action='store_true',
help='shows how the prompt is split into tokens'
)
render_group.add_argument(
'--outdir',
'-o',
type=str,
default='outputs/img-samples',
help='Directory to save generated images and a log of prompts and seeds',
)
img2img_group.add_argument(
'-I',
'--init_img',
type=str,
help='Path to input image for img2img mode (supersedes width and height)',
)
img2img_group.add_argument(
'-M',
'--init_mask',
type=str,
help='Path to input mask for inpainting mode (supersedes width and height)',
)
img2img_group.add_argument(
'-T',
'-fit',
'--fit',
action='store_true',
help='If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)',
)
img2img_group.add_argument(
'-f',
'--strength',
type=float,
help='Strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
default=0.75,
)
postprocessing_group.add_argument(
'-G',
'--gfpgan_strength',
type=float,
help='The strength at which to apply the GFPGAN model to the result, in order to improve faces.',
default=0,
)
postprocessing_group.add_argument(
'-U',
'--upscale',
nargs='+',
type=float,
help='Scale factor (2, 4) for upscaling final output followed by upscaling strength (0-1.0). If strength not specified, defaults to 0.75',
default=None,
)
postprocessing_group.add_argument(
'--save_original',
'-save_orig',
action='store_true',
help='Save original. Use it when upscaling to save both versions.',
)
postprocessing_group.add_argument(
'--embiggen',
'-embiggen',
nargs='+',
type=float,
help='Embiggen tiled img2img for higher resolution and detail without extra VRAM usage. Takes scale factor relative to the size of the --init_img (-I), followed by ESRGAN upscaling strength (0-1.0), followed by minimum amount of overlap between tiles as a decimal ratio (0 - 1.0) or number of pixels. ESRGAN strength defaults to 0.75, and overlap defaults to 0.25 . ESRGAN is used to upscale the init prior to cutting it into tiles/pieces to run through img2img and then stitch back togeather.',
default=None,
)
postprocessing_group.add_argument(
'--embiggen_tiles',
'-embiggen_tiles',
nargs='+',
type=int,
help='If while doing Embiggen we are altering only parts of the image, takes a list of tiles by number to process and replace onto the image e.g. `1 3 5`, useful for redoing problematic spots from a prior Embiggen run',
default=None,
)
special_effects_group.add_argument(
'--seamless',
action='store_true',
help='Change the model to seamless tiling (circular) mode',
)
variation_group.add_argument(
'-v',
'--variation_amount',
default=0.0,
type=float,
help='If > 0, generates variations on the initial seed instead of random seeds per iteration. Must be between 0 and 1. Higher values will be more different.'
)
variation_group.add_argument(
'-V',
'--with_variations',
default=None,
type=str,
help='list of variations to apply, in the format `seed:weight,seed:weight,...'
)
return parser
# very partial implementation of https://github.com/lstein/stable-diffusion/issues/266
# it does not write all the required top-level metadata, writes too much image
# data, and doesn't support grids yet. But you gotta start somewhere, no?
def format_metadata(opt,
seeds=[],
weights=None,
model_hash=None,
postprocessing=None):
'''
Given an Args object, returns a partial implementation of
the stable diffusion metadata standard
'''
# add some RFC266 fields that are generated internally, and not as
# user args
image_dict = opt.to_dict(
postprocessing=postprocessing
)
# TODO: This is just a hack until postprocessing pipeline work completed
image_dict['postprocessing'] = []
if image_dict['gfpgan_strength'] and image_dict['gfpgan_strength'] > 0:
image_dict['postprocessing'].append('GFPGAN (not RFC compliant)')
if image_dict['upscale'] and image_dict['upscale'][0] > 0:
image_dict['postprocessing'].append('ESRGAN (not RFC compliant)')
# remove any image keys not mentioned in RFC #266
rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps',
'cfg_scale','step_number','width','height','extra','strength']
rfc_dict ={}
for item in image_dict.items():
key,value = item
if key in rfc266_img_fields:
rfc_dict[key] = value
# semantic drift
rfc_dict['sampler'] = image_dict.get('sampler_name',None)
# display weighted subprompts (liable to change)
if opt.prompt:
subprompts = split_weighted_subprompts(opt.prompt)
subprompts = [{'prompt':x[0],'weight':x[1]} for x in subprompts]
rfc_dict['prompt'] = subprompts
# variations
if opt.with_variations:
variations = [{'seed':x[0],'weight':x[1]} for x in opt.with_variations]
rfc_dict['variations'] = variations
if opt.init_img:
rfc_dict['type'] = 'img2img'
rfc_dict['strength_steps'] = rfc_dict.pop('strength')
rfc_dict['orig_hash'] = sha256(image_dict['init_img'])
rfc_dict['sampler'] = 'ddim' # FIX ME WHEN IMG2IMG SUPPORTS ALL SAMPLERS
else:
rfc_dict['type'] = 'txt2img'
images = []
for seed in seeds:
rfc_dict['seed'] = seed
images.append(copy.copy(rfc_dict))
return {
'model' : 'stable diffusion',
'model_id' : opt.model,
'model_hash' : model_hash,
'app_id' : APP_ID,
'app_version' : APP_VERSION,
'images' : images,
}
# Bah. This should be moved somewhere else...
def sha256(path):
sha = hashlib.sha256()
with open(path,'rb') as f:
while True:
data = f.read(65536)
if not data:
break
sha.update(data)
return sha.hexdigest()

View File

@ -3,12 +3,13 @@ Two helper classes for dealing with PNG images and their path names.
PngWriter -- Converts Images generated by T2I into PNGs, finds
appropriate names for them, and writes prompt metadata
into the PNG.
PromptFormatter -- Utility for converting a Namespace of prompt parameters
back into a formatted prompt string with command-line switches.
Exports function retrieve_metadata(path)
"""
import os
import re
from PIL import PngImagePlugin
import json
from PIL import PngImagePlugin, Image
# -------------------image generation utils-----
@ -32,54 +33,31 @@ class PngWriter:
# saves image named _image_ to outdir/name, writing metadata from prompt
# returns full path of output
def save_image_and_prompt_to_png(self, image, prompt, name):
def save_image_and_prompt_to_png(self, image, dream_prompt, metadata, name):
path = os.path.join(self.outdir, name)
info = PngImagePlugin.PngInfo()
info.add_text('Dream', prompt)
info.add_text('Dream', dream_prompt)
info.add_text('sd-metadata', json.dumps(metadata))
image.save(path, 'PNG', pnginfo=info)
return path
def retrieve_metadata(self,img_basename):
'''
Given a PNG filename stored in outdir, returns the "sd-metadata"
metadata stored there, as a dict
'''
path = os.path.join(self.outdir,img_basename)
return retrieve_metadata(path)
class PromptFormatter:
def __init__(self, t2i, opt):
self.t2i = t2i
self.opt = opt
def retrieve_metadata(img_path):
'''
Given a path to a PNG image, returns the "sd-metadata"
metadata stored there, as a dict
'''
im = Image.open(img_path)
md = im.text.get('sd-metadata',{})
return json.loads(md)
# note: the t2i object should provide all these values.
# there should be no need to or against opt values
def normalize_prompt(self):
"""Normalize the prompt and switches"""
t2i = self.t2i
opt = self.opt
switches = list()
switches.append(f'"{opt.prompt}"')
switches.append(f'-s{opt.steps or t2i.steps}')
switches.append(f'-W{opt.width or t2i.width}')
switches.append(f'-H{opt.height or t2i.height}')
switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}')
switches.append(f'-A{opt.sampler_name or t2i.sampler_name}')
# to do: put model name into the t2i object
# switches.append(f'--model{t2i.model_name}')
if opt.seamless or t2i.seamless:
switches.append(f'--seamless')
if opt.init_img:
switches.append(f'-I{opt.init_img}')
if opt.fit:
switches.append(f'--fit')
if opt.strength and opt.init_img is not None:
switches.append(f'-f{opt.strength or t2i.strength}')
if opt.gfpgan_strength:
switches.append(f'-G{opt.gfpgan_strength}')
if opt.upscale:
switches.append(f'-U {" ".join([str(u) for u in opt.upscale])}')
if hasattr(opt, 'embiggen') and opt.embiggen:
switches.append(f'-embiggen {" ".join([str(u) for u in opt.embiggen])}')
if hasattr(opt, 'embiggen_tiles') and opt.embiggen_tiles:
switches.append(f'-embiggen_tiles {" ".join([str(u) for u in opt.embiggen_tiles])}')
if opt.variation_amount > 0:
switches.append(f'-v{opt.variation_amount}')
if opt.with_variations:
formatted_variations = ','.join(f'{seed}:{weight}' for seed, weight in opt.with_variations)
switches.append(f'-V{formatted_variations}')
return ' '.join(switches)

View File

@ -1,14 +1,17 @@
import argparse
import json
import copy
import base64
import mimetypes
import os
from ldm.dream.args import Args, format_metadata
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from ldm.dream.pngwriter import PngWriter, PromptFormatter
from ldm.dream.pngwriter import PngWriter
from threading import Event
def build_opt(post_data, seed, gfpgan_model_exists):
opt = argparse.Namespace()
opt = Args()
opt.parse_args() # initialize defaults
setattr(opt, 'prompt', post_data['prompt'])
setattr(opt, 'init_img', post_data['initimg'])
setattr(opt, 'strength', float(post_data['strength']))
@ -40,7 +43,7 @@ def build_opt(post_data, seed, gfpgan_model_exists):
for part in post_data['with_variations'].split(','):
seed_and_weight = part.split(':')
if len(seed_and_weight) != 2:
print(f'could not parse with_variation part "{part}"')
print(f'could not parse WITH_variation part "{part}"')
broken = True
break
try:
@ -158,10 +161,10 @@ class DreamServer(BaseHTTPRequestHandler):
# the images are first generated, and then again when after upscaling
# is complete. The upscaling replaces the original file, so the second
# entry should not be inserted into the image list.
# LS: This repeats code in dream.py
def image_done(image, seed, upscaled=False):
name = f'{prefix}.{seed}.png'
iter_opt = argparse.Namespace(**vars(opt)) # copy
print(f'iter_opt = {iter_opt}')
iter_opt = copy.copy(opt)
if opt.variation_amount > 0:
this_variation = [[seed, opt.variation_amount]]
if opt.with_variations is None:
@ -169,10 +172,17 @@ class DreamServer(BaseHTTPRequestHandler):
else:
iter_opt.with_variations = opt.with_variations + this_variation
iter_opt.variation_amount = 0
elif opt.with_variations is None:
iter_opt.seed = seed
normalized_prompt = PromptFormatter(self.model, iter_opt).normalize_prompt()
path = pngwriter.save_image_and_prompt_to_png(image, f'{normalized_prompt} -S{iter_opt.seed}', name)
formatted_prompt = opt.dream_prompt_str(seed=seed)
path = pngwriter.save_image_and_prompt_to_png(
image,
dream_prompt = formatted_prompt,
metadata = format_metadata(iter_opt,
seeds = [seed],
weights = self.model.weights,
model_hash = self.model.model_hash
),
name = name,
)
if int(config['seed']) == -1:
config['seed'] = seed

View File

@ -13,6 +13,8 @@ import re
import sys
import traceback
import transformers
import io
import hashlib
from omegaconf import OmegaConf
from PIL import Image, ImageOps
@ -567,7 +569,11 @@ class Generate:
# this does the work
c = OmegaConf.load(config)
pl_sd = torch.load(weights, map_location='cpu')
with open(weights,'rb') as f:
weight_bytes = f.read()
self.model_hash = self._cached_sha256(weights,weight_bytes)
pl_sd = torch.load(io.BytesIO(weight_bytes), map_location='cpu')
del weight_bytes
sd = pl_sd['state_dict']
model = instantiate_from_config(c.model)
m, u = model.load_state_dict(sd, strict=False)
@ -728,3 +734,24 @@ class Generate:
def _has_cuda(self):
return self.device.type == 'cuda'
def _cached_sha256(self,path,data):
dirname = os.path.dirname(path)
basename = os.path.basename(path)
base, _ = os.path.splitext(basename)
hashpath = os.path.join(dirname,base+'.sha256')
if os.path.exists(hashpath) and os.path.getmtime(path) <= os.path.getmtime(hashpath):
with open(hashpath) as f:
hash = f.read()
return hash
print(f'>> Calculating sha256 hash of weights file')
tic = time.time()
sha = hashlib.sha256()
sha.update(data)
hash = sha.hexdigest()
toc = time.time()
print(f'>> sha256 = {hash}','(%4.2fs)' % (toc - tic))
with open(hashpath,'w') as f:
f.write(hash)
return hash

View File

@ -5,10 +5,11 @@ import sys
import numpy as np
from PIL import Image
from scripts.dream import create_argv_parser
#from scripts.dream import create_argv_parser
from ldm.dream.args import Args
arg_parser = create_argv_parser()
opt = arg_parser.parse_args()
opt = Args()
opt.parse_args()
model_path = os.path.join(opt.gfpgan_dir, opt.gfpgan_model_path)
gfpgan_model_exists = os.path.isfile(model_path)

426
scripts/dream.py Executable file → Normal file
View File

@ -1,8 +1,6 @@
#!/usr/bin/env python3
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
import argparse
import shlex
import os
import re
import sys
@ -10,7 +8,8 @@ import copy
import warnings
import time
import ldm.dream.readline
from ldm.dream.pngwriter import PngWriter, PromptFormatter
from ldm.dream.args import Args, format_metadata
from ldm.dream.pngwriter import PngWriter
from ldm.dream.server import DreamServer, ThreadingDreamServer
from ldm.dream.image_util import make_grid
from omegaconf import OmegaConf
@ -22,14 +21,16 @@ output_cntr = 0
def main():
"""Initialize command-line parsers and the diffusion model"""
arg_parser = create_argv_parser()
opt = arg_parser.parse_args()
opt = Args()
args = opt.parse_args()
if not args:
sys.exit(-1)
if opt.laion400m:
if args.laion400m:
print('--laion400m flag has been deprecated. Please use --model laion400m instead.')
sys.exit(-1)
if opt.weights != 'model':
print('--weights argument has been deprecated. Please configure ./configs/models.yaml, and call it using --model instead.')
if args.weights:
print('--weights argument has been deprecated. Please edit ./configs/models.yaml, and select the weights using --model instead.')
sys.exit(-1)
print('* Initializing, be patient...\n')
@ -47,7 +48,7 @@ def main():
# the user input loop
try:
gen = Generate(
conf = opt.config,
conf = opt.conf,
model = opt.model,
sampler_name = opt.sampler_name,
embedding_path = opt.embedding_path,
@ -91,11 +92,10 @@ def main():
dream_server_loop(gen, opt.host, opt.port, opt.outdir)
sys.exit(0)
cmd_parser = create_cmd_parser()
main_loop(gen, opt.outdir, opt.prompt_as_dir, cmd_parser, infile)
main_loop(gen, opt, infile)
# TODO: main_loop() has gotten busy. Needs to be refactored.
def main_loop(gen, outdir, prompt_as_dir, parser, infile):
def main_loop(gen, opt, infile):
"""prompt/read/execute loop"""
done = False
path_filter = re.compile(r'[<>:"/\\|?*]')
@ -103,8 +103,8 @@ def main_loop(gen, outdir, prompt_as_dir, parser, infile):
# os.pathconf is not available on Windows
if hasattr(os, 'pathconf'):
path_max = os.pathconf(outdir, 'PC_PATH_MAX')
name_max = os.pathconf(outdir, 'PC_NAME_MAX')
path_max = os.pathconf(opt.outdir, 'PC_PATH_MAX')
name_max = os.pathconf(opt.outdir, 'PC_NAME_MAX')
else:
path_max = 260
name_max = 255
@ -123,41 +123,17 @@ def main_loop(gen, outdir, prompt_as_dir, parser, infile):
if command.startswith(('#', '//')):
continue
# before splitting, escape single quotes so as not to mess
# up the parser
command = command.replace("'", "\\'")
try:
elements = shlex.split(command)
except ValueError as e:
print(str(e))
continue
if elements[0] == 'q':
if command.startswith('q '):
done = True
break
if elements[0].startswith(
if command.startswith(
'!dream'
): # in case a stored prompt still contains the !dream command
elements.pop(0)
# rearrange the arguments to mimic how it works in the Dream bot.
switches = ['']
switches_started = False
for el in elements:
if el[0] == '-' and not switches_started:
switches_started = True
if switches_started:
switches.append(el)
else:
switches[0] += el
switches[0] += ' '
switches[0] = switches[0][: len(switches[0]) - 1]
command.replace('!dream','',1)
try:
opt = parser.parse_args(switches)
parser = opt.parse_cmd(command)
except SystemExit:
parser.print_help()
continue
@ -185,6 +161,7 @@ def main_loop(gen, outdir, prompt_as_dir, parser, infile):
opt.seed = None
continue
# TODO - move this into a module
if opt.with_variations is not None:
# shotgun parsing, woo
parts = []
@ -220,7 +197,7 @@ def main_loop(gen, outdir, prompt_as_dir, parser, infile):
# truncate path to maximum allowed length
# 27 is the length of '######.##########.##.png', plus two separators and a NUL
subdir = subdir[:(path_max - 27 - len(os.path.abspath(outdir)))]
subdir = subdir[:(path_max - 27 - len(os.path.abspath(opt.outdir)))]
current_outdir = os.path.join(outdir, subdir)
print('Writing files to directory: "' + current_outdir + '"')
@ -248,31 +225,36 @@ def main_loop(gen, outdir, prompt_as_dir, parser, infile):
filename = f'{prefix}.{seed}.postprocessed.png'
else:
filename = f'{prefix}.{seed}.png'
# the handling of variations is probably broken
# Also, given the ability to add stuff to the dream_prompt_str, it isn't
# necessary to make a copy of the opt option just to change its attributes
if opt.variation_amount > 0:
iter_opt = argparse.Namespace(**vars(opt)) # copy
iter_opt = copy.copy(opt)
this_variation = [[seed, opt.variation_amount]]
if opt.with_variations is None:
iter_opt.with_variations = this_variation
else:
iter_opt.with_variations = opt.with_variations + this_variation
iter_opt.variation_amount = 0
normalized_prompt = PromptFormatter(
gen, iter_opt).normalize_prompt()
metadata_prompt = f'{normalized_prompt} -S{iter_opt.seed}'
formatted_dream_prompt = iter_opt.dream_prompt_str(seed=seed)
elif opt.with_variations is not None:
normalized_prompt = PromptFormatter(
gen, opt).normalize_prompt()
# use the original seed - the per-iteration value is the last variation-seed
metadata_prompt = f'{normalized_prompt} -S{opt.seed}'
formatted_dream_prompt = opt.dream_prompt_str(seed=seed)
else:
normalized_prompt = PromptFormatter(
gen, opt).normalize_prompt()
metadata_prompt = f'{normalized_prompt} -S{seed}'
formatted_dream_prompt = opt.dream_prompt_str(seed=seed)
path = file_writer.save_image_and_prompt_to_png(
image, metadata_prompt, filename)
image = image,
dream_prompt = formatted_dream_prompt,
metadata = format_metadata(
opt,
seeds = [seed],
weights = gen.weights,
model_hash = gen.model_hash,
),
name = filename,
)
if (not upscaled) or opt.save_original:
# only append to results if we didn't overwrite an earlier output
results.append([path, metadata_prompt])
results.append([path, formatted_dream_prompt])
last_results.append([path, seed])
catch_ctrl_c = infile is None # if running interactively, we catch keyboard interrupts
@ -286,15 +268,22 @@ def main_loop(gen, outdir, prompt_as_dir, parser, infile):
grid_img = make_grid(list(grid_images.values()))
grid_seeds = list(grid_images.keys())
first_seed = last_results[0][1]
filename = f'{prefix}.{first_seed}.png'
# TODO better metadata for grid images
normalized_prompt = PromptFormatter(
gen, opt).normalize_prompt()
metadata_prompt = f'{normalized_prompt} -S{first_seed} --grid -n{len(grid_images)} # {grid_seeds}'
filename = f'{prefix}.{first_seed}.png'
formatted_dream_prompt = opt.dream_prompt_str(seed=first_seed,grid=True,iterations=len(grid_images))
formatted_dream_prompt += f' # {grid_seeds}'
metadata = format_metadata(
opt,
seeds = grid_seeds,
weights = gen.weights,
model_hash = gen.model_hash
)
path = file_writer.save_image_and_prompt_to_png(
grid_img, metadata_prompt, filename
image = grid_img,
dream_prompt = formatted_dream_prompt,
metadata = metadata,
name = filename
)
results = [[path, metadata_prompt]]
results = [[path, formatted_dream_prompt]]
except AssertionError as e:
print(e)
@ -325,7 +314,6 @@ def get_next_command(infile=None) -> str: # command string
print(f'#{command}')
return command
def dream_server_loop(gen, host, port, outdir):
print('\n* --web was specified, starting web server...')
# Change working directory to the stable-diffusion directory
@ -365,315 +353,5 @@ def write_log_message(results, log_path):
with open(log_path, 'a', encoding='utf-8') as file:
file.writelines(log_lines)
SAMPLER_CHOICES = [
'ddim',
'k_dpm_2_a',
'k_dpm_2',
'k_euler_a',
'k_euler',
'k_heun',
'k_lms',
'plms',
]
def create_argv_parser():
parser = argparse.ArgumentParser(
description="""Generate images using Stable Diffusion.
Use --web to launch the web interface.
Use --from_file to load prompts from a file path or standard input ("-").
Otherwise you will be dropped into an interactive command prompt (type -h for help.)
Other command-line arguments are defaults that can usually be overridden
prompt the command prompt.
"""
)
parser.add_argument(
'--laion400m',
'--latent_diffusion',
'-l',
dest='laion400m',
action='store_true',
help='Fallback to the latent diffusion (laion400m) weights and config',
)
parser.add_argument(
'--from_file',
dest='infile',
type=str,
help='If specified, load prompts from this file',
)
parser.add_argument(
'-n',
'--iterations',
type=int,
default=1,
help='Number of images to generate',
)
parser.add_argument(
'-F',
'--full_precision',
dest='full_precision',
action='store_true',
help='Use more memory-intensive full precision math for calculations',
)
parser.add_argument(
'-g',
'--grid',
action='store_true',
help='Generate a grid instead of individual images',
)
parser.add_argument(
'-A',
'-m',
'--sampler',
dest='sampler_name',
choices=SAMPLER_CHOICES,
metavar='SAMPLER_NAME',
default='k_lms',
help=f'Set the initial sampler. Default: k_lms. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
)
parser.add_argument(
'--outdir',
'-o',
type=str,
default='outputs/img-samples',
help='Directory to save generated images and a log of prompts and seeds. Default: outputs/img-samples',
)
parser.add_argument(
'--seamless',
action='store_true',
help='Change the model to seamless tiling (circular) mode',
)
parser.add_argument(
'--embedding_path',
type=str,
help='Path to a pre-trained embedding manager checkpoint - can only be set on command line',
)
parser.add_argument(
'--prompt_as_dir',
'-p',
action='store_true',
help='Place images in subdirectories named after the prompt.',
)
# GFPGAN related args
parser.add_argument(
'--gfpgan_bg_upsampler',
type=str,
default='realesrgan',
help='Background upsampler. Default: realesrgan. Options: realesrgan, none.',
)
parser.add_argument(
'--gfpgan_bg_tile',
type=int,
default=400,
help='Tile size for background sampler, 0 for no tile during testing. Default: 400.',
)
parser.add_argument(
'--gfpgan_model_path',
type=str,
default='experiments/pretrained_models/GFPGANv1.3.pth',
help='Indicates the path to the GFPGAN model, relative to --gfpgan_dir.',
)
parser.add_argument(
'--gfpgan_dir',
type=str,
default='./src/gfpgan',
help='Indicates the directory containing the GFPGAN code.',
)
parser.add_argument(
'--web',
dest='web',
action='store_true',
help='Start in web server mode.',
)
parser.add_argument(
'--host',
type=str,
default='127.0.0.1',
help='Web server: Host or IP to listen on. Set to 0.0.0.0 to accept traffic from other devices on your network.'
)
parser.add_argument(
'--port',
type=int,
default='9090',
help='Web server: Port to listen on'
)
parser.add_argument(
'--weights',
default='model',
help='Indicates the Stable Diffusion model to use.',
)
parser.add_argument(
'--model',
default='stable-diffusion-1.4',
help='Indicates which diffusion model to load. (currently "stable-diffusion-1.4" (default) or "laion400m")',
)
parser.add_argument(
'--config',
default='configs/models.yaml',
help='Path to configuration file for alternate models.',
)
return parser
def create_cmd_parser():
parser = argparse.ArgumentParser(
description='Example: dream> a fantastic alien landscape -W1024 -H960 -s100 -n12'
)
parser.add_argument('prompt')
parser.add_argument('-s', '--steps', type=int, help='Number of steps')
parser.add_argument(
'-S',
'--seed',
type=int,
help='Image seed; a +ve integer, or use -1 for the previous seed, -2 for the one before that, etc',
)
parser.add_argument(
'-n',
'--iterations',
type=int,
default=1,
help='Number of samplings to perform (slower, but will provide seeds for individual images)',
)
parser.add_argument(
'-W', '--width', type=int, help='Image width, multiple of 64'
)
parser.add_argument(
'-H', '--height', type=int, help='Image height, multiple of 64'
)
parser.add_argument(
'-C',
'--cfg_scale',
default=7.5,
type=float,
help='Classifier free guidance (CFG) scale - higher numbers cause generator to "try" harder.',
)
parser.add_argument(
'-g', '--grid', action='store_true', help='generate a grid'
)
parser.add_argument(
'--outdir',
'-o',
type=str,
default=None,
help='Directory to save generated images and a log of prompts and seeds',
)
parser.add_argument(
'--seamless',
action='store_true',
help='Change the model to seamless tiling (circular) mode',
)
parser.add_argument(
'-i',
'--individual',
action='store_true',
help='Generate individual files (default)',
)
parser.add_argument(
'-I',
'--init_img',
type=str,
help='Path to input image for img2img mode (supersedes width and height)',
)
parser.add_argument(
'-M',
'--init_mask',
type=str,
help='Path to input mask for inpainting mode (supersedes width and height)',
)
parser.add_argument(
'-T',
'-fit',
'--fit',
action='store_true',
help='If specified, will resize the input image to fit within the dimensions of width x height (512x512 default)',
)
parser.add_argument(
'-f',
'--strength',
default=0.75,
type=float,
help='Strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
)
parser.add_argument(
'-G',
'--gfpgan_strength',
default=0,
type=float,
help='The strength at which to apply the GFPGAN model to the result, in order to improve faces.',
)
parser.add_argument(
'-U',
'--upscale',
nargs='+',
default=None,
type=float,
help='Scale factor (2, 4) for upscaling final output followed by upscaling strength (0-1.0). If strength not specified, defaults to 0.75'
)
parser.add_argument(
'-save_orig',
'--save_original',
action='store_true',
help='Save original. Use it when upscaling to save both versions.',
)
parser.add_argument(
'-embiggen',
'--embiggen',
nargs='+',
default=None,
type=float,
help='Embiggen tiled img2img for higher resolution and detail without extra VRAM usage. Takes scale factor relative to the size of the --init_img (-I), followed by ESRGAN upscaling strength (0-1.0), followed by minimum amount of overlap between tiles as a decimal ratio (0 - 1.0) or number of pixels. ESRGAN strength defaults to 0.75, and overlap defaults to 0.25 . ESRGAN is used to upscale the init prior to cutting it into tiles/pieces to run through img2img and then stitch back togeather.',
)
parser.add_argument(
'-embiggen_tiles',
'--embiggen_tiles',
nargs='+',
default=None,
type=int,
help='If while doing Embiggen we are altering only parts of the image, takes a list of tiles by number to process and replace onto the image e.g. `1 3 5`, useful for redoing problematic spots from a prior Embiggen run',
)
# variants is going to be superseded by a generalized "prompt-morph" function
# parser.add_argument('-v','--variants',type=int,help="in img2img mode, the first generated image will get passed back to img2img to generate the requested number of variants")
parser.add_argument(
'-x',
'--skip_normalize',
action='store_true',
help='Skip subprompt weight normalization',
)
parser.add_argument(
'-A',
'-m',
'--sampler',
dest='sampler_name',
default=None,
type=str,
choices=SAMPLER_CHOICES,
metavar='SAMPLER_NAME',
help=f'Switch to a different sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
)
parser.add_argument(
'-t',
'--log_tokenization',
action='store_true',
help='shows how the prompt is split into tokens'
)
parser.add_argument(
'-v',
'--variation_amount',
default=0.0,
type=float,
help='If > 0, generates variations on the initial seed instead of random seeds per iteration. Must be between 0 and 1. Higher values will be more different.'
)
parser.add_argument(
'-V',
'--with_variations',
default=None,
type=str,
help='list of variations to apply, in the format `seed:weight,seed:weight,...'
)
return parser
if __name__ == '__main__':
main()

22
scripts/sd-metadata.py Normal file
View File

@ -0,0 +1,22 @@
#!/usr/bin/env python
import sys
import json
from ldm.dream.pngwriter import retrieve_metadata
if len(sys.argv) < 2:
print("Usage: file2prompt.py <file1.png> <file2.png> <file3.png>...")
print("This script opens up the indicated dream.py-generated PNG file(s) and prints out their metadata.")
exit(-1)
filenames = sys.argv[1:]
for f in filenames:
try:
metadata = retrieve_metadata(f)
print(f'{f}:\n',json.dumps(metadata, indent=4))
except FileNotFoundError:
sys.stderr.write(f'{f} not found\n')
continue
except PermissionError:
sys.stderr.write(f'{f} could not be opened due to inadequate permissions\n')
continue