Merge branch 'seed-fuzz' of github.com:bakkot/stable-diffusion into bakkot-seed-fuzz

This commit is contained in:
Lincoln Stein 2022-09-02 16:17:51 -04:00
commit 2d65b03f05
4 changed files with 191 additions and 47 deletions

View File

@ -69,6 +69,11 @@ class PromptFormatter:
switches.append(f'-G{opt.gfpgan_strength}') switches.append(f'-G{opt.gfpgan_strength}')
if opt.upscale: if opt.upscale:
switches.append(f'-U {" ".join([str(u) for u in opt.upscale])}') switches.append(f'-U {" ".join([str(u) for u in opt.upscale])}')
if opt.variation_amount > 0:
switches.append(f'-v {opt.variation_amount}')
if opt.with_variations:
formatted_variations = ';'.join(f'{seed},{weight}' for seed, weight in opt.with_variations)
switches.append(f'-V {formatted_variations}')
if t2i.full_precision: if t2i.full_precision:
switches.append('-F') switches.append('-F')
return ' '.join(switches) return ' '.join(switches)

View File

@ -66,8 +66,8 @@ class KSampler(object):
img_callback(k_callback_values['x'], k_callback_values['i']) img_callback(k_callback_values['x'], k_callback_values['i'])
sigmas = self.model.get_sigmas(S) sigmas = self.model.get_sigmas(S)
if x_T: if x_T is not None:
x = x_T x = x_T * sigmas[0]
else: else:
x = ( x = (
torch.randn([batch_size, *shape], device=self.device) torch.randn([batch_size, *shape], device=self.device)

View File

@ -226,6 +226,8 @@ class T2I:
upscale = None, upscale = None,
sampler_name = None, sampler_name = None,
log_tokenization= False, log_tokenization= False,
with_variations = None,
variation_amount = 0.0,
**args, **args,
): # eat up additional cruft ): # eat up additional cruft
""" """
@ -244,6 +246,8 @@ class T2I:
ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image) ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image)
step_callback // a function or method that will be called each step step_callback // a function or method that will be called each step
image_callback // a function or method that will be called each time an image is generated image_callback // a function or method that will be called each time an image is generated
with_variations // a weighted list [(seed_1, weight_1), (seed_2, weight_2), ...] of variations which should be applied before doing any generation
variation_amount // optional 0-1 value to slerp from -S noise to random noise (allows variations on an image)
To use the step callback, define a function that receives two arguments: To use the step callback, define a function that receives two arguments:
- Image GPU data - Image GPU data
@ -270,6 +274,7 @@ class T2I:
iterations = iterations or self.iterations iterations = iterations or self.iterations
strength = strength or self.strength strength = strength or self.strength
self.log_tokenization = log_tokenization self.log_tokenization = log_tokenization
with_variations = [] if with_variations is None else with_variations
model = ( model = (
self.load_model() self.load_model()
@ -278,6 +283,18 @@ class T2I:
assert ( assert (
0.0 <= strength <= 1.0 0.0 <= strength <= 1.0
), 'can only work with strength in [0.0, 1.0]' ), 'can only work with strength in [0.0, 1.0]'
assert (
0.0 <= variation_amount <= 1.0
), '-v --variation_amount must be in [0.0, 1.0]'
if len(with_variations) > 0:
assert seed is not None,\
'seed must be specified when using with_variations'
if variation_amount == 0.0:
assert iterations == 1,\
'when using --with_variations, multiple iterations are only possible when using --variation_amount'
assert all(0 <= weight <= 1 for _, weight in with_variations),\
f'variation weights must be in [0.0, 1.0]: got {[weight for _, weight in with_variations]}'
width, height, _ = self._resolution_check(width, height, log=True) width, height, _ = self._resolution_check(width, height, log=True)
@ -301,24 +318,25 @@ class T2I:
try: try:
if init_img: if init_img:
assert os.path.exists(init_img), f'{init_img}: File not found' assert os.path.exists(init_img), f'{init_img}: File not found'
images_iterator = self._img2img( init_image = self._load_img(init_img, width, height, fit).to(self.device)
with scope(device.type):
init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image)
) # move to latent space
make_image = self._img2img(
prompt, prompt,
precision_scope=scope,
steps=steps, steps=steps,
cfg_scale=cfg_scale, cfg_scale=cfg_scale,
ddim_eta=ddim_eta, ddim_eta=ddim_eta,
skip_normalize=skip_normalize, skip_normalize=skip_normalize,
init_img=init_img, init_latent=init_latent,
width=width,
height=height,
fit=fit,
strength=strength, strength=strength,
callback=step_callback, callback=step_callback,
) )
else: else:
images_iterator = self._txt2img( make_image = self._txt2img(
prompt, prompt,
precision_scope=scope,
steps=steps, steps=steps,
cfg_scale=cfg_scale, cfg_scale=cfg_scale,
ddim_eta=ddim_eta, ddim_eta=ddim_eta,
@ -328,11 +346,45 @@ class T2I:
callback=step_callback, callback=step_callback,
) )
def get_noise():
if init_img:
return torch.randn_like(init_latent, device=self.device)
else:
return torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device=self.device)
initial_noise = None
if variation_amount > 0 or len(with_variations) > 0:
# use fixed initial noise plus random noise per iteration
seed_everything(seed)
initial_noise = get_noise()
for v_seed, v_weight in with_variations:
seed = v_seed
seed_everything(seed)
next_noise = get_noise()
initial_noise = self.slerp(v_weight, initial_noise, next_noise)
if variation_amount > 0:
random.seed() # reset RNG to an actually random state, so we can get a random seed for variations
seed = random.randrange(0,np.iinfo(np.uint32).max)
device_type = choose_autocast_device(self.device) device_type = choose_autocast_device(self.device)
with scope(device_type), self.model.ema_scope(): with scope(device_type), self.model.ema_scope():
for n in trange(iterations, desc='Generating'): for n in trange(iterations, desc='Generating'):
seed_everything(seed) x_T = None
image = next(images_iterator) if variation_amount > 0:
seed_everything(seed)
target_noise = get_noise()
x_T = self.slerp(variation_amount, initial_noise, target_noise)
elif initial_noise is not None:
# i.e. we specified particular variations
x_T = initial_noise
else:
seed_everything(seed)
# make_image will do the equivalent of get_noise itself
image = make_image(x_T)
results.append([image, seed]) results.append([image, seed])
if image_callback is not None: if image_callback is not None:
image_callback(image, seed) image_callback(image, seed)
@ -406,7 +458,6 @@ class T2I:
def _txt2img( def _txt2img(
self, self,
prompt, prompt,
precision_scope,
steps, steps,
cfg_scale, cfg_scale,
ddim_eta, ddim_eta,
@ -416,12 +467,13 @@ class T2I:
callback, callback,
): ):
""" """
An infinite iterator of images from the prompt. Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
""" """
sampler = self.sampler sampler = self.sampler
while True: def make_image(x_T):
uc, c = self._get_uc_and_c(prompt, skip_normalize) uc, c = self._get_uc_and_c(prompt, skip_normalize)
shape = [ shape = [
self.latent_channels, self.latent_channels,
@ -431,6 +483,7 @@ class T2I:
samples, _ = sampler.sample( samples, _ = sampler.sample(
batch_size=1, batch_size=1,
S=steps, S=steps,
x_T=x_T,
conditioning=c, conditioning=c,
shape=shape, shape=shape,
verbose=False, verbose=False,
@ -439,26 +492,24 @@ class T2I:
eta=ddim_eta, eta=ddim_eta,
img_callback=callback img_callback=callback
) )
yield self._sample_to_image(samples) return self._sample_to_image(samples)
return make_image
@torch.no_grad() @torch.no_grad()
def _img2img( def _img2img(
self, self,
prompt, prompt,
precision_scope,
steps, steps,
cfg_scale, cfg_scale,
ddim_eta, ddim_eta,
skip_normalize, skip_normalize,
init_img, init_latent,
width,
height,
fit,
strength, strength,
callback, # Currently not implemented for img2img callback, # Currently not implemented for img2img
): ):
""" """
An infinite iterator of images from the prompt and the initial image Returns a function returning an image derived from the prompt and the initial image
Return value depends on the seed at the time you call it
""" """
# PLMS sampler not supported yet, so ignore previous sampler # PLMS sampler not supported yet, so ignore previous sampler
@ -470,24 +521,20 @@ class T2I:
else: else:
sampler = self.sampler sampler = self.sampler
init_image = self._load_img(init_img, width, height,fit).to(self.device)
with precision_scope(self.device.type):
init_latent = self.model.get_first_stage_encoding(
self.model.encode_first_stage(init_image)
) # move to latent space
sampler.make_schedule( sampler.make_schedule(
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
) )
t_enc = int(strength * steps) t_enc = int(strength * steps)
while True: def make_image(x_T):
uc, c = self._get_uc_and_c(prompt, skip_normalize) uc, c = self._get_uc_and_c(prompt, skip_normalize)
# encode (scaled latent) # encode (scaled latent)
z_enc = sampler.stochastic_encode( z_enc = sampler.stochastic_encode(
init_latent, torch.tensor([t_enc]).to(self.device) init_latent,
torch.tensor([t_enc]).to(self.device),
noise=x_T
) )
# decode it # decode it
samples = sampler.decode( samples = sampler.decode(
@ -498,7 +545,8 @@ class T2I:
unconditional_guidance_scale=cfg_scale, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc, unconditional_conditioning=uc,
) )
yield self._sample_to_image(samples) return self._sample_to_image(samples)
return make_image
# TODO: does this actually need to run every loop? does anything in it vary by random seed? # TODO: does this actually need to run every loop? does anything in it vary by random seed?
def _get_uc_and_c(self, prompt, skip_normalize): def _get_uc_and_c(self, prompt, skip_normalize):
@ -513,8 +561,7 @@ class T2I:
# i dont know if this is correct.. but it works # i dont know if this is correct.. but it works
c = torch.zeros_like(uc) c = torch.zeros_like(uc)
# normalize each "sub prompt" and add it # normalize each "sub prompt" and add it
for i in range(0, len(weighted_subprompts)): for subprompt, weight in weighted_subprompts:
subprompt, weight = weighted_subprompts[i]
self._log_tokenization(subprompt) self._log_tokenization(subprompt)
c = torch.add( c = torch.add(
c, c,
@ -619,7 +666,7 @@ class T2I:
print( print(
f'>> loaded input image of size {image.width}x{image.height} from {path}' f'>> loaded input image of size {image.width}x{image.height} from {path}'
) )
# The logic here is: # The logic here is:
# 1. If "fit" is true, then the image will be fit into the bounding box defined # 1. If "fit" is true, then the image will be fit into the bounding box defined
# by width and height. It will do this in a way that preserves the init image's # by width and height. It will do this in a way that preserves the init image's
@ -644,7 +691,7 @@ class T2I:
if resize_needed: if resize_needed:
return InitImageResizer(image).resize(x,y) return InitImageResizer(image).resize(x,y)
return image return image
def _fit_image(self,image,max_dimensions): def _fit_image(self,image,max_dimensions):
w,h = max_dimensions w,h = max_dimensions
@ -677,10 +724,10 @@ class T2I:
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:' (?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
) # end 'prompt' ) # end 'prompt'
(?: # non-capture group (?: # non-capture group
:+ # match one or more ':' characters :+ # match one or more ':' characters
(?P<weight> # capture group for 'weight' (?P<weight> # capture group for 'weight'
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number -?\d+(?:\.\d+)? # match positive or negative integer or decimal number
)? # end weight capture group, make optional )? # end weight capture group, make optional
\s* # strip spaces after weight \s* # strip spaces after weight
| # OR | # OR
$ # else, if no ':' then match end of line $ # else, if no ':' then match end of line
@ -741,3 +788,41 @@ class T2I:
print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.") print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.")
return width, height, resize_needed return width, height, resize_needed
def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
'''
Spherical linear interpolation
Args:
t (float/np.ndarray): Float value between 0.0 and 1.0
v0 (np.ndarray): Starting vector
v1 (np.ndarray): Final vector
DOT_THRESHOLD (float): Threshold for considering the two vectors as
colineal. Not recommended to alter this.
Returns:
v2 (np.ndarray): Interpolation vector between v0 and v1
'''
inputs_are_torch = False
if not isinstance(v0, np.ndarray):
inputs_are_torch = True
v0 = v0.detach().cpu().numpy()
if not isinstance(v1, np.ndarray):
inputs_are_torch = True
v1 = v1.detach().cpu().numpy()
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
if np.abs(dot) > DOT_THRESHOLD:
v2 = (1 - t) * v0 + t * v1
else:
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2 = s0 * v0 + s1 * v1
if inputs_are_torch:
v2 = torch.from_numpy(v2).to(self.device)
return v2

View File

@ -181,9 +181,32 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
print(f'No previous seed at position {opt.seed} found') print(f'No previous seed at position {opt.seed} found')
opt.seed = None opt.seed = None
normalized_prompt = PromptFormatter(t2i, opt).normalize_prompt()
do_grid = opt.grid or t2i.grid do_grid = opt.grid or t2i.grid
individual_images = not do_grid
if opt.with_variations is not None:
# shotgun parsing, woo
parts = []
broken = False # python doesn't have labeled loops...
for part in opt.with_variations.split(';'):
seed_and_weight = part.split(',')
if len(seed_and_weight) != 2:
print(f'could not parse with_variation part "{part}"')
broken = True
break
try:
seed = int(seed_and_weight[0])
weight = float(seed_and_weight[1])
except ValueError:
print(f'could not parse with_variation part "{part}"')
broken = True
break
parts.append([seed, weight])
if broken:
continue
if len(parts) > 0:
opt.with_variations = parts
else:
opt.with_variations = None
if opt.outdir: if opt.outdir:
if not os.path.exists(opt.outdir): if not os.path.exists(opt.outdir):
@ -211,7 +234,7 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
file_writer = PngWriter(current_outdir) file_writer = PngWriter(current_outdir)
prefix = file_writer.unique_prefix() prefix = file_writer.unique_prefix()
seeds = set() seeds = set()
results = [] results = [] # list of filename, prompt pairs
grid_images = dict() # seed -> Image, only used if `do_grid` grid_images = dict() # seed -> Image, only used if `do_grid`
def image_writer(image, seed, upscaled=False): def image_writer(image, seed, upscaled=False):
if do_grid: if do_grid:
@ -221,10 +244,26 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
filename = f'{prefix}.{seed}.postprocessed.png' filename = f'{prefix}.{seed}.postprocessed.png'
else: else:
filename = f'{prefix}.{seed}.png' filename = f'{prefix}.{seed}.png'
path = file_writer.save_image_and_prompt_to_png(image, f'{normalized_prompt} -S{seed}', filename) if opt.variation_amount > 0:
iter_opt = argparse.Namespace(**vars(opt)) # copy
this_variation = [[seed, opt.variation_amount]]
if opt.with_variations is None:
iter_opt.with_variations = this_variation
else:
iter_opt.with_variations = opt.with_variations + this_variation
iter_opt.variation_amount = 0
normalized_prompt = PromptFormatter(t2i, iter_opt).normalize_prompt()
metadata_prompt = f'{normalized_prompt} -S{iter_opt.seed}'
elif opt.with_variations is not None:
normalized_prompt = PromptFormatter(t2i, opt).normalize_prompt()
metadata_prompt = f'{normalized_prompt} -S{opt.seed}' # use the original seed - the per-iteration value is the last variation-seed
else:
normalized_prompt = PromptFormatter(t2i, opt).normalize_prompt()
metadata_prompt = f'{normalized_prompt} -S{seed}'
path = file_writer.save_image_and_prompt_to_png(image, metadata_prompt, filename)
if (not upscaled) or opt.save_original: if (not upscaled) or opt.save_original:
# only append to results if we didn't overwrite an earlier output # only append to results if we didn't overwrite an earlier output
results.append([path, seed]) results.append([path, metadata_prompt])
seeds.add(seed) seeds.add(seed)
@ -235,11 +274,12 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
first_seed = next(iter(seeds)) first_seed = next(iter(seeds))
filename = f'{prefix}.{first_seed}.png' filename = f'{prefix}.{first_seed}.png'
# TODO better metadata for grid images # TODO better metadata for grid images
metadata_prompt = f'{normalized_prompt} -S{first_seed}' normalized_prompt = PromptFormatter(t2i, opt).normalize_prompt()
metadata_prompt = f'{normalized_prompt} -S{first_seed} --grid -N{len(grid_images)}'
path = file_writer.save_image_and_prompt_to_png( path = file_writer.save_image_and_prompt_to_png(
grid_img, metadata_prompt, filename grid_img, metadata_prompt, filename
) )
results = [[path, seeds]] results = [[path, metadata_prompt]]
last_seeds = list(seeds) last_seeds = list(seeds)
@ -253,7 +293,7 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile):
print('Outputs:') print('Outputs:')
log_path = os.path.join(current_outdir, 'dream_log.txt') log_path = os.path.join(current_outdir, 'dream_log.txt')
write_log_message(normalized_prompt, results, log_path) write_log_message(results, log_path)
print('goodbye!') print('goodbye!')
@ -291,9 +331,9 @@ def dream_server_loop(t2i):
dream_server.server_close() dream_server.server_close()
def write_log_message(prompt, results, log_path): def write_log_message(results, log_path):
"""logs the name of the output image, prompt, and prompt args to the terminal and log file""" """logs the name of the output image, prompt, and prompt args to the terminal and log file"""
log_lines = [f'{r[0]}: {prompt} -S{r[1]}\n' for r in results] log_lines = [f'{path}: {prompt}\n' for path, prompt in results]
print(*log_lines, sep='') print(*log_lines, sep='')
with open(log_path, 'a', encoding='utf-8') as file: with open(log_path, 'a', encoding='utf-8') as file:
@ -546,6 +586,20 @@ def create_cmd_parser():
action='store_true', action='store_true',
help='shows how the prompt is split into tokens' help='shows how the prompt is split into tokens'
) )
parser.add_argument(
'-v',
'--variation_amount',
default=0.0,
type=float,
help='If > 0, generates variations on the initial seed instead of random seeds per iteration. Must be between 0 and 1. Higher values will be more different.'
)
parser.add_argument(
'-V',
'--with_variations',
default=None,
type=str,
help='list of variations to apply, in the format `seed,weight;seed,weight;...'
)
return parser return parser