mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
switch to generators
This commit is contained in:
parent
078859207d
commit
31b22e057d
@ -257,26 +257,25 @@ The vast majority of these arguments default to reasonable values.
|
||||
try:
|
||||
if init_img:
|
||||
assert os.path.exists(init_img),f'{init_img}: File not found'
|
||||
get_images = self._img2img(
|
||||
images_iterator = self._img2img(prompt,
|
||||
precision_scope=scope,
|
||||
batch_size=batch_size,
|
||||
steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
|
||||
skip_normalize=skip_normalize,
|
||||
init_img=init_img,strength=strength)
|
||||
else:
|
||||
get_images = self._txt2img(
|
||||
images_iterator = self._txt2img(prompt,
|
||||
precision_scope=scope,
|
||||
batch_size=batch_size,
|
||||
steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
|
||||
skip_normalize=skip_normalize,
|
||||
width=width,height=height)
|
||||
|
||||
data = [batch_size * [prompt]]
|
||||
with scope(self.device.type), self.model.ema_scope():
|
||||
for n in trange(iterations, desc="Sampling"):
|
||||
seed_everything(seed)
|
||||
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
|
||||
iter_images = get_images(prompts)
|
||||
for batch_item in tqdm(range(batch_size), desc="data", dynamic_ncols=True):
|
||||
iter_images = next(images_iterator)
|
||||
for image in iter_images:
|
||||
results.append([image, seed])
|
||||
if image_callback is not None:
|
||||
@ -295,19 +294,20 @@ The vast majority of these arguments default to reasonable values.
|
||||
|
||||
@torch.no_grad()
|
||||
def _txt2img(self,
|
||||
prompt,
|
||||
precision_scope,
|
||||
batch_size,
|
||||
steps,cfg_scale,ddim_eta,
|
||||
skip_normalize,
|
||||
width,height):
|
||||
"""
|
||||
Generate an image from the prompt
|
||||
An infinite iterator of images from the prompt.
|
||||
"""
|
||||
|
||||
sampler = self.sampler
|
||||
|
||||
def get_images(prompts):
|
||||
uc, c = self._get_uc_and_c(prompts, batch_size, skip_normalize)
|
||||
while True:
|
||||
uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize)
|
||||
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
|
||||
samples, _ = sampler.sample(S=steps,
|
||||
conditioning=c,
|
||||
@ -317,18 +317,18 @@ The vast majority of these arguments default to reasonable values.
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,
|
||||
eta=ddim_eta)
|
||||
return self._samples_to_images(samples)
|
||||
return get_images
|
||||
yield self._samples_to_images(samples)
|
||||
|
||||
@torch.no_grad()
|
||||
def _img2img(self,
|
||||
prompt,
|
||||
precision_scope,
|
||||
batch_size,
|
||||
steps,cfg_scale,ddim_eta,
|
||||
skip_normalize,
|
||||
init_img,strength):
|
||||
"""
|
||||
Generate an image from the prompt and the initial image
|
||||
An infinite iterator of images from the prompt and the initial image
|
||||
"""
|
||||
|
||||
# PLMS sampler not supported yet, so ignore previous sampler
|
||||
@ -348,16 +348,15 @@ The vast majority of these arguments default to reasonable values.
|
||||
t_enc = int(strength * steps)
|
||||
# print(f"target t_enc is {t_enc} steps")
|
||||
|
||||
def get_images(prompts):
|
||||
uc, c = self._get_uc_and_c(prompts, batch_size, skip_normalize)
|
||||
while True:
|
||||
uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize)
|
||||
|
||||
# encode (scaled latent)
|
||||
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device))
|
||||
# decode it
|
||||
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,)
|
||||
return self._samples_to_images(samples)
|
||||
return get_images
|
||||
yield self._samples_to_images(samples)
|
||||
|
||||
# TODO: does this actually need to run every loop? does anything in it vary by random seed?
|
||||
def _get_uc_and_c(self, prompts, batch_size, skip_normalize):
|
||||
@ -393,7 +392,6 @@ The vast majority of these arguments default to reasonable values.
|
||||
images.append(image)
|
||||
return images
|
||||
|
||||
|
||||
def _new_seed(self):
|
||||
self.seed = random.randrange(0,np.iinfo(np.uint32).max)
|
||||
return self.seed
|
||||
|
Loading…
Reference in New Issue
Block a user