factor out loop

This commit is contained in:
Kevin Gibbons 2022-08-25 15:19:44 -07:00
parent a10baf5808
commit 078859207d

View File

@ -216,7 +216,7 @@ The vast majority of these arguments default to reasonable values.
strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image) ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
variants // if >0, the 1st generated image will be passed back to img2img to generate the requested number of variants variants // if >0, the 1st generated image will be passed back to img2img to generate the requested number of variants
callback // a function or method that will be called each time an image is generated image_callback // a function or method that will be called each time an image is generated
To use the callback, define a function of method that receives two arguments, an Image object To use the callback, define a function of method that receives two arguments, an Image object
and the seed. You can then do whatever you like with the image, including converting it to and the seed. You can then do whatever you like with the image, including converting it to
@ -249,34 +249,40 @@ The vast majority of these arguments default to reasonable values.
height = h height = h
width = w width = w
data = [batch_size * [prompt]]
scope = autocast if self.precision=="autocast" else nullcontext scope = autocast if self.precision=="autocast" else nullcontext
tic = time.time() tic = time.time()
results = list() results = list()
def prompt_callback(image, seed):
results.append([image, seed])
if image_callback is not None:
image_callback(image, seed)
try: try:
if init_img: if init_img:
assert os.path.exists(init_img),f'{init_img}: File not found' assert os.path.exists(init_img),f'{init_img}: File not found'
self._img2img(prompt, get_images = self._img2img(
data=data,precision_scope=scope, precision_scope=scope,
batch_size=batch_size,iterations=iterations, batch_size=batch_size,
steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta, steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
skip_normalize=skip_normalize, skip_normalize=skip_normalize,
init_img=init_img,strength=strength,variants=variants, init_img=init_img,strength=strength)
callback=prompt_callback)
else: else:
self._txt2img(prompt, get_images = self._txt2img(
data=data,precision_scope=scope, precision_scope=scope,
batch_size=batch_size,iterations=iterations, batch_size=batch_size,
steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta, steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
skip_normalize=skip_normalize, skip_normalize=skip_normalize,
width=width,height=height, width=width,height=height)
callback=prompt_callback)
data = [batch_size * [prompt]]
with scope(self.device.type), self.model.ema_scope():
for n in trange(iterations, desc="Sampling"):
seed_everything(seed)
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
iter_images = get_images(prompts)
for image in iter_images:
results.append([image, seed])
if image_callback is not None:
image_callback(image,seed)
seed = self._new_seed()
except KeyboardInterrupt: except KeyboardInterrupt:
print('*interrupted*') print('*interrupted*')
print('Partial results will be returned; if --grid was requested, nothing will be returned.') print('Partial results will be returned; if --grid was requested, nothing will be returned.')
@ -288,84 +294,41 @@ The vast majority of these arguments default to reasonable values.
return results return results
@torch.no_grad() @torch.no_grad()
def _txt2img(self,prompt, def _txt2img(self,
data,precision_scope, precision_scope,
batch_size,iterations, batch_size,
steps,seed,cfg_scale,ddim_eta, steps,cfg_scale,ddim_eta,
skip_normalize, skip_normalize,
width,height, width,height):
callback): # the callback is called each time a new Image is generated
""" """
Generate an image from the prompt, writing iteration images into the outdir Generate an image from the prompt
The output is a list of lists in the format: [[image1,seed1], [image2,seed2],...]
""" """
sampler = self.sampler sampler = self.sampler
images = list()
image_count = 0
# Gawd. Too many levels of indent here. Need to refactor into smaller routines! def get_images(prompts):
with precision_scope(self.device.type), self.model.ema_scope(): uc, c = self._get_uc_and_c(prompts, batch_size, skip_normalize)
all_samples = list() shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
for n in trange(iterations, desc="Sampling"): samples, _ = sampler.sample(S=steps,
seed_everything(seed) conditioning=c,
for prompts in tqdm(data, desc="data", dynamic_ncols=True): batch_size=batch_size,
uc = None shape=shape,
if cfg_scale != 1.0: verbose=False,
uc = self.model.get_learned_conditioning(batch_size * [""]) unconditional_guidance_scale=cfg_scale,
if isinstance(prompts, tuple): unconditional_conditioning=uc,
prompts = list(prompts) eta=ddim_eta)
return self._samples_to_images(samples)
# weighted sub-prompts return get_images
subprompts,weights = T2I._split_weighted_subprompts(prompts[0])
if len(subprompts) > 1:
# i dont know if this is correct.. but it works
c = torch.zeros_like(uc)
# get total weight for normalizing
totalWeight = sum(weights)
# normalize each "sub prompt" and add it
for i in range(0,len(subprompts)):
weight = weights[i]
if not skip_normalize:
weight = weight / totalWeight
c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight)
else: # just standard 1 prompt
c = self.model.get_learned_conditioning(prompts)
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
samples_ddim, _ = sampler.sample(S=steps,
conditioning=c,
batch_size=batch_size,
shape=shape,
verbose=False,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
eta=ddim_eta)
x_samples_ddim = self.model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
image = Image.fromarray(x_sample.astype(np.uint8))
images.append([image,seed])
if callback is not None:
callback(image,seed)
seed = self._new_seed()
return images
@torch.no_grad() @torch.no_grad()
def _img2img(self,prompt, def _img2img(self,
data,precision_scope, precision_scope,
batch_size,iterations, batch_size,
steps,seed,cfg_scale,ddim_eta, steps,cfg_scale,ddim_eta,
skip_normalize, skip_normalize,
init_img,strength,variants, init_img,strength):
callback):
""" """
Generate an image from the prompt and the initial image, writing iteration images into the outdir Generate an image from the prompt and the initial image
The output is a list of lists in the format: [[image,seed1], [image,seed2],...]
""" """
# PLMS sampler not supported yet, so ignore previous sampler # PLMS sampler not supported yet, so ignore previous sampler
@ -384,54 +347,53 @@ The vast majority of these arguments default to reasonable values.
t_enc = int(strength * steps) t_enc = int(strength * steps)
# print(f"target t_enc is {t_enc} steps") # print(f"target t_enc is {t_enc} steps")
def get_images(prompts):
uc, c = self._get_uc_and_c(prompts, batch_size, skip_normalize)
# encode (scaled latent)
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device))
# decode it
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,)
return self._samples_to_images(samples)
return get_images
# TODO: does this actually need to run every loop? does anything in it vary by random seed?
def _get_uc_and_c(self, prompts, batch_size, skip_normalize):
if isinstance(prompts, tuple):
prompts = list(prompts)
uc = self.model.get_learned_conditioning(batch_size * [""])
# weighted sub-prompts
subprompts,weights = T2I._split_weighted_subprompts(prompts[0])
if len(subprompts) > 1:
# i dont know if this is correct.. but it works
c = torch.zeros_like(uc)
# get total weight for normalizing
totalWeight = sum(weights)
# normalize each "sub prompt" and add it
for i in range(0,len(subprompts)):
weight = weights[i]
if not skip_normalize:
weight = weight / totalWeight
c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight)
else: # just standard 1 prompt
c = self.model.get_learned_conditioning(prompts)
return (uc, c)
def _samples_to_images(self, samples):
x_samples = self.model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
images = list() images = list()
for x_sample in x_samples:
with precision_scope(self.device.type), self.model.ema_scope(): x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
all_samples = list() image = Image.fromarray(x_sample.astype(np.uint8))
for n in trange(iterations, desc="Sampling"): images.append(image)
seed_everything(seed)
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
uc = None
if cfg_scale != 1.0:
uc = self.model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
# weighted sub-prompts
subprompts,weights = T2I._split_weighted_subprompts(prompts[0])
if len(subprompts) > 1:
# i dont know if this is correct.. but it works
c = torch.zeros_like(uc)
# get total weight for normalizing
totalWeight = sum(weights)
# normalize each "sub prompt" and add it
for i in range(0,len(subprompts)):
weight = weights[i]
if not skip_normalize:
weight = weight / totalWeight
c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight)
else: # just standard 1 prompt
c = self.model.get_learned_conditioning(prompts)
# encode (scaled latent)
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device))
# decode it
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,)
x_samples = self.model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
for x_sample in x_samples:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
image = Image.fromarray(x_sample.astype(np.uint8))
images.append([image,seed])
if callback is not None:
callback(image,seed)
seed = self._new_seed()
return images return images
def _new_seed(self): def _new_seed(self):
self.seed = random.randrange(0,np.iinfo(np.uint32).max) self.seed = random.randrange(0,np.iinfo(np.uint32).max)
return self.seed return self.seed