Keyboard interrupt retains seed and log information in files produced prior to interrupt. Closes #21

This commit is contained in:
Lincoln Stein 2022-08-23 00:51:38 -04:00
parent bc7b1fdd37
commit 710b908290
2 changed files with 109 additions and 101 deletions

View File

@ -186,8 +186,10 @@ The vast majority of these arguments default to reasonable values.
images = list()
seeds = list()
filename = None
image_count = 0
tic = time.time()
try:
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
@ -226,7 +228,7 @@ The vast majority of these arguments default to reasonable values.
else:
all_samples.append(x_samples_ddim)
seeds.append(seed)
image_count += 1
seed = self._new_seed()
if grid:
@ -236,8 +238,14 @@ The vast majority of these arguments default to reasonable values.
iterations=iterations,
outdir=outdir)
except KeyboardInterrupt:
print('*interrupted*')
print('Partial results will be returned; if --grid was requested, nothing will be returned.')
except RuntimeError as e:
print(str(e))
toc = time.time()
print(f'{batch_size * iterations} images generated in',"%4.2fs"% (toc-tic))
print(f'{image_count} images generated in',"%4.2fs"% (toc-tic))
return images
@ -305,9 +313,11 @@ The vast majority of these arguments default to reasonable values.
images = list()
seeds = list()
filename = None
image_count = 0 # actual number of iterations performed
tic = time.time()
try:
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
@ -341,9 +351,8 @@ The vast majority of these arguments default to reasonable values.
else:
all_samples.append(x_samples)
seeds.append(seed)
image_count += 1
seed = self._new_seed()
if grid:
images = self._make_grid(samples=all_samples,
seeds=seeds,
@ -351,8 +360,13 @@ The vast majority of these arguments default to reasonable values.
iterations=iterations,
outdir=outdir)
except KeyboardInterrupt:
print('Partial results will be returned; if --grid was requested, nothing will be returned.')
except RuntimeError as e:
print(str(e))
toc = time.time()
print(f'{batch_size * iterations} images generated in',"%4.2fs"% (toc-tic))
print(f'{image_count} images generated in',"%4.2fs"% (toc-tic))
return images

View File

@ -85,6 +85,7 @@ def main():
cmd_parser = create_cmd_parser()
main_loop(t2i,cmd_parser,log,infile)
log.close()
if infile:
infile.close()
@ -157,19 +158,12 @@ def main_loop(t2i,parser,log,infile):
print("Try again with a prompt!")
continue
try:
if opt.init_img is None:
results = t2i.txt2img(**vars(opt))
else:
results = t2i.img2img(**vars(opt))
print("Outputs:")
write_log_message(t2i,opt,results,log)
except KeyboardInterrupt:
print('*interrupted*')
continue
except RuntimeError as e:
print(str(e))
continue
print("goodbye!")