fix batch_size

This commit is contained in:
Kevin Gibbons 2022-08-25 17:16:07 -07:00
parent 31b22e057d
commit 797de3257c

View File

@ -274,7 +274,6 @@ The vast majority of these arguments default to reasonable values.
with scope(self.device.type), self.model.ema_scope():
for n in trange(iterations, desc="Sampling"):
seed_everything(seed)
for batch_item in tqdm(range(batch_size), desc="data", dynamic_ncols=True):
iter_images = next(images_iterator)
for image in iter_images:
results.append([image, seed])
@ -359,14 +358,12 @@ The vast majority of these arguments default to reasonable values.
yield self._samples_to_images(samples)
# TODO: does this actually need to run every loop? does anything in it vary by random seed?
def _get_uc_and_c(self, prompts, batch_size, skip_normalize):
if isinstance(prompts, tuple):
prompts = list(prompts)
def _get_uc_and_c(self, prompt, batch_size, skip_normalize):
uc = self.model.get_learned_conditioning(batch_size * [""])
# weighted sub-prompts
subprompts,weights = T2I._split_weighted_subprompts(prompts[0])
subprompts,weights = T2I._split_weighted_subprompts(prompt)
if len(subprompts) > 1:
# i dont know if this is correct.. but it works
c = torch.zeros_like(uc)
@ -377,9 +374,9 @@ The vast majority of these arguments default to reasonable values.
weight = weights[i]
if not skip_normalize:
weight = weight / totalWeight
c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight)
c = torch.add(c, self.model.get_learned_conditioning(batch_size * [subprompts[i]]), alpha=weight)
else: # just standard 1 prompt
c = self.model.get_learned_conditioning(prompts)
c = self.model.get_learned_conditioning(batch_size * [prompt])
return (uc, c)
def _samples_to_images(self, samples):