diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index febc0e461f..bcd3e5a8e6 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -8,7 +8,7 @@ t2i = T2I(outdir = // outputs/txt2img-samples model = // models/ldm/stable-diffusion-v1/model.ckpt config = // default="configs/stable-diffusion/v1-inference.yaml iterations = // how many times to run the sampling (1) - batch = // how many images to generate per sampling (1) + batch_size = // how many images to generate per sampling (1) steps = // 50 seed = // current system time sampler = ['ddim','plms'] // ddim @@ -73,7 +73,7 @@ class T2I: model config iterations - batch + batch_size steps seed sampler @@ -90,7 +90,7 @@ class T2I: """ def __init__(self, outdir="outputs/txt2img-samples", - batch=1, + batch_size=1, iterations = 1, width=512, height=512, @@ -110,7 +110,7 @@ class T2I: strength=0.75 # default in scripts/img2img.py ): self.outdir = outdir - self.batch = batch + self.batch_size = batch_size self.iterations = iterations self.width = width self.height = height @@ -133,7 +133,7 @@ class T2I: else: self.seed = seed - def txt2img(self,prompt,outdir=None,batch=None,iterations=None, + def txt2img(self,prompt,outdir=None,batch_size=None,iterations=None, steps=None,seed=None,grid=None,individual=None,width=None,height=None, cfg_scale=None,ddim_eta=None,strength=None,init_img=None): """ @@ -147,7 +147,7 @@ class T2I: height = height or self.height cfg_scale = cfg_scale or self.cfg_scale ddim_eta = ddim_eta or self.ddim_eta - batch = batch or self.batch + batch_size = batch_size or self.batch_size iterations = iterations or self.iterations strength = strength or self.strength # not actually used here, but preserved for code refactoring @@ -160,7 +160,7 @@ class T2I: if individual: grid = False - data = [batch * [prompt]] + data = [batch_size * [prompt]] # make directories and establish names for the output files os.makedirs(outdir, exist_ok=True) @@ -168,7 +168,7 @@ class T2I: start_code = None if self.fixed_code: - start_code = torch.randn([batch, + start_code = torch.randn([batch_size, self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor], @@ -190,14 +190,14 @@ class T2I: for prompts in tqdm(data, desc="data", dynamic_ncols=True): uc = None if cfg_scale != 1.0: - uc = model.get_learned_conditioning(batch * [""]) + uc = model.get_learned_conditioning(batch_size * [""]) if isinstance(prompts, tuple): prompts = list(prompts) c = model.get_learned_conditioning(prompts) shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor] samples_ddim, _ = sampler.sample(S=steps, conditioning=c, - batch_size=batch, + batch_size_size=batch_size, shape=shape, verbose=False, unconditional_guidance_scale=cfg_scale, @@ -224,17 +224,17 @@ class T2I: if grid: images = self._make_grid(samples=all_samples, seeds=seeds, - batch_size=batch, + batch_size=batch_size, iterations=iterations, outdir=outdir) toc = time.time() - print(f'{batch * iterations} images generated in',"%4.2fs"% (toc-tic)) + print(f'{batch_size * iterations} images generated in',"%4.2fs"% (toc-tic)) return images # There is lots of shared code between this and txt2img and should be refactored. - def img2img(self,prompt,outdir=None,init_img=None,batch=None,iterations=None, + def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=None, steps=None,seed=None,grid=None,individual=None,width=None,height=None, cfg_scale=None,ddim_eta=None,strength=None): """ @@ -246,7 +246,7 @@ class T2I: seed = seed or self.seed cfg_scale = cfg_scale or self.cfg_scale ddim_eta = ddim_eta or self.ddim_eta - batch = batch or self.batch + batch_size = batch_size or self.batch_size iterations = iterations or self.iterations strength = strength or self.strength @@ -263,7 +263,7 @@ class T2I: if individual: grid = False - data = [batch * [prompt]] + data = [batch_size * [prompt]] # PLMS sampler not supported yet, so ignore previous sampler if self.sampler_name!='ddim': @@ -278,7 +278,7 @@ class T2I: assert os.path.isfile(init_img) init_image = self._load_img(init_img).to(self.device) - init_image = repeat(init_image, '1 ... -> b ...', b=batch) + init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False) @@ -307,13 +307,13 @@ class T2I: for prompts in tqdm(data, desc="data", dynamic_ncols=True): uc = None if cfg_scale != 1.0: - uc = model.get_learned_conditioning(batch * [""]) + uc = model.get_learned_conditioning(batch_size * [""]) if isinstance(prompts, tuple): prompts = list(prompts) c = model.get_learned_conditioning(prompts) # encode (scaled latent) - z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch).to(self.device)) + z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device)) # decode it samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc,) @@ -337,12 +337,12 @@ class T2I: if grid: images = self._make_grid(samples=all_samples, seeds=seeds, - batch_size=batch, + batch_size=batch_size, iterations=iterations, outdir=outdir) toc = time.time() - print(f'{batch * iterations} images generated in',"%4.2fs"% (toc-tic)) + print(f'{batch_size * iterations} images generated in',"%4.2fs"% (toc-tic)) return images diff --git a/scripts/dream.py b/scripts/dream.py index 3d385e454f..869e407117 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -6,6 +6,8 @@ import shlex import atexit import os +debugging = False + def main(): ''' Initialize command-line parsers and the diffusion model ''' arg_parser = create_argv_parser() @@ -24,7 +26,7 @@ def main(): weights = "models/ldm/stable-diffusion-v1/model.ckpt" # command line history will be stored in a file called "~/.dream_history" - load_history() + setup_readline() print("* Initializing, be patient...\n") from pytorch_lightning import logging @@ -36,7 +38,7 @@ def main(): # the user input loop t2i = T2I(width=width, height=height, - batch=opt.batch, + batch_size=opt.batch_size, outdir=opt.outdir, sampler=opt.sampler, weights=weights, @@ -50,7 +52,8 @@ def main(): logging.getLogger("pytorch_lightning").setLevel(logging.ERROR) # preload the model - t2i.load_model() + if not debugging: + t2i.load_model() print("\n* Initialization done! Awaiting your command (-h for help)...") log_path = os.path.join(opt.outdir,"dream_log.txt") @@ -139,7 +142,7 @@ def create_argv_parser(): type=int, default=1, help="number of images to generate") - parser.add_argument('-b','--batch', + parser.add_argument('-b','--batch_size', type=int, default=1, help="number of images to produce per iteration (currently not working properly - producing too many images)") @@ -161,7 +164,7 @@ def create_cmd_parser(): parser.add_argument('-s','--steps',type=int,help="number of steps") parser.add_argument('-S','--seed',type=int,help="image seed") parser.add_argument('-n','--iterations',type=int,default=1,help="number of samplings to perform") - parser.add_argument('-b','--batch',type=int,default=1,help="number of images to produce per sampling (currently broken)") + parser.add_argument('-b','--batch_size',type=int,default=1,help="number of images to produce per sampling (currently broken)") parser.add_argument('-W','--width',type=int,help="image width, multiple of 64") parser.add_argument('-H','--height',type=int,help="image height, multiple of 64") parser.add_argument('-C','--cfg_scale',default=7.5,type=float,help="prompt configuration scale") @@ -171,6 +174,14 @@ def create_cmd_parser(): parser.add_argument('-f','--strength',default=0.75,type=float,help="strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely") return parser +def setup_readline(): + readline.set_completer(Completer(['--steps','-s','--seed','-S','--iterations','-n','--batch_size','-b', + '--width','-W','--height','-H','--cfg_scale','-C','--grid','-g', + '--individual','-i','--init_img','-I','--strength','-f']).complete) + readline.set_completer_delims(" ") + readline.parse_and_bind('tab: complete') + load_history() + def load_history(): histfile = os.path.join(os.path.expanduser('~'),".dream_history") try: @@ -180,5 +191,64 @@ def load_history(): pass atexit.register(readline.write_history_file,histfile) +class Completer(): + def __init__(self,options): + self.options = sorted(options) + return + + def complete(self,text,state): + if text.startswith('-I') or text.startswith('--init_img'): + return self._image_completions(text,state) + + response = None + if state == 0: + # This is the first time for this text, so build a match list. + if text: + self.matches = [s + for s in self.options + if s and s.startswith(text)] + else: + self.matches = self.options[:] + + # Return the state'th item from the match list, + # if we have that many. + try: + response = self.matches[state] + except IndexError: + response = None + return response + + def _image_completions(self,text,state): + # get the path so far + if text.startswith('-I'): + path = text.replace('-I','',1).lstrip() + elif text.startswith('--init_img='): + path = text.replace('--init_img=','',1).lstrip() + + matches = list() + + path = os.path.expanduser(path) + if len(path)==0: + matches.append(text+'./') + else: + dir = os.path.dirname(path) + dir_list = os.listdir(dir) + for n in dir_list: + if n.startswith('.') and len(n)>1: + continue + full_path = os.path.join(dir,n) + if full_path.startswith(path): + if os.path.isdir(full_path): + matches.append(os.path.join(os.path.dirname(text),n)+'/') + elif n.endswith('.png'): + matches.append(os.path.join(os.path.dirname(text),n)) + + try: + response = matches[state] + except IndexError: + response = None + return response + + if __name__ == "__main__": main()