mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
switch to generators
This commit is contained in:
@ -257,26 +257,25 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
try:
|
try:
|
||||||
if init_img:
|
if init_img:
|
||||||
assert os.path.exists(init_img),f'{init_img}: File not found'
|
assert os.path.exists(init_img),f'{init_img}: File not found'
|
||||||
get_images = self._img2img(
|
images_iterator = self._img2img(prompt,
|
||||||
precision_scope=scope,
|
precision_scope=scope,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
|
steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
|
||||||
skip_normalize=skip_normalize,
|
skip_normalize=skip_normalize,
|
||||||
init_img=init_img,strength=strength)
|
init_img=init_img,strength=strength)
|
||||||
else:
|
else:
|
||||||
get_images = self._txt2img(
|
images_iterator = self._txt2img(prompt,
|
||||||
precision_scope=scope,
|
precision_scope=scope,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
|
steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
|
||||||
skip_normalize=skip_normalize,
|
skip_normalize=skip_normalize,
|
||||||
width=width,height=height)
|
width=width,height=height)
|
||||||
|
|
||||||
data = [batch_size * [prompt]]
|
|
||||||
with scope(self.device.type), self.model.ema_scope():
|
with scope(self.device.type), self.model.ema_scope():
|
||||||
for n in trange(iterations, desc="Sampling"):
|
for n in trange(iterations, desc="Sampling"):
|
||||||
seed_everything(seed)
|
seed_everything(seed)
|
||||||
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
|
for batch_item in tqdm(range(batch_size), desc="data", dynamic_ncols=True):
|
||||||
iter_images = get_images(prompts)
|
iter_images = next(images_iterator)
|
||||||
for image in iter_images:
|
for image in iter_images:
|
||||||
results.append([image, seed])
|
results.append([image, seed])
|
||||||
if image_callback is not None:
|
if image_callback is not None:
|
||||||
@ -295,19 +294,20 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _txt2img(self,
|
def _txt2img(self,
|
||||||
|
prompt,
|
||||||
precision_scope,
|
precision_scope,
|
||||||
batch_size,
|
batch_size,
|
||||||
steps,cfg_scale,ddim_eta,
|
steps,cfg_scale,ddim_eta,
|
||||||
skip_normalize,
|
skip_normalize,
|
||||||
width,height):
|
width,height):
|
||||||
"""
|
"""
|
||||||
Generate an image from the prompt
|
An infinite iterator of images from the prompt.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sampler = self.sampler
|
sampler = self.sampler
|
||||||
|
|
||||||
def get_images(prompts):
|
while True:
|
||||||
uc, c = self._get_uc_and_c(prompts, batch_size, skip_normalize)
|
uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize)
|
||||||
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
|
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
|
||||||
samples, _ = sampler.sample(S=steps,
|
samples, _ = sampler.sample(S=steps,
|
||||||
conditioning=c,
|
conditioning=c,
|
||||||
@ -317,18 +317,18 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
unconditional_guidance_scale=cfg_scale,
|
unconditional_guidance_scale=cfg_scale,
|
||||||
unconditional_conditioning=uc,
|
unconditional_conditioning=uc,
|
||||||
eta=ddim_eta)
|
eta=ddim_eta)
|
||||||
return self._samples_to_images(samples)
|
yield self._samples_to_images(samples)
|
||||||
return get_images
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _img2img(self,
|
def _img2img(self,
|
||||||
|
prompt,
|
||||||
precision_scope,
|
precision_scope,
|
||||||
batch_size,
|
batch_size,
|
||||||
steps,cfg_scale,ddim_eta,
|
steps,cfg_scale,ddim_eta,
|
||||||
skip_normalize,
|
skip_normalize,
|
||||||
init_img,strength):
|
init_img,strength):
|
||||||
"""
|
"""
|
||||||
Generate an image from the prompt and the initial image
|
An infinite iterator of images from the prompt and the initial image
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# PLMS sampler not supported yet, so ignore previous sampler
|
# PLMS sampler not supported yet, so ignore previous sampler
|
||||||
@ -348,16 +348,15 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
t_enc = int(strength * steps)
|
t_enc = int(strength * steps)
|
||||||
# print(f"target t_enc is {t_enc} steps")
|
# print(f"target t_enc is {t_enc} steps")
|
||||||
|
|
||||||
def get_images(prompts):
|
while True:
|
||||||
uc, c = self._get_uc_and_c(prompts, batch_size, skip_normalize)
|
uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize)
|
||||||
|
|
||||||
# encode (scaled latent)
|
# encode (scaled latent)
|
||||||
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device))
|
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device))
|
||||||
# decode it
|
# decode it
|
||||||
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale,
|
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale,
|
||||||
unconditional_conditioning=uc,)
|
unconditional_conditioning=uc,)
|
||||||
return self._samples_to_images(samples)
|
yield self._samples_to_images(samples)
|
||||||
return get_images
|
|
||||||
|
|
||||||
# TODO: does this actually need to run every loop? does anything in it vary by random seed?
|
# TODO: does this actually need to run every loop? does anything in it vary by random seed?
|
||||||
def _get_uc_and_c(self, prompts, batch_size, skip_normalize):
|
def _get_uc_and_c(self, prompts, batch_size, skip_normalize):
|
||||||
@ -393,7 +392,6 @@ The vast majority of these arguments default to reasonable values.
|
|||||||
images.append(image)
|
images.append(image)
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
|
||||||
def _new_seed(self):
|
def _new_seed(self):
|
||||||
self.seed = random.randrange(0,np.iinfo(np.uint32).max)
|
self.seed = random.randrange(0,np.iinfo(np.uint32).max)
|
||||||
return self.seed
|
return self.seed
|
||||||
|
Reference in New Issue
Block a user