From 91330878503ad1140fb2b2fbd523f01804a3c212 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 24 Aug 2022 17:52:34 -0400 Subject: [PATCH 1/7] first draft at big refactoring; will be broken --- ldm/dream_util.py | 143 +++++++++++++++++++++++ ldm/simplet2i.py | 288 ++++++++++++++-------------------------------- scripts/dream.py | 130 ++------------------- 3 files changed, 238 insertions(+), 323 deletions(-) create mode 100644 ldm/dream_util.py diff --git a/ldm/dream_util.py b/ldm/dream_util.py new file mode 100644 index 0000000000..1526223cd8 --- /dev/null +++ b/ldm/dream_util.py @@ -0,0 +1,143 @@ +'''Utilities for dealing with PNG images and their path names''' +import os +import atexit +from PIL import Image,PngImagePlugin + +# ---------------readline utilities--------------------- +try: + import readline + readline_available = True +except: + readline_available = False + +class Completer(): + def __init__(self,options): + self.options = sorted(options) + return + + def complete(self,text,state): + buffer = readline.get_line_buffer() + + if text.startswith(('-I','--init_img')): + return self._path_completions(text,state,('.png')) + + if buffer.strip().endswith('cd') or text.startswith(('.','/')): + return self._path_completions(text,state,()) + + response = None + if state == 0: + # This is the first time for this text, so build a match list. + if text: + self.matches = [s + for s in self.options + if s and s.startswith(text)] + else: + self.matches = self.options[:] + + # Return the state'th item from the match list, + # if we have that many. + try: + response = self.matches[state] + except IndexError: + response = None + return response + + def _path_completions(self,text,state,extensions): + # get the path so far + if text.startswith('-I'): + path = text.replace('-I','',1).lstrip() + elif text.startswith('--init_img='): + path = text.replace('--init_img=','',1).lstrip() + else: + path = text + + matches = list() + + path = os.path.expanduser(path) + if len(path)==0: + matches.append(text+'./') + else: + dir = os.path.dirname(path) + dir_list = os.listdir(dir) + for n in dir_list: + if n.startswith('.') and len(n)>1: + continue + full_path = os.path.join(dir,n) + if full_path.startswith(path): + if os.path.isdir(full_path): + matches.append(os.path.join(os.path.dirname(text),n)+'/') + elif n.endswith(extensions): + matches.append(os.path.join(os.path.dirname(text),n)) + + try: + response = matches[state] + except IndexError: + response = None + return response + +if readline_available: + readline.set_completer(Completer(['cd','pwd', + '--steps','-s','--seed','-S','--iterations','-n','--batch_size','-b', + '--width','-W','--height','-H','--cfg_scale','-C','--grid','-g', + '--individual','-i','--init_img','-I','--strength','-f','-v','--variants']).complete) + readline.set_completer_delims(" ") + readline.parse_and_bind('tab: complete') + + histfile = os.path.join(os.path.expanduser('~'),".dream_history") + try: + readline.read_history_file(histfile) + readline.set_history_length(1000) + except FileNotFoundError: + pass + atexit.register(readline.write_history_file,histfile) + +# -------------------image generation utils----- +class PngWriter: + + def __init__(self,opt): + self.opt = opt + self.filepath = None + self.files_written = [] + + def write_image(self,image,seed): + self.filepath = self.unique_filename(self,opt,seed,self.filepath) # will increment name in some sensible way + try: + image.save(self.filename) + except IOError as e: + print(e) + self.files_written.append([self.filepath,seed]) + + def unique_filename(self,opt,seed,previouspath): + revision = 1 + + if previouspath is None: + # sort reverse alphabetically until we find max+1 + dirlist = sorted(os.listdir(outdir),reverse=True) + # find the first filename that matches our pattern or return 000000.0.png + filename = next((f for f in dirlist if re.match('^(\d+)\..*\.png',f)),'0000000.0.png') + basecount = int(filename.split('.',1)[0]) + basecount += 1 + if opt.batch_size > 1: + filename = f'{basecount:06}.{seed}.01.png' + else: + filename = f'{basecount:06}.{seed}.png' + return os.path.join(outdir,filename) + + else: + basename = os.path.basename(previouspath) + x = re.match('^(\d+)\..*\.png',basename) + if not x: + return self.unique_filename(opt,seed,previouspath) + + basecount = int(x.groups()[0]) + series = 0 + finished = False + while not finished: + series += 1 + filename = f'{basecount:06}.{seed}.png' + if isbatch or os.path.exists(os.path.join(outdir,filename)): + filename = f'{basecount:06}.{seed}.{series:02}.png' + finished = not os.path.exists(os.path.join(outdir,filename)) + return os.path.join(outdir,filename) + + diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index 4737d90ba7..8e8b077922 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -23,7 +23,6 @@ t2i = T2I(outdir = // outputs/txt2img-samples width = // image width, multiple of 64 (512) height = // image height, multiple of 64 (512) cfg_scale = // unconditional guidance scale (7.5) - fixed_code = // False ) # do the slow model initialization @@ -79,7 +78,6 @@ class T2I: """T2I class Attributes ---------- - outdir model config iterations @@ -87,12 +85,9 @@ class T2I: steps seed sampler_name - grid - individual width height cfg_scale - fixed_code latent_channels downsampling_factor precision @@ -102,11 +97,8 @@ class T2I: The vast majority of these arguments default to reasonable values. """ def __init__(self, - outdir="outputs/txt2img-samples", batch_size=1, iterations = 1, - width=512, - height=512, grid=False, individual=None, # redundant steps=50, @@ -118,7 +110,6 @@ The vast majority of these arguments default to reasonable values. latent_channels=4, downsampling_factor=8, ddim_eta=0.0, # deterministic - fixed_code=False, precision='autocast', full_precision=False, strength=0.75, # default in scripts/img2img.py @@ -126,7 +117,6 @@ The vast majority of these arguments default to reasonable values. latent_diffusion_weights=False, # just to keep track of this parameter when regenerating prompt device='cuda' ): - self.outdir = outdir self.batch_size = batch_size self.iterations = iterations self.width = width @@ -137,7 +127,6 @@ The vast majority of these arguments default to reasonable values. self.weights = weights self.config = config self.sampler_name = sampler_name - self.fixed_code = fixed_code self.latent_channels = latent_channels self.downsampling_factor = downsampling_factor self.ddim_eta = ddim_eta @@ -154,16 +143,25 @@ The vast majority of these arguments default to reasonable values. else: self.seed = seed - @torch.no_grad() - def txt2img(self,prompt,outdir=None,batch_size=None,iterations=None, - steps=None,seed=None,grid=None,individual=None,width=None,height=None, - cfg_scale=None,ddim_eta=None,strength=None,embedding_path=None,init_img=None, - skip_normalize=False,variants=None): # note the "variants" option is an unused hack caused by how options are passed - """ - Generate an image from the prompt, writing iteration images into the outdir - The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...] - """ - outdir = outdir or self.outdir + def generate(self, + # these are common + prompt, + batch_size=None, + iterations=None, + steps=None, + seed=None, + cfg_scale=None, + ddim_eta=None, + skip_normalize=False, + image_callback=None, + # these are specific to txt2img + width=None, + height=None, + # these are specific to img2img + init_img=None, + strength=None, + variants=None): + '''ldm.generate() is the common entry point for txt2img() and img2img()''' steps = steps or self.steps seed = seed or self.seed width = width or self.width @@ -172,41 +170,57 @@ The vast majority of these arguments default to reasonable values. ddim_eta = ddim_eta or self.ddim_eta batch_size = batch_size or self.batch_size iterations = iterations or self.iterations - strength = strength or self.strength # not actually used here, but preserved for code refactoring - embedding_path = embedding_path or self.embedding_path + strength = strength or self.strength model = self.load_model() # will instantiate the model or return it from cache - - assert strength<1.0 and strength>=0.0, "strength (-f) must be >=0.0 and <1.0" assert cfg_scale>1.0, "CFG_Scale (-C) must be >1.0" + assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]' - # grid and individual are mutually exclusive, with individual taking priority. - # not necessary, but needed for compatability with dream bot - if (grid is None): - grid = self.grid - if individual: - grid = False - data = [batch_size * [prompt]] + scope = autocast if self.precision=="autocast" else nullcontext + if grid: + callback = self.image2png + else: + callback = None - # make directories and establish names for the output files - os.makedirs(outdir, exist_ok=True) + tic = time.time() + if init_img: + assert os.path.exists(init_img),f'{init_img}: File not found' + results = self._img2img(prompt, + data=data,precision_scope=scope, + batch_size=batch_size,iterations=iterations, + steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta, + skip_normalize=skip_normalize, + init_img=init_img,strength=strength,variants=variants, + callback=image_callback) + else: + results = self._txt2img(prompt, + data=data,precision_scope=scope, + batch_size=batch_size,iterations=iterations, + steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta, + skip_normalize=skip_normalize, + width=width,height=height, + callback=image_callback) + toc = time.time() + print(f'{len(results)} images generated in',"%4.2fs"% (toc-tic)) + return results + + @torch.no_grad() + def _txt2img(self,prompt, + data,precision_scope, + batch_size,iterations, + steps,seed,cfg_scale,ddim_eta, + skip_normalize, + width,height, + callback=callback): # the callback is called each time a new Image is generated + """ + Generate an image from the prompt, writing iteration images into the outdir + The output is a list of lists in the format: [[image1,seed1], [image2,seed2],...] + """ - start_code = None - if self.fixed_code: - start_code = torch.randn([batch_size, - self.latent_channels, - height // self.downsampling_factor, - width // self.downsampling_factor], - device=self.device) - - precision_scope = autocast if self.precision=="autocast" else nullcontext sampler = self.sampler images = list() - seeds = list() - filename = None image_count = 0 - tic = time.time() # Gawd. Too many levels of indent here. Need to refactor into smaller routines! try: @@ -239,38 +253,24 @@ The vast majority of these arguments default to reasonable values. shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor] samples_ddim, _ = sampler.sample(S=steps, - conditioning=c, - batch_size=batch_size, - shape=shape, - verbose=False, - unconditional_guidance_scale=cfg_scale, - unconditional_conditioning=uc, - eta=ddim_eta, - x_T=start_code) + conditioning=c, + batch_size=batch_size, + shape=shape, + verbose=False, + unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=uc, + eta=ddim_eta) x_samples_ddim = model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - - if not grid: - for x_sample in x_samples_ddim: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - filename = self._unique_filename(outdir,previousname=filename, - seed=seed,isbatch=(batch_size>1)) - assert not os.path.exists(filename) - Image.fromarray(x_sample.astype(np.uint8)).save(filename) - images.append([filename,seed]) - else: - all_samples.append(x_samples_ddim) - seeds.append(seed) - - image_count += 1 + for x_sample in x_samples_ddim: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + image = Image.fromarray(x_sample.astype(np.uint8)) + images.append([image,seed]) + if callback is not None: + callback(image,seed) + seed = self._new_seed() - if grid: - images = self._make_grid(samples=all_samples, - seeds=seeds, - batch_size=batch_size, - iterations=iterations, - outdir=outdir) except KeyboardInterrupt: print('*interrupted*') print('Partial results will be returned; if --grid was requested, nothing will be returned.') @@ -279,48 +279,20 @@ The vast majority of these arguments default to reasonable values. toc = time.time() print(f'{image_count} images generated in',"%4.2fs"% (toc-tic)) - return images - # There is lots of shared code between this and txt2img and should be refactored. @torch.no_grad() - def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=None, - steps=None,seed=None,grid=None,individual=None,width=None,height=None, - cfg_scale=None,ddim_eta=None,strength=None,embedding_path=None, - skip_normalize=False,variants=None): # note the "variants" option is an unused hack caused by how options are passed + def _img2img(self,prompt, + data,precision_scope, + batch_size,iterations, + steps,seed,cfg_scale,ddim_eta, + skip_normalize, + init_img,strength,variants, + callback): """ Generate an image from the prompt and the initial image, writing iteration images into the outdir - The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...] + The output is a list of lists in the format: [[image,seed1], [image,seed2],...] """ - outdir = outdir or self.outdir - steps = steps or self.steps - seed = seed or self.seed - cfg_scale = cfg_scale or self.cfg_scale - ddim_eta = ddim_eta or self.ddim_eta - batch_size = batch_size or self.batch_size - iterations = iterations or self.iterations - strength = strength or self.strength - embedding_path = embedding_path or self.embedding_path - - assert strength<1.0 and strength>=0.0, "strength (-f) must be >=0.0 and <1.0" - assert cfg_scale>1.0, "CFG_Scale (-C) must be >1.0" - - if init_img is None: - print("no init_img provided!") - return [] - - model = self.load_model() # will instantiate the model or return it from cache - - precision_scope = autocast if self.precision=="autocast" else nullcontext - - # grid and individual are mutually exclusive, with individual taking priority. - # not necessary, but needed for compatability with dream bot - if (grid is None): - grid = self.grid - if individual: - grid = False - - data = [batch_size * [prompt]] # PLMS sampler not supported yet, so ignore previous sampler if self.sampler_name!='ddim': @@ -329,33 +301,18 @@ The vast majority of these arguments default to reasonable values. else: sampler = self.sampler - # make directories and establish names for the output files - os.makedirs(outdir, exist_ok=True) - - assert os.path.isfile(init_img) init_image = self._load_img(init_img).to(self.device) init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) with precision_scope(self.device.type): init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False) - - try: - assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]' - except AssertionError: - print(f"strength must be between 0.0 and 1.0, but received value {strength}") - return [] t_enc = int(strength * steps) print(f"target t_enc is {t_enc} steps") images = list() - seeds = list() - filename = None - image_count = 0 # actual number of iterations performed - tic = time.time() - # Gawd. Too many levels of indent here. Need to refactor into smaller routines! try: with precision_scope(self.device.type), model.ema_scope(): all_samples = list() @@ -393,25 +350,13 @@ The vast majority of these arguments default to reasonable values. x_samples = model.decode_first_stage(samples) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) - if not grid: - for x_sample in x_samples: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - filename = self._unique_filename(outdir,previousname=filename, - seed=seed,isbatch=(batch_size>1)) - assert not os.path.exists(filename) - Image.fromarray(x_sample.astype(np.uint8)).save(filename) - images.append([filename,seed]) - else: - all_samples.append(x_samples) - seeds.append(seed) - image_count +=1 + for x_sample in x_samples: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + image = Image.fromarray(x_sample.astype(np.uint8)) + images.append([image,seed]) + if callback is not None: + callback(image,seed) seed = self._new_seed() - if grid: - images = self._make_grid(samples=all_samples, - seeds=seeds, - batch_size=batch_size, - iterations=iterations, - outdir=outdir) except KeyboardInterrupt: print('*interrupted*') @@ -419,26 +364,6 @@ The vast majority of these arguments default to reasonable values. except RuntimeError as e: print("Oops! A runtime error has occurred. If this is unexpected, please copy-and-paste this stack trace and post it as an Issue to http://github.com/lstein/stable-diffusion") traceback.print_exc() - - toc = time.time() - print(f'{image_count} images generated in',"%4.2fs"% (toc-tic)) - - return images - - def _make_grid(self,samples,seeds,batch_size,iterations,outdir): - images = list() - n_rows = batch_size if batch_size>1 else int(math.sqrt(batch_size * iterations)) - # save as grid - grid = torch.stack(samples, 0) - grid = rearrange(grid, 'n b c h w -> (n b) c h w') - grid = make_grid(grid, nrow=n_rows) - - # to image - grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() - filename = self._unique_filename(outdir,seed=seeds[0],grid_count=batch_size*iterations) - Image.fromarray(grid.astype(np.uint8)).save(filename) - for s in seeds: - images.append([filename,s]) return images def _new_seed(self): @@ -513,43 +438,6 @@ The vast majority of these arguments default to reasonable values. image = torch.from_numpy(image) return 2.*image - 1. - def _unique_filename(self,outdir,previousname=None,seed=0,isbatch=False,grid_count=None): - revision = 1 - - if previousname is None: - # sort reverse alphabetically until we find max+1 - dirlist = sorted(os.listdir(outdir),reverse=True) - # find the first filename that matches our pattern or return 000000.0.png - filename = next((f for f in dirlist if re.match('^(\d+)\..*\.png',f)),'0000000.0.png') - basecount = int(filename.split('.',1)[0]) - basecount += 1 - if grid_count is not None: - grid_label = f'grid#1-{grid_count}' - filename = f'{basecount:06}.{seed}.{grid_label}.png' - elif isbatch: - filename = f'{basecount:06}.{seed}.01.png' - else: - filename = f'{basecount:06}.{seed}.png' - - return os.path.join(outdir,filename) - - else: - previousname = os.path.basename(previousname) - x = re.match('^(\d+)\..*\.png',previousname) - if not x: - return self._unique_filename(outdir,previousname,seed) - - basecount = int(x.groups()[0]) - series = 0 - finished = False - while not finished: - series += 1 - filename = f'{basecount:06}.{seed}.png' - if isbatch or os.path.exists(os.path.join(outdir,filename)): - filename = f'{basecount:06}.{seed}.{series:02}.png' - finished = not os.path.exists(os.path.join(outdir,filename)) - return os.path.join(outdir,filename) - def _split_weighted_subprompts(text): """ grabs all text up to the first occurrence of ':' diff --git a/scripts/dream.py b/scripts/dream.py index dc5fad5bac..e0714dbcbd 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -8,13 +8,7 @@ import os import sys import copy from PIL import Image,PngImagePlugin - -# readline unavailable on windows systems -try: - import readline - readline_available = True -except: - readline_available = False +from ldm.dream_util import Completer,PngWriter debugging = False @@ -131,13 +125,13 @@ def main_loop(t2i,parser,log,infile): if elements[0]=='cd' and len(elements)>1: if os.path.exists(elements[1]): print(f"setting image output directory to {elements[1]}") - t2i.outdir=elements[1] + opt.outdir=elements[1] else: print(f"directory {elements[1]} does not exist") continue if elements[0]=='pwd': - print(f"current output directory is {t2i.outdir}") + print(f"current output directory is {opt.outdir}") continue if elements[0].startswith('!dream'): # in case a stored prompt still contains the !dream command @@ -167,47 +161,19 @@ def main_loop(t2i,parser,log,infile): continue try: - if opt.init_img is None: - results = t2i.txt2img(**vars(opt)) - else: - assert os.path.exists(opt.init_img),f"No file found at {opt.init_img}. On Linux systems, pressing after -I will autocomplete a list of possible image files." - if None not in (opt.width,opt.height): - print('Warning: width and height options are ignored when modifying an init image') - results = t2i.img2img(**vars(opt)) + file_writer = PngWriter(opt) + opt.callback = file_writer(write_image) + run_generator(**vars(opt)) + results = file_writer.files_written except AssertionError as e: print(e) continue - - allVariantResults = [] - if opt.variants is not None: - print(f"Generating {opt.variants} variant(s)...") - newopt = copy.deepcopy(opt) - newopt.variants = None - for r in results: - newopt.init_img = r[0] - print(f"\t generating variant for {newopt.init_img}") - for j in range(0, opt.variants): - try: - variantResults = t2i.img2img(**vars(newopt)) - allVariantResults.append([newopt,variantResults]) - except AssertionError as e: - print(e) - continue - print(f"{opt.variants} Variants generated!") - print("Outputs:") write_log_message(t2i,opt,results,log) - - if allVariantResults: - print("Variant outputs:") - for vr in allVariantResults: - write_log_message(t2i,vr[0],vr[1],log) - print("goodbye!") - def write_log_message(t2i,opt,results,logfile): ''' logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata ''' switches = _reconstruct_switches(t2i,opt) @@ -339,89 +305,7 @@ def create_cmd_parser(): parser.add_argument('-x','--skip_normalize',action='store_true',help="skip subprompt weight normalization") return parser -if readline_available: - def setup_readline(): - readline.set_completer(Completer(['cd','pwd', - '--steps','-s','--seed','-S','--iterations','-n','--batch_size','-b', - '--width','-W','--height','-H','--cfg_scale','-C','--grid','-g', - '--individual','-i','--init_img','-I','--strength','-f','-v','--variants']).complete) - readline.set_completer_delims(" ") - readline.parse_and_bind('tab: complete') - load_history() - def load_history(): - histfile = os.path.join(os.path.expanduser('~'),".dream_history") - try: - readline.read_history_file(histfile) - readline.set_history_length(1000) - except FileNotFoundError: - pass - atexit.register(readline.write_history_file,histfile) - - class Completer(): - def __init__(self,options): - self.options = sorted(options) - return - - def complete(self,text,state): - buffer = readline.get_line_buffer() - - if text.startswith(('-I','--init_img')): - return self._path_completions(text,state,('.png')) - - if buffer.strip().endswith('cd') or text.startswith(('.','/')): - return self._path_completions(text,state,()) - - response = None - if state == 0: - # This is the first time for this text, so build a match list. - if text: - self.matches = [s - for s in self.options - if s and s.startswith(text)] - else: - self.matches = self.options[:] - - # Return the state'th item from the match list, - # if we have that many. - try: - response = self.matches[state] - except IndexError: - response = None - return response - - def _path_completions(self,text,state,extensions): - # get the path so far - if text.startswith('-I'): - path = text.replace('-I','',1).lstrip() - elif text.startswith('--init_img='): - path = text.replace('--init_img=','',1).lstrip() - else: - path = text - - matches = list() - - path = os.path.expanduser(path) - if len(path)==0: - matches.append(text+'./') - else: - dir = os.path.dirname(path) - dir_list = os.listdir(dir) - for n in dir_list: - if n.startswith('.') and len(n)>1: - continue - full_path = os.path.join(dir,n) - if full_path.startswith(path): - if os.path.isdir(full_path): - matches.append(os.path.join(os.path.dirname(text),n)+'/') - elif n.endswith(extensions): - matches.append(os.path.join(os.path.dirname(text),n)) - - try: - response = matches[state] - except IndexError: - response = None - return response if __name__ == "__main__": main() From b12955c9631ceec70154690eef072b99b9c71c32 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 24 Aug 2022 17:57:44 -0400 Subject: [PATCH 2/7] remove unneeded imports from dream.py --- scripts/dream.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/scripts/dream.py b/scripts/dream.py index e0714dbcbd..6ff7802fa2 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -3,11 +3,9 @@ import argparse import shlex -import atexit import os import sys import copy -from PIL import Image,PngImagePlugin from ldm.dream_util import Completer,PngWriter debugging = False From b978536385613ab2c4bc4660633337e730020d75 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 24 Aug 2022 19:47:59 -0400 Subject: [PATCH 3/7] code is reorganized and mostly functional. Grid needs to be brought back online, as well as naming of img2img variants (currently the variants get written but not logged) --- ldm/dream_util.py | 64 ++++++++++++++++++++++++++-------- ldm/simplet2i.py | 76 +++++++++++++++++++--------------------- scripts/dream.py | 89 +++++++++-------------------------------------- 3 files changed, 101 insertions(+), 128 deletions(-) diff --git a/ldm/dream_util.py b/ldm/dream_util.py index 1526223cd8..ceab2940b1 100644 --- a/ldm/dream_util.py +++ b/ldm/dream_util.py @@ -1,6 +1,7 @@ '''Utilities for dealing with PNG images and their path names''' import os import atexit +import re from PIL import Image,PngImagePlugin # ---------------readline utilities--------------------- @@ -94,40 +95,43 @@ if readline_available: # -------------------image generation utils----- class PngWriter: - def __init__(self,opt): - self.opt = opt - self.filepath = None - self.files_written = [] + def __init__(self,outdir,opt,prompt): + self.outdir = outdir + self.opt = opt + self.prompt = prompt + self.filepath = None + self.files_written = [] def write_image(self,image,seed): - self.filepath = self.unique_filename(self,opt,seed,self.filepath) # will increment name in some sensible way + self.filepath = self.unique_filename(seed,self.filepath) # will increment name in some sensible way try: - image.save(self.filename) + prompt = f'{self.prompt} -S{seed}' + self.save_image_and_prompt_to_png(image,prompt,self.filepath) except IOError as e: print(e) self.files_written.append([self.filepath,seed]) - def unique_filename(self,opt,seed,previouspath): + def unique_filename(self,seed,previouspath): revision = 1 if previouspath is None: # sort reverse alphabetically until we find max+1 - dirlist = sorted(os.listdir(outdir),reverse=True) + dirlist = sorted(os.listdir(self.outdir),reverse=True) # find the first filename that matches our pattern or return 000000.0.png filename = next((f for f in dirlist if re.match('^(\d+)\..*\.png',f)),'0000000.0.png') basecount = int(filename.split('.',1)[0]) basecount += 1 - if opt.batch_size > 1: + if self.opt.batch_size > 1: filename = f'{basecount:06}.{seed}.01.png' else: filename = f'{basecount:06}.{seed}.png' - return os.path.join(outdir,filename) + return os.path.join(self.outdir,filename) else: basename = os.path.basename(previouspath) x = re.match('^(\d+)\..*\.png',basename) if not x: - return self.unique_filename(opt,seed,previouspath) + return self.unique_filename(seed,previouspath) basecount = int(x.groups()[0]) series = 0 @@ -135,9 +139,41 @@ class PngWriter: while not finished: series += 1 filename = f'{basecount:06}.{seed}.png' - if isbatch or os.path.exists(os.path.join(outdir,filename)): + if self.opt.batch_size>1 or os.path.exists(os.path.join(self.outdir,filename)): filename = f'{basecount:06}.{seed}.{series:02}.png' - finished = not os.path.exists(os.path.join(outdir,filename)) - return os.path.join(outdir,filename) + finished = not os.path.exists(os.path.join(self.outdir,filename)) + return os.path.join(self.outdir,filename) + def save_image_and_prompt_to_png(self,image,prompt,path): + info = PngImagePlugin.PngInfo() + info.add_text("Dream",prompt) + image.save(path,"PNG",pnginfo=info) + +class PromptFormatter(): + def __init__(self,t2i,opt): + self.t2i = t2i + self.opt = opt + + def normalize_prompt(self): + '''Normalize the prompt and switches''' + t2i = self.t2i + opt = self.opt + + switches = list() + switches.append(f'"{opt.prompt}"') + switches.append(f'-s{opt.steps or t2i.steps}') + switches.append(f'-b{opt.batch_size or t2i.batch_size}') + switches.append(f'-W{opt.width or t2i.width}') + switches.append(f'-H{opt.height or t2i.height}') + switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}') + switches.append(f'-m{t2i.sampler_name}') + if opt.variants: + switches.append(f'-v{opt.variants}') + if opt.init_img: + switches.append(f'-I{opt.init_img}') + if opt.strength and opt.init_img is not None: + switches.append(f'-f{opt.strength or t2i.strength}') + if t2i.full_precision: + switches.append('-F') + return ' '.join(switches) diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index 8e8b077922..3b5aaeb696 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -99,13 +99,13 @@ The vast majority of these arguments default to reasonable values. def __init__(self, batch_size=1, iterations = 1, - grid=False, - individual=None, # redundant steps=50, seed=None, cfg_scale=7.5, weights="models/ldm/stable-diffusion-v1/model.ckpt", config = "configs/stable-diffusion/v1-inference.yaml", + width=512, + height=512, sampler_name="klms", latent_channels=4, downsampling_factor=8, @@ -121,7 +121,6 @@ The vast majority of these arguments default to reasonable values. self.iterations = iterations self.width = width self.height = height - self.grid = grid self.steps = steps self.cfg_scale = cfg_scale self.weights = weights @@ -143,25 +142,26 @@ The vast majority of these arguments default to reasonable values. else: self.seed = seed - def generate(self, - # these are common - prompt, - batch_size=None, - iterations=None, - steps=None, - seed=None, - cfg_scale=None, - ddim_eta=None, - skip_normalize=False, - image_callback=None, - # these are specific to txt2img - width=None, - height=None, - # these are specific to img2img - init_img=None, - strength=None, - variants=None): - '''ldm.generate() is the common entry point for txt2img() and img2img()''' + def prompt2image(self, + # these are common + prompt, + batch_size=None, + iterations=None, + steps=None, + seed=None, + cfg_scale=None, + ddim_eta=None, + skip_normalize=False, + image_callback=None, + # these are specific to txt2img + width=None, + height=None, + # these are specific to img2img + init_img=None, + strength=None, + variants=None, + **args): # eat up additional cruft + '''ldm.prompt2image() is the common entry point for txt2img() and img2img()''' steps = steps or self.steps seed = seed or self.seed width = width or self.width @@ -178,10 +178,6 @@ The vast majority of these arguments default to reasonable values. data = [batch_size * [prompt]] scope = autocast if self.precision=="autocast" else nullcontext - if grid: - callback = self.image2png - else: - callback = None tic = time.time() if init_img: @@ -212,7 +208,7 @@ The vast majority of these arguments default to reasonable values. steps,seed,cfg_scale,ddim_eta, skip_normalize, width,height, - callback=callback): # the callback is called each time a new Image is generated + callback): # the callback is called each time a new Image is generated """ Generate an image from the prompt, writing iteration images into the outdir The output is a list of lists in the format: [[image1,seed1], [image2,seed2],...] @@ -224,14 +220,14 @@ The vast majority of these arguments default to reasonable values. # Gawd. Too many levels of indent here. Need to refactor into smaller routines! try: - with precision_scope(self.device.type), model.ema_scope(): + with precision_scope(self.device.type), self.model.ema_scope(): all_samples = list() for n in trange(iterations, desc="Sampling"): seed_everything(seed) for prompts in tqdm(data, desc="data", dynamic_ncols=True): uc = None if cfg_scale != 1.0: - uc = model.get_learned_conditioning(batch_size * [""]) + uc = self.model.get_learned_conditioning(batch_size * [""]) if isinstance(prompts, tuple): prompts = list(prompts) @@ -247,9 +243,9 @@ The vast majority of these arguments default to reasonable values. weight = weights[i] if not skip_normalize: weight = weight / totalWeight - c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight) + c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight) else: # just standard 1 prompt - c = model.get_learned_conditioning(prompts) + c = self.model.get_learned_conditioning(prompts) shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor] samples_ddim, _ = sampler.sample(S=steps, @@ -261,7 +257,7 @@ The vast majority of these arguments default to reasonable values. unconditional_conditioning=uc, eta=ddim_eta) - x_samples_ddim = model.decode_first_stage(samples_ddim) + x_samples_ddim = self.model.decode_first_stage(samples_ddim) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) for x_sample in x_samples_ddim: x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') @@ -277,8 +273,6 @@ The vast majority of these arguments default to reasonable values. except RuntimeError as e: print(str(e)) - toc = time.time() - print(f'{image_count} images generated in',"%4.2fs"% (toc-tic)) return images @torch.no_grad() @@ -297,14 +291,14 @@ The vast majority of these arguments default to reasonable values. # PLMS sampler not supported yet, so ignore previous sampler if self.sampler_name!='ddim': print(f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler") - sampler = DDIMSampler(model, device=self.device) + sampler = DDIMSampler(self.model, device=self.device) else: sampler = self.sampler init_image = self._load_img(init_img).to(self.device) init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) with precision_scope(self.device.type): - init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space + init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(init_image)) # move to latent space sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False) @@ -314,14 +308,14 @@ The vast majority of these arguments default to reasonable values. images = list() try: - with precision_scope(self.device.type), model.ema_scope(): + with precision_scope(self.device.type), self.model.ema_scope(): all_samples = list() for n in trange(iterations, desc="Sampling"): seed_everything(seed) for prompts in tqdm(data, desc="data", dynamic_ncols=True): uc = None if cfg_scale != 1.0: - uc = model.get_learned_conditioning(batch_size * [""]) + uc = self.model.get_learned_conditioning(batch_size * [""]) if isinstance(prompts, tuple): prompts = list(prompts) @@ -337,9 +331,9 @@ The vast majority of these arguments default to reasonable values. weight = weights[i] if not skip_normalize: weight = weight / totalWeight - c = torch.add(c,model.get_learned_conditioning(subprompts[i]), alpha=weight) + c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight) else: # just standard 1 prompt - c = model.get_learned_conditioning(prompts) + c = self.model.get_learned_conditioning(prompts) # encode (scaled latent) z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device)) @@ -347,7 +341,7 @@ The vast majority of these arguments default to reasonable values. samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc,) - x_samples = model.decode_first_stage(samples) + x_samples = self.model.decode_first_stage(samples) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) for x_sample in x_samples: diff --git a/scripts/dream.py b/scripts/dream.py index 6ff7802fa2..ab01e8db01 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -6,7 +6,7 @@ import shlex import os import sys import copy -from ldm.dream_util import Completer,PngWriter +from ldm.dream_util import Completer,PngWriter,PromptFormatter debugging = False @@ -27,10 +27,6 @@ def main(): config = "configs/stable-diffusion/v1-inference.yaml" weights = "models/ldm/stable-diffusion-v1/model.ckpt" - # command line history will be stored in a file called "~/.dream_history" - if readline_available: - setup_readline() - print("* Initializing, be patient...\n") sys.path.append('.') from pytorch_lightning import logging @@ -46,8 +42,6 @@ def main(): # the user input loop t2i = T2I(width=width, height=height, - batch_size=opt.batch_size, - outdir=opt.outdir, sampler_name=opt.sampler_name, weights=weights, full_precision=opt.full_precision, @@ -79,13 +73,13 @@ def main(): log_path = os.path.join(opt.outdir,'dream_log.txt') with open(log_path,'a') as log: cmd_parser = create_cmd_parser() - main_loop(t2i,cmd_parser,log,infile) + main_loop(t2i,opt.outdir,cmd_parser,log,infile) log.close() if infile: infile.close() -def main_loop(t2i,parser,log,infile): +def main_loop(t2i,outdir,parser,log,infile): ''' prompt/read/execute loop ''' done = False @@ -123,13 +117,13 @@ def main_loop(t2i,parser,log,infile): if elements[0]=='cd' and len(elements)>1: if os.path.exists(elements[1]): print(f"setting image output directory to {elements[1]}") - opt.outdir=elements[1] + outdir=elements[1] else: print(f"directory {elements[1]} does not exist") continue if elements[0]=='pwd': - print(f"current output directory is {opt.outdir}") + print(f"current output directory is {outdir}") continue if elements[0].startswith('!dream'): # in case a stored prompt still contains the !dream command @@ -158,88 +152,41 @@ def main_loop(t2i,parser,log,infile): print("Try again with a prompt!") continue + normalized_prompt = PromptFormatter(t2i,opt).normalize_prompt() try: - file_writer = PngWriter(opt) - opt.callback = file_writer(write_image) - run_generator(**vars(opt)) + file_writer = PngWriter(outdir,opt,normalized_prompt) + callback = file_writer.write_image + + t2i.prompt2image(image_callback=callback, + **vars(opt)) results = file_writer.files_written + except AssertionError as e: print(e) continue print("Outputs:") - write_log_message(t2i,opt,results,log) + write_log_message(t2i,normalized_prompt,results,log) print("goodbye!") -def write_log_message(t2i,opt,results,logfile): +def write_log_message(t2i,prompt,results,logfile): ''' logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata ''' - switches = _reconstruct_switches(t2i,opt) - prompt_str = ' '.join(switches) - - # when multiple images are produced in batch, then we keep track of where each starts last_seed = None img_num = 1 - batch_size = opt.batch_size or t2i.batch_size seenit = {} seeds = [a[1] for a in results] - if batch_size > 1: - seeds = f"(seeds for each batch row: {seeds})" - else: - seeds = f"(seeds for individual images: {seeds})" + seeds = f"(seeds for individual images: {seeds})" for r in results: seed = r[1] - log_message = (f'{r[0]}: {prompt_str} -S{seed}') + log_message = (f'{r[0]}: {prompt} -S{seed}') - if batch_size > 1: - if seed != last_seed: - img_num = 1 - log_message += f' # (batch image {img_num} of {batch_size})' - else: - img_num += 1 - log_message += f' # (batch image {img_num} of {batch_size})' - last_seed = seed print(log_message) logfile.write(log_message+"\n") logfile.flush() - if r[0] not in seenit: - seenit[r[0]] = True - try: - if opt.grid: - _write_prompt_to_png(r[0],f'{prompt_str} -g -S{seed} {seeds}') - else: - _write_prompt_to_png(r[0],f'{prompt_str} -S{seed}') - except FileNotFoundError: - print(f"Could not open file '{r[0]}' for reading") -def _reconstruct_switches(t2i,opt): - '''Normalize the prompt and switches''' - switches = list() - switches.append(f'"{opt.prompt}"') - switches.append(f'-s{opt.steps or t2i.steps}') - switches.append(f'-b{opt.batch_size or t2i.batch_size}') - switches.append(f'-W{opt.width or t2i.width}') - switches.append(f'-H{opt.height or t2i.height}') - switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}') - switches.append(f'-m{t2i.sampler_name}') - if opt.variants: - switches.append(f'-v{opt.variants}') - if opt.init_img: - switches.append(f'-I{opt.init_img}') - if opt.strength and opt.init_img is not None: - switches.append(f'-f{opt.strength or t2i.strength}') - if t2i.full_precision: - switches.append('-F') - return switches - -def _write_prompt_to_png(path,prompt): - info = PngImagePlugin.PngInfo() - info.add_text("Dream",prompt) - im = Image.open(path) - im.save(path,"PNG",pnginfo=info) - def create_argv_parser(): parser = argparse.ArgumentParser(description="Parse script's command line args") parser.add_argument("--laion400m", @@ -260,10 +207,6 @@ def create_argv_parser(): dest='full_precision', action='store_true', help="use slower full precision math for calculations") - parser.add_argument('-b','--batch_size', - type=int, - default=1, - help="number of images to produce per iteration (faster, but doesn't generate individual seeds") parser.add_argument('--sampler','-m', dest="sampler_name", choices=['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'], From 0b4459b7074060314aaafe2c142370d77ac56a46 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 25 Aug 2022 00:42:37 -0400 Subject: [PATCH 4/7] mostly back to full functionality; just missing grid generation code --- ldm/dream_util.py | 171 ++++++++++++++++++++++---------------------- ldm/simplet2i.py | 175 ++++++++++++++++++++++++++++++++-------------- scripts/dream.py | 37 +++++++++- 3 files changed, 245 insertions(+), 138 deletions(-) diff --git a/ldm/dream_util.py b/ldm/dream_util.py index ceab2940b1..a1d0d3204b 100644 --- a/ldm/dream_util.py +++ b/ldm/dream_util.py @@ -4,6 +4,92 @@ import atexit import re from PIL import Image,PngImagePlugin +# -------------------image generation utils----- +class PngWriter: + + def __init__(self,outdir,prompt=None,batch_size=1): + self.outdir = outdir + self.batch_size = batch_size + self.prompt = prompt + self.filepath = None + self.files_written = [] + os.makedirs(outdir, exist_ok=True) + + def write_image(self,image,seed): + self.filepath = self.unique_filename(seed,self.filepath) # will increment name in some sensible way + try: + prompt = f'{self.prompt} -S{seed}' + self.save_image_and_prompt_to_png(image,prompt,self.filepath) + except IOError as e: + print(e) + self.files_written.append([self.filepath,seed]) + + def unique_filename(self,seed,previouspath): + revision = 1 + + if previouspath is None: + # sort reverse alphabetically until we find max+1 + dirlist = sorted(os.listdir(self.outdir),reverse=True) + # find the first filename that matches our pattern or return 000000.0.png + filename = next((f for f in dirlist if re.match('^(\d+)\..*\.png',f)),'0000000.0.png') + basecount = int(filename.split('.',1)[0]) + basecount += 1 + if self.batch_size > 1: + filename = f'{basecount:06}.{seed}.01.png' + else: + filename = f'{basecount:06}.{seed}.png' + return os.path.join(self.outdir,filename) + + else: + basename = os.path.basename(previouspath) + x = re.match('^(\d+)\..*\.png',basename) + if not x: + return self.unique_filename(seed,previouspath) + + basecount = int(x.groups()[0]) + series = 0 + finished = False + while not finished: + series += 1 + filename = f'{basecount:06}.{seed}.png' + if self.batch_size>1 or os.path.exists(os.path.join(self.outdir,filename)): + filename = f'{basecount:06}.{seed}.{series:02}.png' + finished = not os.path.exists(os.path.join(self.outdir,filename)) + return os.path.join(self.outdir,filename) + + def save_image_and_prompt_to_png(self,image,prompt,path): + info = PngImagePlugin.PngInfo() + info.add_text("Dream",prompt) + image.save(path,"PNG",pnginfo=info) + +class PromptFormatter(): + def __init__(self,t2i,opt): + self.t2i = t2i + self.opt = opt + + def normalize_prompt(self): + '''Normalize the prompt and switches''' + t2i = self.t2i + opt = self.opt + + switches = list() + switches.append(f'"{opt.prompt}"') + switches.append(f'-s{opt.steps or t2i.steps}') + switches.append(f'-b{opt.batch_size or t2i.batch_size}') + switches.append(f'-W{opt.width or t2i.width}') + switches.append(f'-H{opt.height or t2i.height}') + switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}') + switches.append(f'-m{t2i.sampler_name}') + if opt.variants: + switches.append(f'-v{opt.variants}') + if opt.init_img: + switches.append(f'-I{opt.init_img}') + if opt.strength and opt.init_img is not None: + switches.append(f'-f{opt.strength or t2i.strength}') + if t2i.full_precision: + switches.append('-F') + return ' '.join(switches) + # ---------------readline utilities--------------------- try: import readline @@ -92,88 +178,3 @@ if readline_available: pass atexit.register(readline.write_history_file,histfile) -# -------------------image generation utils----- -class PngWriter: - - def __init__(self,outdir,opt,prompt): - self.outdir = outdir - self.opt = opt - self.prompt = prompt - self.filepath = None - self.files_written = [] - - def write_image(self,image,seed): - self.filepath = self.unique_filename(seed,self.filepath) # will increment name in some sensible way - try: - prompt = f'{self.prompt} -S{seed}' - self.save_image_and_prompt_to_png(image,prompt,self.filepath) - except IOError as e: - print(e) - self.files_written.append([self.filepath,seed]) - - def unique_filename(self,seed,previouspath): - revision = 1 - - if previouspath is None: - # sort reverse alphabetically until we find max+1 - dirlist = sorted(os.listdir(self.outdir),reverse=True) - # find the first filename that matches our pattern or return 000000.0.png - filename = next((f for f in dirlist if re.match('^(\d+)\..*\.png',f)),'0000000.0.png') - basecount = int(filename.split('.',1)[0]) - basecount += 1 - if self.opt.batch_size > 1: - filename = f'{basecount:06}.{seed}.01.png' - else: - filename = f'{basecount:06}.{seed}.png' - return os.path.join(self.outdir,filename) - - else: - basename = os.path.basename(previouspath) - x = re.match('^(\d+)\..*\.png',basename) - if not x: - return self.unique_filename(seed,previouspath) - - basecount = int(x.groups()[0]) - series = 0 - finished = False - while not finished: - series += 1 - filename = f'{basecount:06}.{seed}.png' - if self.opt.batch_size>1 or os.path.exists(os.path.join(self.outdir,filename)): - filename = f'{basecount:06}.{seed}.{series:02}.png' - finished = not os.path.exists(os.path.join(self.outdir,filename)) - return os.path.join(self.outdir,filename) - - def save_image_and_prompt_to_png(self,image,prompt,path): - info = PngImagePlugin.PngInfo() - info.add_text("Dream",prompt) - image.save(path,"PNG",pnginfo=info) - -class PromptFormatter(): - def __init__(self,t2i,opt): - self.t2i = t2i - self.opt = opt - - def normalize_prompt(self): - '''Normalize the prompt and switches''' - t2i = self.t2i - opt = self.opt - - switches = list() - switches.append(f'"{opt.prompt}"') - switches.append(f'-s{opt.steps or t2i.steps}') - switches.append(f'-b{opt.batch_size or t2i.batch_size}') - switches.append(f'-W{opt.width or t2i.width}') - switches.append(f'-H{opt.height or t2i.height}') - switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}') - switches.append(f'-m{t2i.sampler_name}') - if opt.variants: - switches.append(f'-v{opt.variants}') - if opt.init_img: - switches.append(f'-I{opt.init_img}') - if opt.strength and opt.init_img is not None: - switches.append(f'-f{opt.strength or t2i.strength}') - if t2i.full_precision: - switches.append('-F') - return ' '.join(switches) - diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index 3b5aaeb696..91b73def43 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -4,52 +4,6 @@ # Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors - -"""Simplified text to image API for stable diffusion/latent diffusion - -Example Usage: - -from ldm.simplet2i import T2I -# Create an object with default values -t2i = T2I(outdir = // outputs/txt2img-samples - model = // models/ldm/stable-diffusion-v1/model.ckpt - config = // default="configs/stable-diffusion/v1-inference.yaml - iterations = // how many times to run the sampling (1) - batch_size = // how many images to generate per sampling (1) - steps = // 50 - seed = // current system time - sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms - grid = // false - width = // image width, multiple of 64 (512) - height = // image height, multiple of 64 (512) - cfg_scale = // unconditional guidance scale (7.5) - ) - -# do the slow model initialization -t2i.load_model() - -# Do the fast inference & image generation. Any options passed here -# override the default values assigned during class initialization -# Will call load_model() if the model was not previously loaded. -# The method returns a list of images. Each row of the list is a sub-list of [filename,seed] -results = t2i.txt2img(prompt = "an astronaut riding a horse" - outdir = "./outputs/txt2img-samples) - ) - -for row in results: - print(f'filename={row[0]}') - print(f'seed ={row[1]}') - -# Same thing, but using an initial image. -results = t2i.img2img(prompt = "an astronaut riding a horse" - outdir = "./outputs/img2img-samples" - init_img = "./sketches/horse+rider.png") - -for row in results: - print(f'filename={row[0]}') - print(f'seed ={row[1]}') -""" - import torch import numpy as np import random @@ -64,6 +18,7 @@ from torchvision.utils import make_grid from pytorch_lightning import seed_everything from torch import autocast from contextlib import contextmanager, nullcontext +import transformers import time import math import re @@ -73,6 +28,69 @@ from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.ksampler import KSampler +from ldm.dream_util import PngWriter + +"""Simplified text to image API for stable diffusion/latent diffusion + +Example Usage: + +from ldm.simplet2i import T2I + +# Create an object with default values +t2i = T2I(model = // models/ldm/stable-diffusion-v1/model.ckpt + config = // configs/stable-diffusion/v1-inference.yaml + iterations = // how many times to run the sampling (1) + batch_size = // how many images to generate per sampling (1) + steps = // 50 + seed = // current system time + sampler_name= ['ddim', 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', 'k_heun', 'k_lms', 'plms'] // k_lms + grid = // false + width = // image width, multiple of 64 (512) + height = // image height, multiple of 64 (512) + cfg_scale = // unconditional guidance scale (7.5) + ) + +# do the slow model initialization +t2i.load_model() + +# Do the fast inference & image generation. Any options passed here +# override the default values assigned during class initialization +# Will call load_model() if the model was not previously loaded and so +# may be slow at first. +# The method returns a list of images. Each row of the list is a sub-list of [filename,seed] +results = t2i.prompt2png(prompt = "an astronaut riding a horse", + outdir = "./outputs/samples", + iterations = 3) + +for row in results: + print(f'filename={row[0]}') + print(f'seed ={row[1]}') + +# Same thing, but using an initial image. +results = t2i.prompt2png(prompt = "an astronaut riding a horse", + outdir = "./outputs/, + iterations = 3, + init_img = "./sketches/horse+rider.png") + +for row in results: + print(f'filename={row[0]}') + print(f'seed ={row[1]}') + +# Same thing, but we return a series of Image objects, which lets you manipulate them, +# combine them, and save them under arbitrary names + +results = t2i.prompt2image(prompt = "an astronaut riding a horse" + outdir = "./outputs/") +for row in results: + im = row[0] + seed = row[1] + im.save(f'./outputs/samples/an_astronaut_riding_a_horse-{seed}.png') + im.thumbnail(100,100).save('./outputs/samples/astronaut_thumb.jpg') + +Note that the old txt2img() and img2img() calls are deprecated but will +still work. +""" + class T2I: """T2I class @@ -141,7 +159,30 @@ The vast majority of these arguments default to reasonable values. self.seed = self._new_seed() else: self.seed = seed + transformers.logging.set_verbosity_error() + def prompt2png(self,prompt,outdir,**kwargs): + ''' + Takes a prompt and an output directory, writes out the requested number + of PNG files, and returns an array of [[filename,seed],[filename,seed]...] + Optional named arguments are the same as those passed to T2I and prompt2image() + ''' + results = self.prompt2image(prompt,**kwargs) + pngwriter = PngWriter(outdir,prompt,kwargs.get('batch_size',self.batch_size)) + for r in results: + metadata_str = f'prompt2png("{prompt}" {kwargs} seed={r[1]}' # gets written into the PNG + pngwriter.write_image(r[0],r[1]) + return pngwriter.files_written + + def txt2img(self,prompt,**kwargs): + outdir = kwargs.get('outdir','outputs/img-samples') + return self.prompt2png(prompt,outdir,**kwargs) + + def img2img(self,prompt,**kwargs): + outdir = kwargs.get('outdir','outputs/img-samples') + assert 'init_img' in kwargs,'call to img2img() must include the init_img argument' + return self.prompt2png(prompt,outdir,**kwargs) + def prompt2image(self, # these are common prompt, @@ -161,7 +202,34 @@ The vast majority of these arguments default to reasonable values. strength=None, variants=None, **args): # eat up additional cruft - '''ldm.prompt2image() is the common entry point for txt2img() and img2img()''' + ''' + ldm.prompt2image() is the common entry point for txt2img() and img2img() + It takes the following arguments: + prompt // prompt string (no default) + iterations // iterations (1); image count=iterations x batch_size + batch_size // images per iteration (1) + steps // refinement steps per iteration + seed // seed for random number generator + width // width of image, in multiples of 64 (512) + height // height of image, in multiples of 64 (512) + cfg_scale // how strongly the prompt influences the image (7.5) (must be >1) + init_img // path to an initial image - its dimensions override width and height + strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely + ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image) + variants // if >0, the 1st generated image will be passed back to img2img to generate the requested number of variants + callback // a function or method that will be called each time an image is generated + + To use the callback, define a function of method that receives two arguments, an Image object + and the seed. You can then do whatever you like with the image, including converting it to + different formats and manipulating it. For example: + + def process_image(image,seed): + image.save(f{'images/seed.png'}) + + The callback used by the prompt2png() can be found in ldm/dream_util.py. It contains code + to create the requested output directory, select a unique informative name for each image, and + write the prompt into the PNG metadata. + ''' steps = steps or self.steps seed = seed or self.seed width = width or self.width @@ -175,6 +243,12 @@ The vast majority of these arguments default to reasonable values. model = self.load_model() # will instantiate the model or return it from cache assert cfg_scale>1.0, "CFG_Scale (-C) must be >1.0" assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]' + w = int(width/64) * 64 + h = int(height/64) * 64 + if h != height or w != width: + print(f'Height and width must be multiples of 64. Resizing to {h}x{w}') + height = h + width = w data = [batch_size * [prompt]] scope = autocast if self.precision=="autocast" else nullcontext @@ -303,8 +377,7 @@ The vast majority of these arguments default to reasonable values. sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False) t_enc = int(strength * steps) - print(f"target t_enc is {t_enc} steps") - + # print(f"target t_enc is {t_enc} steps") images = list() try: @@ -408,8 +481,8 @@ The vast majority of these arguments default to reasonable values. def _load_model_from_config(self, config, ckpt): print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") +# if "global_step" in pl_sd: +# print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) diff --git a/scripts/dream.py b/scripts/dream.py index ab01e8db01..10acccbfc3 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -153,25 +153,58 @@ def main_loop(t2i,outdir,parser,log,infile): continue normalized_prompt = PromptFormatter(t2i,opt).normalize_prompt() + variants = None + try: - file_writer = PngWriter(outdir,opt,normalized_prompt) + file_writer = PngWriter(outdir,normalized_prompt,opt.batch_size) callback = file_writer.write_image t2i.prompt2image(image_callback=callback, **vars(opt)) results = file_writer.files_written + if None not in (opt.variants,opt.init_img): + variants = generate_variants(t2i,outdir,opt,results) + except AssertionError as e: print(e) continue print("Outputs:") write_log_message(t2i,normalized_prompt,results,log) + if variants is not None: + print('Variants:') + for vr in variants: + write_log_message(t2i,vr[0],vr[1],log) print("goodbye!") +def generate_variants(t2i,outdir,opt,previous_gens): + variants = [] + print(f"Generating {opt.variants} variant(s)...") + newopt = copy.deepcopy(opt) + newopt.iterations = 1 + newopt.variants = None + for r in previous_gens: + newopt.init_img = r[0] + prompt = PromptFormatter(t2i,newopt).normalize_prompt() + print(f"] generating variant for {newopt.init_img}") + for j in range(0,opt.variants): + try: + file_writer = PngWriter(outdir,prompt,newopt.batch_size) + callback = file_writer.write_image + t2i.prompt2image(image_callback=callback,**vars(newopt)) + results = file_writer.files_written + variants.append([prompt,results]) + except AssertionError as e: + print(e) + continue + print(f'{opt.variants} variants generated') + return variants + + def write_log_message(t2i,prompt,results,logfile): - ''' logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata ''' + ''' logs the name of the output image, its prompt and seed to the terminal, log file, and a Dream text chunk in the PNG metadata''' last_seed = None img_num = 1 seenit = {} From 49247b4aa4f65f98a903b8ad47932d35b940e1fb Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 25 Aug 2022 09:41:12 -0400 Subject: [PATCH 5/7] fix performance regression; closes issue #42 --- ldm/simplet2i.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index 91b73def43..a3b4ecfcc7 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -486,6 +486,7 @@ The vast majority of these arguments default to reasonable values. sd = pl_sd["state_dict"] model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) + model.cuda() # fixes performance issue model.eval() if self.full_precision: print('Using slower but more accurate full-precision math (--full_precision)') From 26dc05e0e07e1fb93c8988832bf99cdb08f9d195 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 25 Aug 2022 09:47:27 -0400 Subject: [PATCH 6/7] document --from_file flag, closes issue #82 --- README.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/README.md b/README.md index 22c6447248..28a610db8c 100644 --- a/README.md +++ b/README.md @@ -88,6 +88,26 @@ You may also pass a -v option to generate count variants on the original passing the first generated image back into img2img the requested number of times. It generates interesting variants. +## Reading Prompts from a File + +You can automate dream.py by providing a text file with the prompts +you want to run, one line per prompt. The text file must be composed +with a text editor (e.g. Notepad) and not a word processor. Each line +should look like what you would type at the dream> prompt: + +~~~~ +a beautiful sunny day in the park, children playing -n4 -C10 +stormy weather on a mountain top, goats grazing -s100 +innovative packaging for a squid's dinner -S137038382 +~~~~ + +Then pass this file's name to dream.py when you invoke it: + +~~~~ +(ldm) ~/stable-diffusion$ python3 scripts/dream.py --from_file="path/to/prompts.txt" +~~~~ + + ## Weighted Prompts You may weight different sections of the prompt to tell the sampler to attach different levels of From b3e3b0e8613e7f11e0ae3e5f722233290b0732bb Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 25 Aug 2022 17:26:48 -0400 Subject: [PATCH 7/7] feature complete; looks like ready for merge --- TODO.txt | 1 + ldm/dream_util.py | 21 +++++++++++-- ldm/simplet2i.py | 1 - scripts/dream.py | 78 +++++++++++++++++++++++++---------------------- src/k-diffusion | 2 +- 5 files changed, 61 insertions(+), 42 deletions(-) diff --git a/TODO.txt b/TODO.txt index df9aea75ba..32475b43ba 100644 --- a/TODO.txt +++ b/TODO.txt @@ -2,6 +2,7 @@ Feature requests: 1. "gobig" mode - split image into strips, scale up, add detail using img2img and reassemble with feathering. Issue #66. + See https://github.com/jquesnelle/txt2imghd 2. Port basujindal low VRAM optimizations. Issue #62 diff --git a/ldm/dream_util.py b/ldm/dream_util.py index a1d0d3204b..b69a0c1367 100644 --- a/ldm/dream_util.py +++ b/ldm/dream_util.py @@ -2,6 +2,7 @@ import os import atexit import re +from math import sqrt,floor,ceil from PIL import Image,PngImagePlugin # -------------------image generation utils----- @@ -24,7 +25,7 @@ class PngWriter: print(e) self.files_written.append([self.filepath,seed]) - def unique_filename(self,seed,previouspath): + def unique_filename(self,seed,previouspath=None): revision = 1 if previouspath is None: @@ -61,6 +62,22 @@ class PngWriter: info = PngImagePlugin.PngInfo() info.add_text("Dream",prompt) image.save(path,"PNG",pnginfo=info) + + def make_grid(self,image_list,rows=None,cols=None): + image_cnt = len(image_list) + if None in (rows,cols): + rows = floor(sqrt(image_cnt)) # try to make it square + cols = ceil(image_cnt/rows) + width = image_list[0].width + height = image_list[0].height + + grid_img = Image.new('RGB',(width*cols,height*rows)) + for r in range(0,rows): + for c in range (0,cols): + i = r*rows + c + grid_img.paste(image_list[i],(c*width,r*height)) + + return grid_img class PromptFormatter(): def __init__(self,t2i,opt): @@ -80,8 +97,6 @@ class PromptFormatter(): switches.append(f'-H{opt.height or t2i.height}') switches.append(f'-C{opt.cfg_scale or t2i.cfg_scale}') switches.append(f'-m{t2i.sampler_name}') - if opt.variants: - switches.append(f'-v{opt.variants}') if opt.init_img: switches.append(f'-I{opt.init_img}') if opt.strength and opt.init_img is not None: diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index a3b4ecfcc7..fe0d3819a1 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -20,7 +20,6 @@ from torch import autocast from contextlib import contextmanager, nullcontext import transformers import time -import math import re import traceback diff --git a/scripts/dream.py b/scripts/dream.py index 10acccbfc3..24dac5b927 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -153,54 +153,60 @@ def main_loop(t2i,outdir,parser,log,infile): continue normalized_prompt = PromptFormatter(t2i,opt).normalize_prompt() - variants = None + individual_images = not opt.grid try: file_writer = PngWriter(outdir,normalized_prompt,opt.batch_size) - callback = file_writer.write_image + callback = file_writer.write_image if individual_images else None - t2i.prompt2image(image_callback=callback, - **vars(opt)) - results = file_writer.files_written + image_list = t2i.prompt2image(image_callback=callback,**vars(opt)) + results = file_writer.files_written if individual_images else image_list - if None not in (opt.variants,opt.init_img): - variants = generate_variants(t2i,outdir,opt,results) + if opt.grid and len(results) > 0: + grid_img = file_writer.make_grid([r[0] for r in results]) + filename = file_writer.unique_filename(results[0][1]) + seeds = [a[1] for a in results] + results = [[filename,seeds]] + metadata_prompt = f'{normalized_prompt} -S{results[0][1]}' + file_writer.save_image_and_prompt_to_png(grid_img,metadata_prompt,filename) except AssertionError as e: print(e) continue + except OSError as e: + print(e) + continue + print("Outputs:") write_log_message(t2i,normalized_prompt,results,log) - if variants is not None: - print('Variants:') - for vr in variants: - write_log_message(t2i,vr[0],vr[1],log) print("goodbye!") -def generate_variants(t2i,outdir,opt,previous_gens): - variants = [] - print(f"Generating {opt.variants} variant(s)...") - newopt = copy.deepcopy(opt) - newopt.iterations = 1 - newopt.variants = None - for r in previous_gens: - newopt.init_img = r[0] - prompt = PromptFormatter(t2i,newopt).normalize_prompt() - print(f"] generating variant for {newopt.init_img}") - for j in range(0,opt.variants): - try: - file_writer = PngWriter(outdir,prompt,newopt.batch_size) - callback = file_writer.write_image - t2i.prompt2image(image_callback=callback,**vars(newopt)) - results = file_writer.files_written - variants.append([prompt,results]) - except AssertionError as e: - print(e) - continue - print(f'{opt.variants} variants generated') - return variants +# variant generation is going to be superseded by a generalized +# "prompt-morph" functionality +# def generate_variants(t2i,outdir,opt,previous_gens): +# variants = [] +# print(f"Generating {opt.variants} variant(s)...") +# newopt = copy.deepcopy(opt) +# newopt.iterations = 1 +# newopt.variants = None +# for r in previous_gens: +# newopt.init_img = r[0] +# prompt = PromptFormatter(t2i,newopt).normalize_prompt() +# print(f"] generating variant for {newopt.init_img}") +# for j in range(0,opt.variants): +# try: +# file_writer = PngWriter(outdir,prompt,newopt.batch_size) +# callback = file_writer.write_image +# t2i.prompt2image(image_callback=callback,**vars(newopt)) +# results = file_writer.files_written +# variants.append([prompt,results]) +# except AssertionError as e: +# print(e) +# continue +# print(f'{opt.variants} variants generated') +# return variants def write_log_message(t2i,prompt,results,logfile): @@ -209,9 +215,6 @@ def write_log_message(t2i,prompt,results,logfile): img_num = 1 seenit = {} - seeds = [a[1] for a in results] - seeds = f"(seeds for individual images: {seeds})" - for r in results: seed = r[1] log_message = (f'{r[0]}: {prompt} -S{seed}') @@ -275,7 +278,8 @@ def create_cmd_parser(): parser.add_argument('-i','--individual',action='store_true',help="generate individual files (default)") parser.add_argument('-I','--init_img',type=str,help="path to input image for img2img mode (supersedes width and height)") parser.add_argument('-f','--strength',default=0.75,type=float,help="strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely") - parser.add_argument('-v','--variants',type=int,help="in img2img mode, the first generated image will get passed back to img2img to generate the requested number of variants") +# variants is going to be superseded by a generalized "prompt-morph" function +# parser.add_argument('-v','--variants',type=int,help="in img2img mode, the first generated image will get passed back to img2img to generate the requested number of variants") parser.add_argument('-x','--skip_normalize',action='store_true',help="skip subprompt weight normalization") return parser diff --git a/src/k-diffusion b/src/k-diffusion index db57990687..ef1bf07627 160000 --- a/src/k-diffusion +++ b/src/k-diffusion @@ -1 +1 @@ -Subproject commit db5799068749bf3a6d5845120ed32df16b7d883b +Subproject commit ef1bf07627c9a10ba9137e68a0206b844544a7d9