diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index eca0f3b68e..bcd3e5a8e6 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -8,7 +8,7 @@ t2i = T2I(outdir = // outputs/txt2img-samples model = // models/ldm/stable-diffusion-v1/model.ckpt config = // default="configs/stable-diffusion/v1-inference.yaml iterations = // how many times to run the sampling (1) - batch = // how many images to generate per sampling (1) + batch_size = // how many images to generate per sampling (1) steps = // 50 seed = // current system time sampler = ['ddim','plms'] // ddim @@ -26,23 +26,22 @@ t2i.load_model() # override the default values assigned during class initialization # Will call load_model() if the model was not previously loaded. # The method returns a list of images. Each row of the list is a sub-list of [filename,seed] -results = t2i.txt2img(prompt = // required - outdir = // the remaining option arguments override constructur value when present - iterations = - batch = - steps = - seed = - sampler = ['ddim','plms'] - grid = - width = - height = - cfg_scale = - ) -> boolean +results = t2i.txt2img(prompt = "an astronaut riding a horse" + outdir = "./outputs/txt2img-samples) + ) for row in results: print(f'filename={row[0]}') print(f'seed ={row[1]}') + +# Same thing, but using an initial image. +results = t2i.img2img(prompt = "an astronaut riding a horse" + outdir = "./outputs/img2img-samples" + init_img = "./sketches/horse+rider.png") +for row in results: + print(f'filename={row[0]}') + print(f'seed ={row[1]}') """ import torch @@ -54,7 +53,7 @@ from omegaconf import OmegaConf from PIL import Image from tqdm import tqdm, trange from itertools import islice -from einops import rearrange +from einops import rearrange, repeat from torchvision.utils import make_grid from pytorch_lightning import seed_everything from torch import autocast @@ -74,7 +73,7 @@ class T2I: model config iterations - batch + batch_size steps seed sampler @@ -87,10 +86,11 @@ class T2I: latent_channels downsampling_factor precision + strength """ def __init__(self, outdir="outputs/txt2img-samples", - batch=1, + batch_size=1, iterations = 1, width=512, height=512, @@ -106,10 +106,11 @@ class T2I: downsampling_factor=8, ddim_eta=0.0, # deterministic fixed_code=False, - precision='autocast' + precision='autocast', + strength=0.75 # default in scripts/img2img.py ): self.outdir = outdir - self.batch = batch + self.batch_size = batch_size self.iterations = iterations self.width = width self.height = height @@ -124,15 +125,17 @@ class T2I: self.downsampling_factor = downsampling_factor self.ddim_eta = ddim_eta self.precision = precision + self.strength = strength self.model = None # empty for now self.sampler = None if seed is None: self.seed = self._new_seed() else: self.seed = seed - def txt2img(self,prompt,outdir=None,batch=None,iterations=None, + + def txt2img(self,prompt,outdir=None,batch_size=None,iterations=None, steps=None,seed=None,grid=None,individual=None,width=None,height=None, - cfg_scale=None,ddim_eta=None): + cfg_scale=None,ddim_eta=None,strength=None,init_img=None): """ Generate an image from the prompt, writing iteration images into the outdir The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...] @@ -144,8 +147,9 @@ class T2I: height = height or self.height cfg_scale = cfg_scale or self.cfg_scale ddim_eta = ddim_eta or self.ddim_eta - batch = batch or self.batch + batch_size = batch_size or self.batch_size iterations = iterations or self.iterations + strength = strength or self.strength # not actually used here, but preserved for code refactoring model = self.load_model() # will instantiate the model or return it from cache @@ -156,7 +160,7 @@ class T2I: if individual: grid = False - data = [batch * [prompt]] + data = [batch_size * [prompt]] # make directories and establish names for the output files os.makedirs(outdir, exist_ok=True) @@ -164,7 +168,7 @@ class T2I: start_code = None if self.fixed_code: - start_code = torch.randn([batch, + start_code = torch.randn([batch_size, self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor], @@ -186,14 +190,14 @@ class T2I: for prompts in tqdm(data, desc="data", dynamic_ncols=True): uc = None if cfg_scale != 1.0: - uc = model.get_learned_conditioning(batch * [""]) + uc = model.get_learned_conditioning(batch_size * [""]) if isinstance(prompts, tuple): prompts = list(prompts) c = model.get_learned_conditioning(prompts) shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor] samples_ddim, _ = sampler.sample(S=steps, conditioning=c, - batch_size=batch, + batch_size_size=batch_size, shape=shape, verbose=False, unconditional_guidance_scale=cfg_scale, @@ -218,24 +222,146 @@ class T2I: seed = self._new_seed() if grid: - n_rows = batch if batch>1 else int(math.sqrt(batch * iterations)) - # save as grid - grid = torch.stack(all_samples, 0) - grid = rearrange(grid, 'n b c h w -> (n b) c h w') - grid = make_grid(grid, nrow=n_rows) - - # to image - grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() - filename = os.path.join(outdir, f"{base_count:05}.png") - Image.fromarray(grid.astype(np.uint8)).save(filename) - for s in seeds: - images.append([filename,s]) + images = self._make_grid(samples=all_samples, + seeds=seeds, + batch_size=batch_size, + iterations=iterations, + outdir=outdir) toc = time.time() - print(f'{batch * iterations} images generated in',"%4.2fs"% (toc-tic)) + print(f'{batch_size * iterations} images generated in',"%4.2fs"% (toc-tic)) return images + # There is lots of shared code between this and txt2img and should be refactored. + def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=None, + steps=None,seed=None,grid=None,individual=None,width=None,height=None, + cfg_scale=None,ddim_eta=None,strength=None): + """ + Generate an image from the prompt and the initial image, writing iteration images into the outdir + The output is a list of lists in the format: [[filename1,seed1], [filename2,seed2],...] + """ + outdir = outdir or self.outdir + steps = steps or self.steps + seed = seed or self.seed + cfg_scale = cfg_scale or self.cfg_scale + ddim_eta = ddim_eta or self.ddim_eta + batch_size = batch_size or self.batch_size + iterations = iterations or self.iterations + strength = strength or self.strength + + if init_img is None: + print("no init_img provided!") + return [] + + model = self.load_model() # will instantiate the model or return it from cache + + # grid and individual are mutually exclusive, with individual taking priority. + # not necessary, but needed for compatability with dream bot + if (grid is None): + grid = self.grid + if individual: + grid = False + + data = [batch_size * [prompt]] + + # PLMS sampler not supported yet, so ignore previous sampler + if self.sampler_name!='ddim': + print(f"sampler '{self.sampler_name}' is not yet supported. Using DDM sampler") + sampler = DDIMSampler(model) + else: + sampler = self.sampler + + # make directories and establish names for the output files + os.makedirs(outdir, exist_ok=True) + base_count = len(os.listdir(outdir))-1 + + assert os.path.isfile(init_img) + init_image = self._load_img(init_img).to(self.device) + init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) + init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space + + sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False) + + try: + assert 0. <= strength <= 1., 'can only work with strength in [0.0, 1.0]' + except AssertionError: + print(f"strength must be between 0.0 and 1.0, but received value {strength}") + return [] + + t_enc = int(strength * steps) + print(f"target t_enc is {t_enc} steps") + + precision_scope = autocast if self.precision=="autocast" else nullcontext + images = list() + seeds = list() + + tic = time.time() + + with torch.no_grad(): + with precision_scope("cuda"): + with model.ema_scope(): + all_samples = list() + for n in trange(iterations, desc="Sampling"): + seed_everything(seed) + for prompts in tqdm(data, desc="data", dynamic_ncols=True): + uc = None + if cfg_scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + + # encode (scaled latent) + z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device)) + # decode it + samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=uc,) + + x_samples = model.decode_first_stage(samples) + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) + + if not grid: + for x_sample in x_samples: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + filename = os.path.join(outdir, f"{base_count:05}.png") + Image.fromarray(x_sample.astype(np.uint8)).save(filename) + images.append([filename,seed]) + base_count += 1 + else: + all_samples.append(x_samples) + seeds.append(seed) + + seed = self._new_seed() + + if grid: + images = self._make_grid(samples=all_samples, + seeds=seeds, + batch_size=batch_size, + iterations=iterations, + outdir=outdir) + + toc = time.time() + print(f'{batch_size * iterations} images generated in',"%4.2fs"% (toc-tic)) + + return images + + def _make_grid(self,samples,seeds,batch_size,iterations,outdir): + images = list() + base_count = len(os.listdir(outdir))-1 + n_rows = batch_size if batch_size>1 else int(math.sqrt(batch_size * iterations)) + # save as grid + grid = torch.stack(samples, 0) + grid = rearrange(grid, 'n b c h w -> (n b) c h w') + grid = make_grid(grid, nrow=n_rows) + + # to image + grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() + filename = os.path.join(outdir, f"{base_count:05}.png") + Image.fromarray(grid.astype(np.uint8)).save(filename) + for s in seeds: + images.append([filename,s]) + return images def _new_seed(self): self.seed = random.randrange(0,np.iinfo(np.uint32).max) @@ -277,3 +403,13 @@ class T2I: model.eval() return model + def _load_img(self,path): + image = Image.open(path).convert("RGB") + w, h = image.size + print(f"loaded input image of size ({w}, {h}) from {path}") + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h), resample=Image.Resampling.LANCZOS) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.*image - 1. diff --git a/scripts/dream.py b/scripts/dream.py index 0e8c19a233..869e407117 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -6,6 +6,8 @@ import shlex import atexit import os +debugging = False + def main(): ''' Initialize command-line parsers and the diffusion model ''' arg_parser = create_argv_parser() @@ -24,7 +26,7 @@ def main(): weights = "models/ldm/stable-diffusion-v1/model.ckpt" # command line history will be stored in a file called "~/.dream_history" - load_history() + setup_readline() print("* Initializing, be patient...\n") from pytorch_lightning import logging @@ -36,7 +38,7 @@ def main(): # the user input loop t2i = T2I(width=width, height=height, - batch=opt.batch, + batch_size=opt.batch_size, outdir=opt.outdir, sampler=opt.sampler, weights=weights, @@ -50,7 +52,8 @@ def main(): logging.getLogger("pytorch_lightning").setLevel(logging.ERROR) # preload the model - t2i.load_model() + if not debugging: + t2i.load_model() print("\n* Initialization done! Awaiting your command (-h for help)...") log_path = os.path.join(opt.outdir,"dream_log.txt") @@ -92,7 +95,10 @@ def main_loop(t2i,parser,log): print("Try again with a prompt!") continue - results = t2i.txt2img(**vars(opt)) + if opt.init_img is None: + results = t2i.txt2img(**vars(opt)) + else: + results = t2i.img2img(**vars(opt)) print("Outputs:") write_log_message(opt,switches,results,log) @@ -136,7 +142,7 @@ def create_argv_parser(): type=int, default=1, help="number of images to generate") - parser.add_argument('-b','--batch', + parser.add_argument('-b','--batch_size', type=int, default=1, help="number of images to produce per iteration (currently not working properly - producing too many images)") @@ -158,14 +164,24 @@ def create_cmd_parser(): parser.add_argument('-s','--steps',type=int,help="number of steps") parser.add_argument('-S','--seed',type=int,help="image seed") parser.add_argument('-n','--iterations',type=int,default=1,help="number of samplings to perform") - parser.add_argument('-b','--batch',type=int,default=1,help="number of images to produce per sampling (currently broken)") + parser.add_argument('-b','--batch_size',type=int,default=1,help="number of images to produce per sampling (currently broken)") parser.add_argument('-W','--width',type=int,help="image width, multiple of 64") parser.add_argument('-H','--height',type=int,help="image height, multiple of 64") - parser.add_argument('-C','--cfg_scale',type=float,help="prompt configuration scale (7.5)") + parser.add_argument('-C','--cfg_scale',default=7.5,type=float,help="prompt configuration scale") parser.add_argument('-g','--grid',action='store_true',help="generate a grid") parser.add_argument('-i','--individual',action='store_true',help="generate individual files (default)") + parser.add_argument('-I','--init_img',type=str,help="path to input image (supersedes width and height)") + parser.add_argument('-f','--strength',default=0.75,type=float,help="strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely") return parser +def setup_readline(): + readline.set_completer(Completer(['--steps','-s','--seed','-S','--iterations','-n','--batch_size','-b', + '--width','-W','--height','-H','--cfg_scale','-C','--grid','-g', + '--individual','-i','--init_img','-I','--strength','-f']).complete) + readline.set_completer_delims(" ") + readline.parse_and_bind('tab: complete') + load_history() + def load_history(): histfile = os.path.join(os.path.expanduser('~'),".dream_history") try: @@ -175,5 +191,64 @@ def load_history(): pass atexit.register(readline.write_history_file,histfile) +class Completer(): + def __init__(self,options): + self.options = sorted(options) + return + + def complete(self,text,state): + if text.startswith('-I') or text.startswith('--init_img'): + return self._image_completions(text,state) + + response = None + if state == 0: + # This is the first time for this text, so build a match list. + if text: + self.matches = [s + for s in self.options + if s and s.startswith(text)] + else: + self.matches = self.options[:] + + # Return the state'th item from the match list, + # if we have that many. + try: + response = self.matches[state] + except IndexError: + response = None + return response + + def _image_completions(self,text,state): + # get the path so far + if text.startswith('-I'): + path = text.replace('-I','',1).lstrip() + elif text.startswith('--init_img='): + path = text.replace('--init_img=','',1).lstrip() + + matches = list() + + path = os.path.expanduser(path) + if len(path)==0: + matches.append(text+'./') + else: + dir = os.path.dirname(path) + dir_list = os.listdir(dir) + for n in dir_list: + if n.startswith('.') and len(n)>1: + continue + full_path = os.path.join(dir,n) + if full_path.startswith(path): + if os.path.isdir(full_path): + matches.append(os.path.join(os.path.dirname(text),n)+'/') + elif n.endswith('.png'): + matches.append(os.path.join(os.path.dirname(text),n)) + + try: + response = matches[state] + except IndexError: + response = None + return response + + if __name__ == "__main__": main()