From 078859207df4a6149cb8cdf7d4d9b4bb1fef1ae6 Mon Sep 17 00:00:00 2001 From: Kevin Gibbons Date: Thu, 25 Aug 2022 15:19:44 -0700 Subject: [PATCH] factor out loop --- ldm/simplet2i.py | 226 ++++++++++++++++++++--------------------------- 1 file changed, 94 insertions(+), 132 deletions(-) diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index 157f55fbcb..31721906d7 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -216,7 +216,7 @@ The vast majority of these arguments default to reasonable values. strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image) variants // if >0, the 1st generated image will be passed back to img2img to generate the requested number of variants - callback // a function or method that will be called each time an image is generated + image_callback // a function or method that will be called each time an image is generated To use the callback, define a function of method that receives two arguments, an Image object and the seed. You can then do whatever you like with the image, including converting it to @@ -249,34 +249,40 @@ The vast majority of these arguments default to reasonable values. height = h width = w - data = [batch_size * [prompt]] scope = autocast if self.precision=="autocast" else nullcontext tic = time.time() results = list() - def prompt_callback(image, seed): - results.append([image, seed]) - if image_callback is not None: - image_callback(image, seed) try: if init_img: assert os.path.exists(init_img),f'{init_img}: File not found' - self._img2img(prompt, - data=data,precision_scope=scope, - batch_size=batch_size,iterations=iterations, - steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta, + get_images = self._img2img( + precision_scope=scope, + batch_size=batch_size, + steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta, skip_normalize=skip_normalize, - init_img=init_img,strength=strength,variants=variants, - callback=prompt_callback) + init_img=init_img,strength=strength) else: - self._txt2img(prompt, - data=data,precision_scope=scope, - batch_size=batch_size,iterations=iterations, - steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta, + get_images = self._txt2img( + precision_scope=scope, + batch_size=batch_size, + steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta, skip_normalize=skip_normalize, - width=width,height=height, - callback=prompt_callback) + width=width,height=height) + + data = [batch_size * [prompt]] + with scope(self.device.type), self.model.ema_scope(): + for n in trange(iterations, desc="Sampling"): + seed_everything(seed) + for prompts in tqdm(data, desc="data", dynamic_ncols=True): + iter_images = get_images(prompts) + for image in iter_images: + results.append([image, seed]) + if image_callback is not None: + image_callback(image,seed) + seed = self._new_seed() + except KeyboardInterrupt: print('*interrupted*') print('Partial results will be returned; if --grid was requested, nothing will be returned.') @@ -288,84 +294,41 @@ The vast majority of these arguments default to reasonable values. return results @torch.no_grad() - def _txt2img(self,prompt, - data,precision_scope, - batch_size,iterations, - steps,seed,cfg_scale,ddim_eta, + def _txt2img(self, + precision_scope, + batch_size, + steps,cfg_scale,ddim_eta, skip_normalize, - width,height, - callback): # the callback is called each time a new Image is generated + width,height): """ - Generate an image from the prompt, writing iteration images into the outdir - The output is a list of lists in the format: [[image1,seed1], [image2,seed2],...] + Generate an image from the prompt """ - sampler = self.sampler - images = list() - image_count = 0 + sampler = self.sampler - # Gawd. Too many levels of indent here. Need to refactor into smaller routines! - with precision_scope(self.device.type), self.model.ema_scope(): - all_samples = list() - for n in trange(iterations, desc="Sampling"): - seed_everything(seed) - for prompts in tqdm(data, desc="data", dynamic_ncols=True): - uc = None - if cfg_scale != 1.0: - uc = self.model.get_learned_conditioning(batch_size * [""]) - if isinstance(prompts, tuple): - prompts = list(prompts) - - # weighted sub-prompts - subprompts,weights = T2I._split_weighted_subprompts(prompts[0]) - if len(subprompts) > 1: - # i dont know if this is correct.. but it works - c = torch.zeros_like(uc) - # get total weight for normalizing - totalWeight = sum(weights) - # normalize each "sub prompt" and add it - for i in range(0,len(subprompts)): - weight = weights[i] - if not skip_normalize: - weight = weight / totalWeight - c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight) - else: # just standard 1 prompt - c = self.model.get_learned_conditioning(prompts) - - shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor] - samples_ddim, _ = sampler.sample(S=steps, - conditioning=c, - batch_size=batch_size, - shape=shape, - verbose=False, - unconditional_guidance_scale=cfg_scale, - unconditional_conditioning=uc, - eta=ddim_eta) - - x_samples_ddim = self.model.decode_first_stage(samples_ddim) - x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - for x_sample in x_samples_ddim: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - image = Image.fromarray(x_sample.astype(np.uint8)) - images.append([image,seed]) - if callback is not None: - callback(image,seed) - - seed = self._new_seed() - - return images + def get_images(prompts): + uc, c = self._get_uc_and_c(prompts, batch_size, skip_normalize) + shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor] + samples, _ = sampler.sample(S=steps, + conditioning=c, + batch_size=batch_size, + shape=shape, + verbose=False, + unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=uc, + eta=ddim_eta) + return self._samples_to_images(samples) + return get_images @torch.no_grad() - def _img2img(self,prompt, - data,precision_scope, - batch_size,iterations, - steps,seed,cfg_scale,ddim_eta, + def _img2img(self, + precision_scope, + batch_size, + steps,cfg_scale,ddim_eta, skip_normalize, - init_img,strength,variants, - callback): + init_img,strength): """ - Generate an image from the prompt and the initial image, writing iteration images into the outdir - The output is a list of lists in the format: [[image,seed1], [image,seed2],...] + Generate an image from the prompt and the initial image """ # PLMS sampler not supported yet, so ignore previous sampler @@ -384,54 +347,53 @@ The vast majority of these arguments default to reasonable values. t_enc = int(strength * steps) # print(f"target t_enc is {t_enc} steps") + + def get_images(prompts): + uc, c = self._get_uc_and_c(prompts, batch_size, skip_normalize) + + # encode (scaled latent) + z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device)) + # decode it + samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale, + unconditional_conditioning=uc,) + return self._samples_to_images(samples) + return get_images + + # TODO: does this actually need to run every loop? does anything in it vary by random seed? + def _get_uc_and_c(self, prompts, batch_size, skip_normalize): + if isinstance(prompts, tuple): + prompts = list(prompts) + + uc = self.model.get_learned_conditioning(batch_size * [""]) + + # weighted sub-prompts + subprompts,weights = T2I._split_weighted_subprompts(prompts[0]) + if len(subprompts) > 1: + # i dont know if this is correct.. but it works + c = torch.zeros_like(uc) + # get total weight for normalizing + totalWeight = sum(weights) + # normalize each "sub prompt" and add it + for i in range(0,len(subprompts)): + weight = weights[i] + if not skip_normalize: + weight = weight / totalWeight + c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight) + else: # just standard 1 prompt + c = self.model.get_learned_conditioning(prompts) + return (uc, c) + + def _samples_to_images(self, samples): + x_samples = self.model.decode_first_stage(samples) + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) images = list() - - with precision_scope(self.device.type), self.model.ema_scope(): - all_samples = list() - for n in trange(iterations, desc="Sampling"): - seed_everything(seed) - for prompts in tqdm(data, desc="data", dynamic_ncols=True): - uc = None - if cfg_scale != 1.0: - uc = self.model.get_learned_conditioning(batch_size * [""]) - if isinstance(prompts, tuple): - prompts = list(prompts) - - # weighted sub-prompts - subprompts,weights = T2I._split_weighted_subprompts(prompts[0]) - if len(subprompts) > 1: - # i dont know if this is correct.. but it works - c = torch.zeros_like(uc) - # get total weight for normalizing - totalWeight = sum(weights) - # normalize each "sub prompt" and add it - for i in range(0,len(subprompts)): - weight = weights[i] - if not skip_normalize: - weight = weight / totalWeight - c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight) - else: # just standard 1 prompt - c = self.model.get_learned_conditioning(prompts) - - # encode (scaled latent) - z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device)) - # decode it - samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale, - unconditional_conditioning=uc,) - - x_samples = self.model.decode_first_stage(samples) - x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) - - for x_sample in x_samples: - x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') - image = Image.fromarray(x_sample.astype(np.uint8)) - images.append([image,seed]) - if callback is not None: - callback(image,seed) - seed = self._new_seed() - + for x_sample in x_samples: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + image = Image.fromarray(x_sample.astype(np.uint8)) + images.append(image) return images + def _new_seed(self): self.seed = random.randrange(0,np.iinfo(np.uint32).max) return self.seed