From a10baf58082defe05311c2d90614c387651301bd Mon Sep 17 00:00:00 2001 From: Kevin Gibbons Date: Thu, 25 Aug 2022 15:13:07 -0700 Subject: [PATCH] factor out exception handler --- ldm/simplet2i.py | 248 +++++++++++++++++++++++------------------------ 1 file changed, 124 insertions(+), 124 deletions(-) diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index 0ec3d60d98..157f55fbcb 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -52,7 +52,7 @@ t2i = T2I(model = // models/ldm/stable-diffusion-v1/model.ck # do the slow model initialization t2i.load_model() -# Do the fast inference & image generation. Any options passed here +# Do the fast inference & image generation. Any options passed here # override the default values assigned during class initialization # Will call load_model() if the model was not previously loaded and so # may be slow at first. @@ -70,7 +70,7 @@ results = t2i.prompt2png(prompt = "an astronaut riding a horse", outdir = "./outputs/, iterations = 3, init_img = "./sketches/horse+rider.png") - + for row in results: print(f'filename={row[0]}') print(f'seed ={row[1]}') @@ -181,7 +181,7 @@ The vast majority of these arguments default to reasonable values. outdir = kwargs.get('outdir','outputs/img-samples') assert 'init_img' in kwargs,'call to img2img() must include the init_img argument' return self.prompt2png(prompt,outdir,**kwargs) - + def prompt2image(self, # these are common prompt, @@ -219,7 +219,7 @@ The vast majority of these arguments default to reasonable values. callback // a function or method that will be called each time an image is generated To use the callback, define a function of method that receives two arguments, an Image object - and the seed. You can then do whatever you like with the image, including converting it to + and the seed. You can then do whatever you like with the image, including converting it to different formats and manipulating it. For example: def process_image(image,seed): @@ -252,28 +252,41 @@ The vast majority of these arguments default to reasonable values. data = [batch_size * [prompt]] scope = autocast if self.precision=="autocast" else nullcontext - tic = time.time() - if init_img: - assert os.path.exists(init_img),f'{init_img}: File not found' - results = self._img2img(prompt, - data=data,precision_scope=scope, - batch_size=batch_size,iterations=iterations, - steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta, - skip_normalize=skip_normalize, - init_img=init_img,strength=strength,variants=variants, - callback=image_callback) - else: - results = self._txt2img(prompt, - data=data,precision_scope=scope, - batch_size=batch_size,iterations=iterations, - steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta, - skip_normalize=skip_normalize, - width=width,height=height, - callback=image_callback) + tic = time.time() + results = list() + def prompt_callback(image, seed): + results.append([image, seed]) + if image_callback is not None: + image_callback(image, seed) + + try: + if init_img: + assert os.path.exists(init_img),f'{init_img}: File not found' + self._img2img(prompt, + data=data,precision_scope=scope, + batch_size=batch_size,iterations=iterations, + steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta, + skip_normalize=skip_normalize, + init_img=init_img,strength=strength,variants=variants, + callback=prompt_callback) + else: + self._txt2img(prompt, + data=data,precision_scope=scope, + batch_size=batch_size,iterations=iterations, + steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta, + skip_normalize=skip_normalize, + width=width,height=height, + callback=prompt_callback) + except KeyboardInterrupt: + print('*interrupted*') + print('Partial results will be returned; if --grid was requested, nothing will be returned.') + except RuntimeError as e: + print(str(e)) + toc = time.time() print(f'{len(results)} images generated in',"%4.2fs"% (toc-tic)) return results - + @torch.no_grad() def _txt2img(self,prompt, data,precision_scope, @@ -292,62 +305,56 @@ The vast majority of these arguments default to reasonable values. image_count = 0 # Gawd. Too many levels of indent here. Need to refactor into smaller routines! - try: - with precision_scope(self.device.type), self.model.ema_scope(): - all_samples = list() - for n in trange(iterations, desc="Sampling"): - seed_everything(seed) - for prompts in tqdm(data, desc="data", dynamic_ncols=True): - uc = None - if cfg_scale != 1.0: - uc = self.model.get_learned_conditioning(batch_size * [""]) - if isinstance(prompts, tuple): - prompts = list(prompts) + with precision_scope(self.device.type), self.model.ema_scope(): + all_samples = list() + for n in trange(iterations, desc="Sampling"): + seed_everything(seed) + for prompts in tqdm(data, desc="data", dynamic_ncols=True): + uc = None + if cfg_scale != 1.0: + uc = self.model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) - # weighted sub-prompts - subprompts,weights = T2I._split_weighted_subprompts(prompts[0]) - if len(subprompts) > 1: - # i dont know if this is correct.. but it works - c = torch.zeros_like(uc) - # get total weight for normalizing - totalWeight = sum(weights) - # normalize each "sub prompt" and add it - for i in range(0,len(subprompts)): - weight = weights[i] - if not skip_normalize: - weight = weight / totalWeight - c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight) - else: # just standard 1 prompt - c = self.model.get_learned_conditioning(prompts) + # weighted sub-prompts + subprompts,weights = T2I._split_weighted_subprompts(prompts[0]) + if len(subprompts) > 1: + # i dont know if this is correct.. but it works + c = torch.zeros_like(uc) + # get total weight for normalizing + totalWeight = sum(weights) + # normalize each "sub prompt" and add it + for i in range(0,len(subprompts)): + weight = weights[i] + if not skip_normalize: + weight = weight / totalWeight + c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight) + else: # just standard 1 prompt + c = self.model.get_learned_conditioning(prompts) - shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor] - samples_ddim, _ = sampler.sample(S=steps, - conditioning=c, - batch_size=batch_size, - shape=shape, - verbose=False, - unconditional_guidance_scale=cfg_scale, - unconditional_conditioning=uc, - eta=ddim_eta) + shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor] + samples_ddim, _ = sampler.sample(S=steps, + conditioning=c, + batch_size=batch_size, + shape=shape, + verbose=False, + unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=uc, + eta=ddim_eta) - x_samples_ddim = self.model.decode_first_stage(samples_ddim) - x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - for x_sample in x_samples_ddim: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - image = Image.fromarray(x_sample.astype(np.uint8)) - images.append([image,seed]) - if callback is not None: - callback(image,seed) - - seed = self._new_seed() - except KeyboardInterrupt: - print('*interrupted*') - print('Partial results will be returned; if --grid was requested, nothing will be returned.') - except RuntimeError as e: - print(str(e)) + x_samples_ddim = self.model.decode_first_stage(samples_ddim) + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + for x_sample in x_samples_ddim: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + image = Image.fromarray(x_sample.astype(np.uint8)) + images.append([image,seed]) + if callback is not None: + callback(image,seed) + + seed = self._new_seed() return images - + @torch.no_grad() def _img2img(self,prompt, data,precision_scope, @@ -374,62 +381,55 @@ The vast majority of these arguments default to reasonable values. init_latent = self.model.get_first_stage_encoding(self.model.encode_first_stage(init_image)) # move to latent space sampler.make_schedule(ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False) - + t_enc = int(strength * steps) # print(f"target t_enc is {t_enc} steps") images = list() - try: - with precision_scope(self.device.type), self.model.ema_scope(): - all_samples = list() - for n in trange(iterations, desc="Sampling"): - seed_everything(seed) - for prompts in tqdm(data, desc="data", dynamic_ncols=True): - uc = None - if cfg_scale != 1.0: - uc = self.model.get_learned_conditioning(batch_size * [""]) - if isinstance(prompts, tuple): - prompts = list(prompts) + with precision_scope(self.device.type), self.model.ema_scope(): + all_samples = list() + for n in trange(iterations, desc="Sampling"): + seed_everything(seed) + for prompts in tqdm(data, desc="data", dynamic_ncols=True): + uc = None + if cfg_scale != 1.0: + uc = self.model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) - # weighted sub-prompts - subprompts,weights = T2I._split_weighted_subprompts(prompts[0]) - if len(subprompts) > 1: - # i dont know if this is correct.. but it works - c = torch.zeros_like(uc) - # get total weight for normalizing - totalWeight = sum(weights) - # normalize each "sub prompt" and add it - for i in range(0,len(subprompts)): - weight = weights[i] - if not skip_normalize: - weight = weight / totalWeight - c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight) - else: # just standard 1 prompt - c = self.model.get_learned_conditioning(prompts) + # weighted sub-prompts + subprompts,weights = T2I._split_weighted_subprompts(prompts[0]) + if len(subprompts) > 1: + # i dont know if this is correct.. but it works + c = torch.zeros_like(uc) + # get total weight for normalizing + totalWeight = sum(weights) + # normalize each "sub prompt" and add it + for i in range(0,len(subprompts)): + weight = weights[i] + if not skip_normalize: + weight = weight / totalWeight + c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight) + else: # just standard 1 prompt + c = self.model.get_learned_conditioning(prompts) - # encode (scaled latent) - z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device)) - # decode it - samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale, - unconditional_conditioning=uc,) + # encode (scaled latent) + z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device)) + # decode it + samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=uc,) - x_samples = self.model.decode_first_stage(samples) - x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) + x_samples = self.model.decode_first_stage(samples) + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) - for x_sample in x_samples: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - image = Image.fromarray(x_sample.astype(np.uint8)) - images.append([image,seed]) - if callback is not None: - callback(image,seed) - seed = self._new_seed() + for x_sample in x_samples: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + image = Image.fromarray(x_sample.astype(np.uint8)) + images.append([image,seed]) + if callback is not None: + callback(image,seed) + seed = self._new_seed() - except KeyboardInterrupt: - print('*interrupted*') - print('Partial results will be returned; if --grid was requested, nothing will be returned.') - except RuntimeError as e: - print("Oops! A runtime error has occurred. If this is unexpected, please copy-and-paste this stack trace and post it as an Issue to http://github.com/lstein/stable-diffusion") - traceback.print_exc() return images def _new_seed(self): @@ -476,7 +476,7 @@ The vast majority of these arguments default to reasonable values. print(msg) return self.model - + def _load_model_from_config(self, config, ckpt): print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") @@ -507,7 +507,7 @@ The vast majority of these arguments default to reasonable values. def _split_weighted_subprompts(text): """ - grabs all text up to the first occurrence of ':' + grabs all text up to the first occurrence of ':' uses the grabbed text as a sub-prompt, and takes the value following ':' as weight if ':' has no value defined, defaults to 1.0 repeats until no text remaining @@ -523,7 +523,7 @@ The vast majority of these arguments default to reasonable values. remaining -= idx # remove from main text text = text[idx+1:] - # find value for weight + # find value for weight if " " in text: idx = text.index(" ") # first occurence else: # no space, read to end