From fcdd95b652b8f503528d3e3526590e3c7ab2f3dc Mon Sep 17 00:00:00 2001 From: Sean McLellan Date: Fri, 26 Aug 2022 00:39:57 -0400 Subject: [PATCH] Refactor so that behavior is consolidated at top level --- ldm/simplet2i.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index 0de3a33237..6ae45b36bf 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -282,6 +282,11 @@ The vast majority of these arguments default to reasonable values. seed_everything(seed) iter_images = next(images_iterator) for image in iter_images: + try: + if gfpgan_strength > 0: + image = self._run_gfpgan(image, gfpgan_strength) + except Exception as e: + print(f"Error running GFPGAN - Your image was not enhanced.\n{e}") results.append([image, seed]) if image_callback is not None: image_callback(image,seed) @@ -305,7 +310,6 @@ The vast majority of these arguments default to reasonable values. batch_size, steps,cfg_scale,ddim_eta, skip_normalize, - gfpgan_strength, width,height): """ An infinite iterator of images from the prompt. @@ -325,7 +329,7 @@ The vast majority of these arguments default to reasonable values. unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, eta=ddim_eta) - yield self._samples_to_images(samples, gfpgan_strength=gfpgan_strength) + yield self._samples_to_images(samples) @torch.no_grad() def _img2img(self, @@ -334,7 +338,6 @@ The vast majority of these arguments default to reasonable values. batch_size, steps,cfg_scale,ddim_eta, skip_normalize, - gfpgan_strength, init_img,strength): """ An infinite iterator of images from the prompt and the initial image @@ -365,7 +368,7 @@ The vast majority of these arguments default to reasonable values. # decode it samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc,) - yield self._samples_to_images(samples, gfpgan_strength) + yield self._samples_to_images(samples) # TODO: does this actually need to run every loop? does anything in it vary by random seed? def _get_uc_and_c(self, prompt, batch_size, skip_normalize): @@ -389,18 +392,13 @@ The vast majority of these arguments default to reasonable values. c = self.model.get_learned_conditioning(batch_size * [prompt]) return (uc, c) - def _samples_to_images(self, samples, gfpgan_strength=0): + def _samples_to_images(self, samples): x_samples = self.model.decode_first_stage(samples) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) images = list() for x_sample in x_samples: x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') image = Image.fromarray(x_sample.astype(np.uint8)) - try: - if gfpgan_strength > 0: - image = self._run_gfpgan(image, gfpgan_strength) - except Exception as e: - print(f"Error running GFPGAN - Your image was not enhanced.\n{e}") images.append(image) return images