Refactor so that behavior is consolidated at top level

This commit is contained in:
Sean McLellan 2022-08-26 00:39:57 -04:00
parent 470a62dbbe
commit fcdd95b652

View File

@ -282,6 +282,11 @@ The vast majority of these arguments default to reasonable values.
seed_everything(seed)
iter_images = next(images_iterator)
for image in iter_images:
try:
if gfpgan_strength > 0:
image = self._run_gfpgan(image, gfpgan_strength)
except Exception as e:
print(f"Error running GFPGAN - Your image was not enhanced.\n{e}")
results.append([image, seed])
if image_callback is not None:
image_callback(image,seed)
@ -305,7 +310,6 @@ The vast majority of these arguments default to reasonable values.
batch_size,
steps,cfg_scale,ddim_eta,
skip_normalize,
gfpgan_strength,
width,height):
"""
An infinite iterator of images from the prompt.
@ -325,7 +329,7 @@ The vast majority of these arguments default to reasonable values.
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
eta=ddim_eta)
yield self._samples_to_images(samples, gfpgan_strength=gfpgan_strength)
yield self._samples_to_images(samples)
@torch.no_grad()
def _img2img(self,
@ -334,7 +338,6 @@ The vast majority of these arguments default to reasonable values.
batch_size,
steps,cfg_scale,ddim_eta,
skip_normalize,
gfpgan_strength,
init_img,strength):
"""
An infinite iterator of images from the prompt and the initial image
@ -365,7 +368,7 @@ The vast majority of these arguments default to reasonable values.
# decode it
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,)
yield self._samples_to_images(samples, gfpgan_strength)
yield self._samples_to_images(samples)
# TODO: does this actually need to run every loop? does anything in it vary by random seed?
def _get_uc_and_c(self, prompt, batch_size, skip_normalize):
@ -389,18 +392,13 @@ The vast majority of these arguments default to reasonable values.
c = self.model.get_learned_conditioning(batch_size * [prompt])
return (uc, c)
def _samples_to_images(self, samples, gfpgan_strength=0):
def _samples_to_images(self, samples):
x_samples = self.model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
images = list()
for x_sample in x_samples:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
image = Image.fromarray(x_sample.astype(np.uint8))
try:
if gfpgan_strength > 0:
image = self._run_gfpgan(image, gfpgan_strength)
except Exception as e:
print(f"Error running GFPGAN - Your image was not enhanced.\n{e}")
images.append(image)
return images