From 31b22e057d12daddd9a9b79c8d156288dfad3b95 Mon Sep 17 00:00:00 2001 From: Kevin Gibbons Date: Thu, 25 Aug 2022 17:01:17 -0700 Subject: [PATCH] switch to generators --- ldm/simplet2i.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index 31721906d7..d63502831d 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -257,26 +257,25 @@ The vast majority of these arguments default to reasonable values. try: if init_img: assert os.path.exists(init_img),f'{init_img}: File not found' - get_images = self._img2img( + images_iterator = self._img2img(prompt, precision_scope=scope, batch_size=batch_size, steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta, skip_normalize=skip_normalize, init_img=init_img,strength=strength) else: - get_images = self._txt2img( + images_iterator = self._txt2img(prompt, precision_scope=scope, batch_size=batch_size, steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta, skip_normalize=skip_normalize, width=width,height=height) - data = [batch_size * [prompt]] with scope(self.device.type), self.model.ema_scope(): for n in trange(iterations, desc="Sampling"): seed_everything(seed) - for prompts in tqdm(data, desc="data", dynamic_ncols=True): - iter_images = get_images(prompts) + for batch_item in tqdm(range(batch_size), desc="data", dynamic_ncols=True): + iter_images = next(images_iterator) for image in iter_images: results.append([image, seed]) if image_callback is not None: @@ -295,19 +294,20 @@ The vast majority of these arguments default to reasonable values. @torch.no_grad() def _txt2img(self, + prompt, precision_scope, batch_size, steps,cfg_scale,ddim_eta, skip_normalize, width,height): """ - Generate an image from the prompt + An infinite iterator of images from the prompt. """ sampler = self.sampler - def get_images(prompts): - uc, c = self._get_uc_and_c(prompts, batch_size, skip_normalize) + while True: + uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize) shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor] samples, _ = sampler.sample(S=steps, conditioning=c, @@ -317,18 +317,18 @@ The vast majority of these arguments default to reasonable values. unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, eta=ddim_eta) - return self._samples_to_images(samples) - return get_images + yield self._samples_to_images(samples) @torch.no_grad() def _img2img(self, + prompt, precision_scope, batch_size, steps,cfg_scale,ddim_eta, skip_normalize, init_img,strength): """ - Generate an image from the prompt and the initial image + An infinite iterator of images from the prompt and the initial image """ # PLMS sampler not supported yet, so ignore previous sampler @@ -348,16 +348,15 @@ The vast majority of these arguments default to reasonable values. t_enc = int(strength * steps) # print(f"target t_enc is {t_enc} steps") - def get_images(prompts): - uc, c = self._get_uc_and_c(prompts, batch_size, skip_normalize) + while True: + uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize) # encode (scaled latent) z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device)) # decode it samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc,) - return self._samples_to_images(samples) - return get_images + yield self._samples_to_images(samples) # TODO: does this actually need to run every loop? does anything in it vary by random seed? def _get_uc_and_c(self, prompts, batch_size, skip_normalize): @@ -393,7 +392,6 @@ The vast majority of these arguments default to reasonable values. images.append(image) return images - def _new_seed(self): self.seed = random.randrange(0,np.iinfo(np.uint32).max) return self.seed