Merge branch 'bakkot-more-refactor' into main

This commit is contained in:
Lincoln Stein 2022-08-25 22:19:27 -04:00
commit 23fb4a72bb
2 changed files with 114 additions and 156 deletions

View File

@ -216,7 +216,7 @@ The vast majority of these arguments default to reasonable values.
strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
variants // if >0, the 1st generated image will be passed back to img2img to generate the requested number of variants
callback // a function or method that will be called each time an image is generated
image_callback // a function or method that will be called each time an image is generated
To use the callback, define a function of method that receives two arguments, an Image object
and the seed. You can then do whatever you like with the image, including converting it to
@ -249,79 +249,67 @@ The vast majority of these arguments default to reasonable values.
height = h
width = w
data = [batch_size * [prompt]]
scope = autocast if self.precision=="autocast" else nullcontext
tic = time.time()
results = list()
try:
if init_img:
assert os.path.exists(init_img),f'{init_img}: File not found'
results = self._img2img(prompt,
data=data,precision_scope=scope,
batch_size=batch_size,iterations=iterations,
steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
images_iterator = self._img2img(prompt,
precision_scope=scope,
batch_size=batch_size,
steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
skip_normalize=skip_normalize,
init_img=init_img,strength=strength,variants=variants,
callback=image_callback)
init_img=init_img,strength=strength)
else:
results = self._txt2img(prompt,
data=data,precision_scope=scope,
batch_size=batch_size,iterations=iterations,
steps=steps,seed=seed,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
images_iterator = self._txt2img(prompt,
precision_scope=scope,
batch_size=batch_size,
steps=steps,cfg_scale=cfg_scale,ddim_eta=ddim_eta,
skip_normalize=skip_normalize,
width=width,height=height,
callback=image_callback)
width=width,height=height)
with scope(self.device.type), self.model.ema_scope():
for n in trange(iterations, desc="Sampling"):
seed_everything(seed)
iter_images = next(images_iterator)
for image in iter_images:
results.append([image, seed])
if image_callback is not None:
image_callback(image,seed)
seed = self._new_seed()
except KeyboardInterrupt:
print('*interrupted*')
print('Partial results will be returned; if --grid was requested, nothing will be returned.')
except RuntimeError as e:
print(str(e))
print('Are you sure your system has an adequate NVIDIA GPU?')
toc = time.time()
print(f'{len(results)} images generated in',"%4.2fs"% (toc-tic))
return results
@torch.no_grad()
def _txt2img(self,prompt,
data,precision_scope,
batch_size,iterations,
steps,seed,cfg_scale,ddim_eta,
def _txt2img(self,
prompt,
precision_scope,
batch_size,
steps,cfg_scale,ddim_eta,
skip_normalize,
width,height,
callback): # the callback is called each time a new Image is generated
width,height):
"""
Generate an image from the prompt, writing iteration images into the outdir
The output is a list of lists in the format: [[image1,seed1], [image2,seed2],...]
An infinite iterator of images from the prompt.
"""
sampler = self.sampler
images = list()
image_count = 0
# Gawd. Too many levels of indent here. Need to refactor into smaller routines!
try:
with precision_scope(self.device.type), self.model.ema_scope():
all_samples = list()
for n in trange(iterations, desc="Sampling"):
seed_everything(seed)
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
uc = None
if cfg_scale != 1.0:
uc = self.model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
# weighted sub-prompts
subprompts,weights = T2I._split_weighted_subprompts(prompts[0])
if len(subprompts) > 1:
# i dont know if this is correct.. but it works
c = torch.zeros_like(uc)
# get total weight for normalizing
totalWeight = sum(weights)
# normalize each "sub prompt" and add it
for i in range(0,len(subprompts)):
weight = weights[i]
if not skip_normalize:
weight = weight / totalWeight
c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight)
else: # just standard 1 prompt
c = self.model.get_learned_conditioning(prompts)
while True:
uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize)
shape = [self.latent_channels, height // self.downsampling_factor, width // self.downsampling_factor]
samples_ddim, _ = sampler.sample(S=steps,
samples, _ = sampler.sample(S=steps,
conditioning=c,
batch_size=batch_size,
shape=shape,
@ -329,36 +317,18 @@ The vast majority of these arguments default to reasonable values.
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
eta=ddim_eta)
x_samples_ddim = self.model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
image = Image.fromarray(x_sample.astype(np.uint8))
images.append([image,seed])
if callback is not None:
callback(image,seed)
seed = self._new_seed()
except KeyboardInterrupt:
print('*interrupted*')
print('Partial results will be returned; if --grid was requested, nothing will be returned.')
except RuntimeError as e:
print(str(e))
return images
yield self._samples_to_images(samples)
@torch.no_grad()
def _img2img(self,prompt,
data,precision_scope,
batch_size,iterations,
steps,seed,cfg_scale,ddim_eta,
def _img2img(self,
prompt,
precision_scope,
batch_size,
steps,cfg_scale,ddim_eta,
skip_normalize,
init_img,strength,variants,
callback):
init_img,strength):
"""
Generate an image from the prompt and the initial image, writing iteration images into the outdir
The output is a list of lists in the format: [[image,seed1], [image,seed2],...]
An infinite iterator of images from the prompt and the initial image
"""
# PLMS sampler not supported yet, so ignore previous sampler
@ -377,22 +347,24 @@ The vast majority of these arguments default to reasonable values.
t_enc = int(strength * steps)
# print(f"target t_enc is {t_enc} steps")
images = list()
try:
with precision_scope(self.device.type), self.model.ema_scope():
all_samples = list()
for n in trange(iterations, desc="Sampling"):
seed_everything(seed)
for prompts in tqdm(data, desc="data", dynamic_ncols=True):
uc = None
if cfg_scale != 1.0:
while True:
uc, c = self._get_uc_and_c(prompt, batch_size, skip_normalize)
# encode (scaled latent)
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device))
# decode it
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,)
yield self._samples_to_images(samples)
# TODO: does this actually need to run every loop? does anything in it vary by random seed?
def _get_uc_and_c(self, prompt, batch_size, skip_normalize):
uc = self.model.get_learned_conditioning(batch_size * [""])
if isinstance(prompts, tuple):
prompts = list(prompts)
# weighted sub-prompts
subprompts,weights = T2I._split_weighted_subprompts(prompts[0])
subprompts,weights = T2I._split_weighted_subprompts(prompt)
if len(subprompts) > 1:
# i dont know if this is correct.. but it works
c = torch.zeros_like(uc)
@ -403,33 +375,19 @@ The vast majority of these arguments default to reasonable values.
weight = weights[i]
if not skip_normalize:
weight = weight / totalWeight
c = torch.add(c,self.model.get_learned_conditioning(subprompts[i]), alpha=weight)
c = torch.add(c, self.model.get_learned_conditioning(batch_size * [subprompts[i]]), alpha=weight)
else: # just standard 1 prompt
c = self.model.get_learned_conditioning(prompts)
# encode (scaled latent)
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(self.device))
# decode it
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,)
c = self.model.get_learned_conditioning(batch_size * [prompt])
return (uc, c)
def _samples_to_images(self, samples):
x_samples = self.model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
images = list()
for x_sample in x_samples:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
image = Image.fromarray(x_sample.astype(np.uint8))
images.append([image,seed])
if callback is not None:
callback(image,seed)
seed = self._new_seed()
except KeyboardInterrupt:
print('*interrupted*')
print('Partial results will be returned; if --grid was requested, nothing will be returned.')
except RuntimeError as e:
print("Oops! A runtime error has occurred. If this is unexpected, please copy-and-paste this stack trace and post it as an Issue to http://github.com/lstein/stable-diffusion")
traceback.print_exc()
images.append(image)
return images
def _new_seed(self):

View File

@ -252,7 +252,7 @@ def create_argv_parser():
'-o',
type=str,
default="outputs/img-samples",
help="directory in which to place generated images and a log of prompts and seeds")
help="directory in which to place generated images and a log of prompts and seeds (outputs/img-samples")
parser.add_argument('--embedding_path',
type=str,
help="Path to a pre-trained embedding manager checkpoint - can only be set on command line")