factor out exception handler

This commit is contained in:
Kevin Gibbons 2022-08-25 15:13:07 -07:00
parent 0eba55ddbc
commit a10baf5808

View File

@ -252,24 +252,37 @@ The vast majority of these arguments default to reasonable values.
data = [batch_size * [prompt]] data = [batch_size * [prompt]]
scope = autocast if self.precision=="autocast" else nullcontext scope = autocast if self.precision=="autocast" else nullcontext
tic = time.time() tic = time.time()
if init_img: results = list()
assert os.path.exists(init_img),f'{init_img}: File not found' def prompt_callback(image, seed):
results = self._img2img(prompt, results.append([image, seed])
data=data,precision_scope=scope, if image_callback is not None:
batch_size=batch_size,iterations=iterations, image_callback(image, seed)
steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
skip_normalize=skip_normalize, try:
init_img=init_img,strength=strength,variants=variants, if init_img:
callback=image_callback) assert os.path.exists(init_img),f'{init_img}: File not found'
else: self._img2img(prompt,
results = self._txt2img(prompt, data=data,precision_scope=scope,
data=data,precision_scope=scope, batch_size=batch_size,iterations=iterations,
batch_size=batch_size,iterations=iterations, steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta, skip_normalize=skip_normalize,
skip_normalize=skip_normalize, init_img=init_img,strength=strength,variants=variants,
width=width,height=height, callback=prompt_callback)
callback=image_callback) else:
self._txt2img(prompt,
data=data,precision_scope=scope,
batch_size=batch_size,iterations=iterations,
steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
skip_normalize=skip_normalize,
width=width,height=height,
callback=prompt_callback)
except KeyboardInterrupt:
print('*interrupted*')
print('Partial results will be returned; if --grid was requested, nothing will be returned.')
except RuntimeError as e:
print(str(e))
toc = time.time() toc = time.time()
print(f'{len(results)} images generated in',"%4.2fs"% (toc-tic)) print(f'{len(results)} images generated in',"%4.2fs"% (toc-tic))
return results return results
@ -292,59 +305,53 @@ The vast majority of these arguments default to reasonable values.
image_count = 0 image_count = 0
# Gawd. Too many levels of indent here. Need to refactor into smaller routines! # Gawd. Too many levels of indent here. Need to refactor into smaller routines!
try: with precision_scope(self.device.type), self.model.ema_scope():
with precision_scope(self.device.type), self.model.ema_scope(): all_samples = list()
all_samples = list() for n in trange(iterations, desc="Sampling"):
for n in trange(iterations, desc="Sampling"): seed_everything(seed)
seed_everything(seed) for prompts in tqdm(data, desc="data", dynamic_ncols=True):
for prompts in tqdm(data, desc="data", dynamic_ncols=True): uc = None
uc = None if cfg_scale != 1.0:
if cfg_scale != 1.0: uc = self.model.get_learned_conditioning(batch_size * [""])
uc = self.model.get_learned_conditioning(batch_size * [""]) if isinstance(prompts, tuple):
if isinstance(prompts, tuple): prompts = list(prompts)
prompts = list(prompts)
# weighted sub-prompts # weighted sub-prompts
subprompts,weights = T2I._split_weighted_subprompts(prompts[0]) subprompts,weights = T2I._split_weighted_subprompts(prompts[0])
if len(subprompts) > 1: if len(subprompts) > 1:
# i dont know if this is correct.. but it works # i dont know if this is correct.. but it works
c = torch.zeros_like(uc) c = torch.zeros_like(uc)
# get total weight for normalizing # get total weight for normalizing
totalWeight = sum(weights) totalWeight = sum(weights)
# normalize each "sub prompt" and add it # normalize each "sub prompt" and add it
for i in range(0,len(subprompts)): for i in range(0,len(subprompts)):
weight = weights[i] weight = weights[i]
if not skip_normalize: if not skip_normalize:
weight = weight / totalWeight weight = weight / totalWeight
c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight) c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight)
else: # just standard 1 prompt else: # just standard 1 prompt
c = self.model.get_learned_conditioning(prompts) c = self.model.get_learned_conditioning(prompts)
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor] shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
samples_ddim, _ = sampler.sample(S=steps, samples_ddim, _ = sampler.sample(S=steps,
conditioning=c, conditioning=c,
batch_size=batch_size, batch_size=batch_size,
shape=shape, shape=shape,
verbose=False, verbose=False,
unconditional_guidance_scale=cfg_scale, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc, unconditional_conditioning=uc,
eta=ddim_eta) eta=ddim_eta)
x_samples_ddim = self.model.decode_first_stage(samples_ddim) x_samples_ddim = self.model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
for x_sample in x_samples_ddim: for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
image = Image.fromarray(x_sample.astype(np.uint8)) image = Image.fromarray(x_sample.astype(np.uint8))
images.append([image,seed]) images.append([image,seed])
if callback is not None: if callback is not None:
callback(image,seed) callback(image,seed)
seed = self._new_seed() seed = self._new_seed()
except KeyboardInterrupt:
print('*interrupted*')
print('Partial results will be returned; if --grid was requested, nothing will be returned.')
except RuntimeError as e:
print(str(e))
return images return images
@ -379,57 +386,50 @@ The vast majority of these arguments default to reasonable values.
# print(f"target t_enc is {t_enc} steps") # print(f"target t_enc is {t_enc} steps")
images = list() images = list()
try: with precision_scope(self.device.type), self.model.ema_scope():
with precision_scope(self.device.type), self.model.ema_scope(): all_samples = list()
all_samples = list() for n in trange(iterations, desc="Sampling"):
for n in trange(iterations, desc="Sampling"): seed_everything(seed)
seed_everything(seed) for prompts in tqdm(data, desc="data", dynamic_ncols=True):
for prompts in tqdm(data, desc="data", dynamic_ncols=True): uc = None
uc = None if cfg_scale != 1.0:
if cfg_scale != 1.0: uc = self.model.get_learned_conditioning(batch_size * [""])
uc = self.model.get_learned_conditioning(batch_size * [""]) if isinstance(prompts, tuple):
if isinstance(prompts, tuple): prompts = list(prompts)
prompts = list(prompts)
# weighted sub-prompts # weighted sub-prompts
subprompts,weights = T2I._split_weighted_subprompts(prompts[0]) subprompts,weights = T2I._split_weighted_subprompts(prompts[0])
if len(subprompts) > 1: if len(subprompts) > 1:
# i dont know if this is correct.. but it works # i dont know if this is correct.. but it works
c = torch.zeros_like(uc) c = torch.zeros_like(uc)
# get total weight for normalizing # get total weight for normalizing
totalWeight = sum(weights) totalWeight = sum(weights)
# normalize each "sub prompt" and add it # normalize each "sub prompt" and add it
for i in range(0,len(subprompts)): for i in range(0,len(subprompts)):
weight = weights[i] weight = weights[i]
if not skip_normalize: if not skip_normalize:
weight = weight / totalWeight weight = weight / totalWeight
c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight) c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight)
else: # just standard 1 prompt else: # just standard 1 prompt
c = self.model.get_learned_conditioning(prompts) c = self.model.get_learned_conditioning(prompts)
# encode (scaled latent) # encode (scaled latent)
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device)) z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device))
# decode it # decode it
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale, samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,) unconditional_conditioning=uc,)
x_samples = self.model.decode_first_stage(samples) x_samples = self.model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
for x_sample in x_samples: for x_sample in x_samples:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
image = Image.fromarray(x_sample.astype(np.uint8)) image = Image.fromarray(x_sample.astype(np.uint8))
images.append([image,seed]) images.append([image,seed])
if callback is not None: if callback is not None:
callback(image,seed) callback(image,seed)
seed = self._new_seed() seed = self._new_seed()
except KeyboardInterrupt:
print('*interrupted*')
print('Partial results will be returned; if --grid was requested, nothing will be returned.')
except RuntimeError as e:
print("Oops! A runtime error has occurred. If this is unexpected, please copy-and-paste this stack trace and post it as an Issue to http://github.com/lstein/stable-diffusion")
traceback.print_exc()
return images return images
def _new_seed(self): def _new_seed(self):