From 710b9082901e21fdf15682dde319e9d57ca98234 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 23 Aug 2022 00:51:38 -0400 Subject: [PATCH] Keyboard interrupt retains seed and log information in files produced prior to interrupt. Closes #21 --- ldm/simplet2i.py | 188 +++++++++++++++++++++++++---------------------- scripts/dream.py | 22 ++---- 2 files changed, 109 insertions(+), 101 deletions(-) diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index a4eac1af53..45f91d3dde 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -186,58 +186,66 @@ The vast majority of these arguments default to reasonable values. images = list() seeds = list() filename = None + image_count = 0 tic = time.time() + + try: + with torch.no_grad(): + with precision_scope("cuda"): + with model.ema_scope(): + all_samples = list() + for n in trange(iterations, desc="Sampling"): + seed_everything(seed) + for prompts in tqdm(data, desc="data", dynamic_ncols=True): + uc = None + if cfg_scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor] + samples_ddim, _ = sampler.sample(S=steps, + conditioning=c, + batch_size=batch_size, + shape=shape, + verbose=False, + unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=uc, + eta=ddim_eta, + x_T=start_code) + + x_samples_ddim = model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + + if not grid: + for x_sample in x_samples_ddim: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + filename = self._unique_filename(outdir,previousname=filename, + seed=seed,isbatch=(batch_size>1)) + assert not os.path.exists(filename) + Image.fromarray(x_sample.astype(np.uint8)).save(filename) + images.append([filename,seed]) + else: + all_samples.append(x_samples_ddim) + seeds.append(seed) + image_count += 1 + seed = self._new_seed() + + if grid: + images = self._make_grid(samples=all_samples, + seeds=seeds, + batch_size=batch_size, + iterations=iterations, + outdir=outdir) - with torch.no_grad(): - with precision_scope("cuda"): - with model.ema_scope(): - all_samples = list() - for n in trange(iterations, desc="Sampling"): - seed_everything(seed) - for prompts in tqdm(data, desc="data", dynamic_ncols=True): - uc = None - if cfg_scale != 1.0: - uc = model.get_learned_conditioning(batch_size * [""]) - if isinstance(prompts, tuple): - prompts = list(prompts) - c = model.get_learned_conditioning(prompts) - shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor] - samples_ddim, _ = sampler.sample(S=steps, - conditioning=c, - batch_size=batch_size, - shape=shape, - verbose=False, - unconditional_guidance_scale=cfg_scale, - unconditional_conditioning=uc, - eta=ddim_eta, - x_T=start_code) - - x_samples_ddim = model.decode_first_stage(samples_ddim) - x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - - if not grid: - for x_sample in x_samples_ddim: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - filename = self._unique_filename(outdir,previousname=filename, - seed=seed,isbatch=(batch_size>1)) - assert not os.path.exists(filename) - Image.fromarray(x_sample.astype(np.uint8)).save(filename) - images.append([filename,seed]) - else: - all_samples.append(x_samples_ddim) - seeds.append(seed) - - seed = self._new_seed() - - if grid: - images = self._make_grid(samples=all_samples, - seeds=seeds, - batch_size=batch_size, - iterations=iterations, - outdir=outdir) + except KeyboardInterrupt: + print('*interrupted*') + print('Partial results will be returned; if --grid was requested, nothing will be returned.') + except RuntimeError as e: + print(str(e)) toc = time.time() - print(f'{batch_size * iterations} images generated in',"%4.2fs"% (toc-tic)) + print(f'{image_count} images generated in',"%4.2fs"% (toc-tic)) return images @@ -305,54 +313,60 @@ The vast majority of these arguments default to reasonable values. images = list() seeds = list() filename = None + image_count = 0 # actual number of iterations performed tic = time.time() - - with torch.no_grad(): - with precision_scope("cuda"): - with model.ema_scope(): - all_samples = list() - for n in trange(iterations, desc="Sampling"): - seed_everything(seed) - for prompts in tqdm(data, desc="data", dynamic_ncols=True): - uc = None - if cfg_scale != 1.0: - uc = model.get_learned_conditioning(batch_size * [""]) - if isinstance(prompts, tuple): - prompts = list(prompts) - c = model.get_learned_conditioning(prompts) - # encode (scaled latent) - z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device)) - # decode it - samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale, - unconditional_conditioning=uc,) + try: + with torch.no_grad(): + with precision_scope("cuda"): + with model.ema_scope(): + all_samples = list() + for n in trange(iterations, desc="Sampling"): + seed_everything(seed) + for prompts in tqdm(data, desc="data", dynamic_ncols=True): + uc = None + if cfg_scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) - x_samples = model.decode_first_stage(samples) - x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) + # encode (scaled latent) + z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device)) + # decode it + samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=uc,) - if not grid: - for x_sample in x_samples: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - filename = self._unique_filename(outdir,filename,seed=seed,isbatch=(batch_size>1)) - assert not os.path.exists(filename) - Image.fromarray(x_sample.astype(np.uint8)).save(filename) - images.append([filename,seed]) - else: - all_samples.append(x_samples) - seeds.append(seed) + x_samples = model.decode_first_stage(samples) + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) - seed = self._new_seed() + if not grid: + for x_sample in x_samples: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + filename = self._unique_filename(outdir,filename,seed=seed,isbatch=(batch_size>1)) + assert not os.path.exists(filename) + Image.fromarray(x_sample.astype(np.uint8)).save(filename) + images.append([filename,seed]) + else: + all_samples.append(x_samples) + seeds.append(seed) + image_count += 1 + seed = self._new_seed() + if grid: + images = self._make_grid(samples=all_samples, + seeds=seeds, + batch_size=batch_size, + iterations=iterations, + outdir=outdir) - if grid: - images = self._make_grid(samples=all_samples, - seeds=seeds, - batch_size=batch_size, - iterations=iterations, - outdir=outdir) + except KeyboardInterrupt: + print('Partial results will be returned; if --grid was requested, nothing will be returned.') + except RuntimeError as e: + print(str(e)) toc = time.time() - print(f'{batch_size * iterations} images generated in',"%4.2fs"% (toc-tic)) + print(f'{image_count} images generated in',"%4.2fs"% (toc-tic)) return images diff --git a/scripts/dream.py b/scripts/dream.py index 74cd4efcdd..0e511f7789 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -85,7 +85,8 @@ def main(): cmd_parser = create_cmd_parser() main_loop(t2i,cmd_parser,log,infile) log.close() - infile.close() + if infile: + infile.close() def main_loop(t2i,parser,log,infile): @@ -157,19 +158,12 @@ def main_loop(t2i,parser,log,infile): print("Try again with a prompt!") continue - try: - if opt.init_img is None: - results = t2i.txt2img(**vars(opt)) - else: - results = t2i.img2img(**vars(opt)) - print("Outputs:") - write_log_message(t2i,opt,results,log) - except KeyboardInterrupt: - print('*interrupted*') - continue - except RuntimeError as e: - print(str(e)) - continue + if opt.init_img is None: + results = t2i.txt2img(**vars(opt)) + else: + results = t2i.img2img(**vars(opt)) + print("Outputs:") + write_log_message(t2i,opt,results,log) print("goodbye!")