Keyboard interrupt retains seed and log information in files produced prior to interrupt. Closes #21

This commit is contained in:
Lincoln Stein 2022-08-23 00:51:38 -04:00
parent bc7b1fdd37
commit 710b908290
2 changed files with 109 additions and 101 deletions

View File

@ -186,58 +186,66 @@ The vast majority of these arguments default to reasonable values.
images = list() images = list()
seeds = list() seeds = list()
filename = None filename = None
image_count = 0
tic = time.time() tic = time.time()
try:
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
all_samples = list()
for n in trange(iterations, desc="Sampling"):
seed_everything(seed)
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
uc = None
if cfg_scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
samples_ddim, _ = sampler.sample(S=steps,
conditioning=c,
batch_size=batch_size,
shape=shape,
verbose=False,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
eta=ddim_eta,
x_T=start_code)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
if not grid:
for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
filename = self._unique_filename(outdir,previousname=filename,
seed=seed,isbatch=(batch_size>1))
assert not os.path.exists(filename)
Image.fromarray(x_sample.astype(np.uint8)).save(filename)
images.append([filename,seed])
else:
all_samples.append(x_samples_ddim)
seeds.append(seed)
image_count += 1
seed = self._new_seed()
if grid:
images = self._make_grid(samples=all_samples,
seeds=seeds,
batch_size=batch_size,
iterations=iterations,
outdir=outdir)
with torch.no_grad(): except KeyboardInterrupt:
with precision_scope("cuda"): print('*interrupted*')
with model.ema_scope(): print('Partial results will be returned; if --grid was requested, nothing will be returned.')
all_samples = list() except RuntimeError as e:
for n in trange(iterations, desc="Sampling"): print(str(e))
seed_everything(seed)
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
uc = None
if cfg_scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
samples_ddim, _ = sampler.sample(S=steps,
conditioning=c,
batch_size=batch_size,
shape=shape,
verbose=False,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
eta=ddim_eta,
x_T=start_code)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
if not grid:
for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
filename = self._unique_filename(outdir,previousname=filename,
seed=seed,isbatch=(batch_size>1))
assert not os.path.exists(filename)
Image.fromarray(x_sample.astype(np.uint8)).save(filename)
images.append([filename,seed])
else:
all_samples.append(x_samples_ddim)
seeds.append(seed)
seed = self._new_seed()
if grid:
images = self._make_grid(samples=all_samples,
seeds=seeds,
batch_size=batch_size,
iterations=iterations,
outdir=outdir)
toc = time.time() toc = time.time()
print(f'{batch_size * iterations} images generated in',"%4.2fs"% (toc-tic)) print(f'{image_count} images generated in',"%4.2fs"% (toc-tic))
return images return images
@ -305,54 +313,60 @@ The vast majority of these arguments default to reasonable values.
images = list() images = list()
seeds = list() seeds = list()
filename = None filename = None
image_count = 0 # actual number of iterations performed
tic = time.time() tic = time.time()
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
all_samples = list()
for n in trange(iterations, desc="Sampling"):
seed_everything(seed)
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
uc = None
if cfg_scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
# encode (scaled latent) try:
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device)) with torch.no_grad():
# decode it with precision_scope("cuda"):
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale, with model.ema_scope():
unconditional_conditioning=uc,) all_samples = list()
for n in trange(iterations, desc="Sampling"):
seed_everything(seed)
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
uc = None
if cfg_scale != 1.0:
uc = model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
x_samples = model.decode_first_stage(samples) # encode (scaled latent)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device))
# decode it
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,)
if not grid: x_samples = model.decode_first_stage(samples)
for x_sample in x_samples: x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
filename = self._unique_filename(outdir,filename,seed=seed,isbatch=(batch_size>1))
assert not os.path.exists(filename)
Image.fromarray(x_sample.astype(np.uint8)).save(filename)
images.append([filename,seed])
else:
all_samples.append(x_samples)
seeds.append(seed)
seed = self._new_seed() if not grid:
for x_sample in x_samples:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
filename = self._unique_filename(outdir,filename,seed=seed,isbatch=(batch_size>1))
assert not os.path.exists(filename)
Image.fromarray(x_sample.astype(np.uint8)).save(filename)
images.append([filename,seed])
else:
all_samples.append(x_samples)
seeds.append(seed)
image_count += 1
seed = self._new_seed()
if grid:
images = self._make_grid(samples=all_samples,
seeds=seeds,
batch_size=batch_size,
iterations=iterations,
outdir=outdir)
if grid: except KeyboardInterrupt:
images = self._make_grid(samples=all_samples, print('Partial results will be returned; if --grid was requested, nothing will be returned.')
seeds=seeds, except RuntimeError as e:
batch_size=batch_size, print(str(e))
iterations=iterations,
outdir=outdir)
toc = time.time() toc = time.time()
print(f'{batch_size * iterations} images generated in',"%4.2fs"% (toc-tic)) print(f'{image_count} images generated in',"%4.2fs"% (toc-tic))
return images return images

View File

@ -85,7 +85,8 @@ def main():
cmd_parser = create_cmd_parser() cmd_parser = create_cmd_parser()
main_loop(t2i,cmd_parser,log,infile) main_loop(t2i,cmd_parser,log,infile)
log.close() log.close()
infile.close() if infile:
infile.close()
def main_loop(t2i,parser,log,infile): def main_loop(t2i,parser,log,infile):
@ -157,19 +158,12 @@ def main_loop(t2i,parser,log,infile):
print("Try again with a prompt!") print("Try again with a prompt!")
continue continue
try: if opt.init_img is None:
if opt.init_img is None: results = t2i.txt2img(**vars(opt))
results = t2i.txt2img(**vars(opt)) else:
else: results = t2i.img2img(**vars(opt))
results = t2i.img2img(**vars(opt)) print("Outputs:")
print("Outputs:") write_log_message(t2i,opt,results,log)
write_log_message(t2i,opt,results,log)
except KeyboardInterrupt:
print('*interrupted*')
continue
except RuntimeError as e:
print(str(e))
continue
print("goodbye!") print("goodbye!")