switch to generators

This commit is contained in:
Kevin Gibbons 2022-08-25 17:01:17 -07:00
parent 078859207d
commit 31b22e057d

View File

@ -257,26 +257,25 @@ The vast majority of these arguments default to reasonable values.
try:
if init_img:
assert os.path.exists(init_img),f'{init_img}: File not found'
get_images = self._img2img(
images_iterator = self._img2img(prompt,
precision_scope=scope,
batch_size=batch_size,
steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
skip_normalize=skip_normalize,
init_img=init_img,strength=strength)
else:
get_images = self._txt2img(
images_iterator = self._txt2img(prompt,
precision_scope=scope,
batch_size=batch_size,
steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
skip_normalize=skip_normalize,
width=width,height=height)
data = [batch_size * [prompt]]
with scope(self.device.type), self.model.ema_scope():
for n in trange(iterations, desc="Sampling"):
seed_everything(seed)
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
iter_images = get_images(prompts)
for batch_item in tqdm(range(batch_size), desc="data", dynamic_ncols=True):
iter_images = next(images_iterator)
for image in iter_images:
results.append([image, seed])
if image_callback is not None:
@ -295,19 +294,20 @@ The vast majority of these arguments default to reasonable values.
@torch.no_grad()
def _txt2img(self,
prompt,
precision_scope,
batch_size,
steps,cfg_scale,ddim_eta,
skip_normalize,
width,height):
"""
Generate an image from the prompt
An infinite iterator of images from the prompt.
"""
sampler = self.sampler
def get_images(prompts):
uc, c = self._get_uc_and_c(prompts, batch_size, skip_normalize)
while True:
uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize)
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
samples, _ = sampler.sample(S=steps,
conditioning=c,
@ -317,18 +317,18 @@ The vast majority of these arguments default to reasonable values.
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
eta=ddim_eta)
return self._samples_to_images(samples)
return get_images
yield self._samples_to_images(samples)
@torch.no_grad()
def _img2img(self,
prompt,
precision_scope,
batch_size,
steps,cfg_scale,ddim_eta,
skip_normalize,
init_img,strength):
"""
Generate an image from the prompt and the initial image
An infinite iterator of images from the prompt and the initial image
"""
# PLMS sampler not supported yet, so ignore previous sampler
@ -348,16 +348,15 @@ The vast majority of these arguments default to reasonable values.
t_enc = int(strength * steps)
# print(f"target t_enc is {t_enc} steps")
def get_images(prompts):
uc, c = self._get_uc_and_c(prompts, batch_size, skip_normalize)
while True:
uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize)
# encode (scaled latent)
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device))
# decode it
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,)
return self._samples_to_images(samples)
return get_images
yield self._samples_to_images(samples)
# TODO: does this actually need to run every loop? does anything in it vary by random seed?
def _get_uc_and_c(self, prompts, batch_size, skip_normalize):
@ -393,7 +392,6 @@ The vast majority of these arguments default to reasonable values.
images.append(image)
return images
def _new_seed(self):
self.seed = random.randrange(0,np.iinfo(np.uint32).max)
return self.seed