diff --git a/README.md b/README.md index 38cb46f681..6adce59e9d 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,7 @@ text-to-image generator. This fork supports: 3. A basic Web interface that allows you to run a local web server for generating images in your browser. + 4. A notebook for running the code on Google Colab. 5. Upscaling and face fixing using the optional ESRGAN and GFPGAN @@ -30,7 +31,11 @@ text-to-image generator. This fork supports: 6. Weighted subprompts for prompt tuning. -7. Textual inversion for customization of the prompt language and images. +7. [Image variations](Variations.md) which allow you to systematically +generate variations of an image you like and combine two or more +images together to combine the best features of both. + +8. Textual inversion for customization of the prompt language and images. 8. ...and more! diff --git a/VARIATIONS.md b/VARIATIONS.md new file mode 100644 index 0000000000..c0699909a1 --- /dev/null +++ b/VARIATIONS.md @@ -0,0 +1,113 @@ +# Cheat Sheat for Generating Variations + +Release 1.13 of SD-Dream adds support for image variations. There are two things that you can do: + +1. Generate a series of systematic variations of an image, given a +prompt. The amount of variation from one image to the next can be +controlled. + +2. Given two or more variations that you like, you can combine them in +a weighted fashion + +This cheat sheet provides a quick guide for how this works in +practice, using variations to create the desired image of Xena, +Warrior Princess. + +## Step 1 -- find a base image that you like + +The prompt we will use throughout is "lucy lawless as xena, warrior +princess, character portrait, high resolution." This will be indicated +as "prompt" in the examples below. + +First we let SD create a series of images in the usual way, in this case +requesting six iterations: + +~~~ +dream> lucy lawless as xena, warrior princess, character portrait, high resolution -n6 +... +Outputs: +./outputs/Xena/000001.1579445059.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -S1579445059 +./outputs/Xena/000001.1880768722.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -S1880768722 +./outputs/Xena/000001.332057179.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -S332057179 +./outputs/Xena/000001.2224800325.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -S2224800325 +./outputs/Xena/000001.465250761.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -S465250761 +./outputs/Xena/000001.3357757885.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -S3357757885 +~~~ + +The one with seed 3357757885 looks nice: + + + +Let's try to generate some variations. Using the same seed, we pass +the argument -v0.1 (or --variant_amount), which generates a series of +variations each differing by a variation amount of 0.2. This number +ranges from 0 to 1.0, with higher numbers being larger amounts of +variation. + +~~~ +dream> "prompt" -n6 -S3357757885 -v0.2 +... +Outputs: +./outputs/Xena/000002.784039624.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -V 784039624,0.2 -S3357757885 +./outputs/Xena/000002.3647897225.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -V 3647897225,0.2 -S3357757885 +./outputs/Xena/000002.917731034.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -V 917731034,0.2 -S3357757885 +./outputs/Xena/000002.4116285959.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -V 4116285959,0.2 -S3357757885 +./outputs/Xena/000002.1614299449.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -V 1614299449,0.2 -S3357757885 +./outputs/Xena/000002.1335553075.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -V 1335553075,0.2 -S3357757885 +~~~ + +Note that the output for each image has a -V option giving the +"variant subseed" for that image, consisting of a seed followed by the +variation amount used to generate it. + +This gives us a series of closely-related variations, including the +two shown here. + + + + + +I like the expression on Xena's face in the first one (subseed +3647897225), and the armor on her shoulder in the second one (subseed +1614299449). Can we combine them to get the best of both worlds? + +We combine the two variations using -V (--with_variations). Again, we +must provide the seed for the originally-chosen image in order for +this to work. + +~~~ +dream> "prompt" -S3357757885 -V3647897225,0.1;1614299449,0.1 +Outputs: +./outputs/Xena/000003.1614299449.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -V 3647897225,0.1;1614299449,0.1 -S3357757885 +~~~ + +Here we are providing equal weights (0.1 and 0.1) for both the +subseeds. The resulting image is close, but not exactly what I +wanted: + + + +We could either try combining the images with different weights, or we +can generate more variations around the almost-but-not-quite image. We +do the latter, using both the -V (combining) and -v (variation +strength) options. Note that we use -n6 to generate 6 variations: + +~~~~ +dream> "prompt" -S3357757885 -V3647897225,0.1;1614299449,0.1 -v0.05 -n6 +Outputs: +./outputs/Xena/000004.3279757577.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -V 3647897225,0.1;1614299449,0.1;3279757577,0.05 -S3357757885 +./outputs/Xena/000004.2853129515.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -V 3647897225,0.1;1614299449,0.1;2853129515,0.05 -S3357757885 +./outputs/Xena/000004.3747154981.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -V 3647897225,0.1;1614299449,0.1;3747154981,0.05 -S3357757885 +./outputs/Xena/000004.2664260391.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -V 3647897225,0.1;1614299449,0.1;2664260391,0.05 -S3357757885 +./outputs/Xena/000004.1642517170.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -V 3647897225,0.1;1614299449,0.1;1642517170,0.05 -S3357757885 +./outputs/Xena/000004.2183375608.png: "prompt" -s50 -W512 -H512 -C7.5 -Ak_lms -V 3647897225,0.1;1614299449,0.1;2183375608,0.05 -S3357757885 +~~~ + +This produces six images, all slight variations on the combination of +the chosen two images. Here's the one I like best: + + + +As you can see, this is a very powerful too, which when combined with +subprompt weighting, gives you great control over the content and +quality of your generated images. \ No newline at end of file diff --git a/ldm/dream/pngwriter.py b/ldm/dream/pngwriter.py index 8dfc09236a..1e86eb8fbc 100644 --- a/ldm/dream/pngwriter.py +++ b/ldm/dream/pngwriter.py @@ -69,6 +69,11 @@ class PromptFormatter: switches.append(f'-G{opt.gfpgan_strength}') if opt.upscale: switches.append(f'-U {" ".join([str(u) for u in opt.upscale])}') + if opt.variation_amount > 0: + switches.append(f'-v {opt.variation_amount}') + if opt.with_variations: + formatted_variations = ';'.join(f'{seed},{weight}' for seed, weight in opt.with_variations) + switches.append(f'-V {formatted_variations}') if t2i.full_precision: switches.append('-F') return ' '.join(switches) diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index 7e3f40883d..0f6814940e 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -66,8 +66,8 @@ class KSampler(object): img_callback(k_callback_values['x'], k_callback_values['i']) sigmas = self.model.get_sigmas(S) - if x_T: - x = x_T + if x_T is not None: + x = x_T * sigmas[0] else: x = ( torch.randn([batch_size, *shape], device=self.device) diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index 10720a7483..bfe2c99cc4 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -226,6 +226,8 @@ class T2I: upscale = None, sampler_name = None, log_tokenization= False, + with_variations = None, + variation_amount = 0.0, **args, ): # eat up additional cruft """ @@ -244,6 +246,8 @@ class T2I: ddim_eta // image randomness (eta=0.0 means the same seed always produces the same image) step_callback // a function or method that will be called each step image_callback // a function or method that will be called each time an image is generated + with_variations // a weighted list [(seed_1, weight_1), (seed_2, weight_2), ...] of variations which should be applied before doing any generation + variation_amount // optional 0-1 value to slerp from -S noise to random noise (allows variations on an image) To use the step callback, define a function that receives two arguments: - Image GPU data @@ -262,7 +266,6 @@ class T2I: """ # TODO: convert this into a getattr() loop steps = steps or self.steps - seed = seed or self.seed width = width or self.width height = height or self.height cfg_scale = cfg_scale or self.cfg_scale @@ -270,6 +273,7 @@ class T2I: iterations = iterations or self.iterations strength = strength or self.strength self.log_tokenization = log_tokenization + with_variations = [] if with_variations is None else with_variations model = ( self.load_model() @@ -278,7 +282,20 @@ class T2I: assert ( 0.0 <= strength <= 1.0 ), 'can only work with strength in [0.0, 1.0]' + assert ( + 0.0 <= variation_amount <= 1.0 + ), '-v --variation_amount must be in [0.0, 1.0]' + if len(with_variations) > 0: + assert seed is not None,\ + 'seed must be specified when using with_variations' + if variation_amount == 0.0: + assert iterations == 1,\ + 'when using --with_variations, multiple iterations are only possible when using --variation_amount' + assert all(0 <= weight <= 1 for _, weight in with_variations),\ + f'variation weights must be in [0.0, 1.0]: got {[weight for _, weight in with_variations]}' + + seed = seed or self.seed width, height, _ = self._resolution_check(width, height, log=True) # TODO: - Check if this is still necessary to run on M1 devices. @@ -301,24 +318,25 @@ class T2I: try: if init_img: assert os.path.exists(init_img), f'{init_img}: File not found' - images_iterator = self._img2img( + init_image = self._load_img(init_img, width, height, fit).to(self.device) + with scope(self.device.type): + init_latent = self.model.get_first_stage_encoding( + self.model.encode_first_stage(init_image) + ) # move to latent space + + make_image = self._img2img( prompt, - precision_scope=scope, steps=steps, cfg_scale=cfg_scale, ddim_eta=ddim_eta, skip_normalize=skip_normalize, - init_img=init_img, - width=width, - height=height, - fit=fit, + init_latent=init_latent, strength=strength, callback=step_callback, ) else: - images_iterator = self._txt2img( + make_image = self._txt2img( prompt, - precision_scope=scope, steps=steps, cfg_scale=cfg_scale, ddim_eta=ddim_eta, @@ -328,11 +346,45 @@ class T2I: callback=step_callback, ) + def get_noise(): + if init_img: + return torch.randn_like(init_latent, device=self.device) + else: + return torch.randn([1, + self.latent_channels, + height // self.downsampling_factor, + width // self.downsampling_factor], + device=self.device) + + initial_noise = None + if variation_amount > 0 or len(with_variations) > 0: + # use fixed initial noise plus random noise per iteration + seed_everything(seed) + initial_noise = get_noise() + for v_seed, v_weight in with_variations: + seed = v_seed + seed_everything(seed) + next_noise = get_noise() + initial_noise = self.slerp(v_weight, initial_noise, next_noise) + if variation_amount > 0: + random.seed() # reset RNG to an actually random state, so we can get a random seed for variations + seed = random.randrange(0,np.iinfo(np.uint32).max) + device_type = choose_autocast_device(self.device) with scope(device_type), self.model.ema_scope(): for n in trange(iterations, desc='Generating'): - seed_everything(seed) - image = next(images_iterator) + x_T = None + if variation_amount > 0: + seed_everything(seed) + target_noise = get_noise() + x_T = self.slerp(variation_amount, initial_noise, target_noise) + elif initial_noise is not None: + # i.e. we specified particular variations + x_T = initial_noise + else: + seed_everything(seed) + # make_image will do the equivalent of get_noise itself + image = make_image(x_T) results.append([image, seed]) if image_callback is not None: image_callback(image, seed) @@ -406,7 +458,6 @@ class T2I: def _txt2img( self, prompt, - precision_scope, steps, cfg_scale, ddim_eta, @@ -416,12 +467,13 @@ class T2I: callback, ): """ - An infinite iterator of images from the prompt. + Returns a function returning an image derived from the prompt and the initial image + Return value depends on the seed at the time you call it """ sampler = self.sampler - while True: + def make_image(x_T): uc, c = self._get_uc_and_c(prompt, skip_normalize) shape = [ self.latent_channels, @@ -431,6 +483,7 @@ class T2I: samples, _ = sampler.sample( batch_size=1, S=steps, + x_T=x_T, conditioning=c, shape=shape, verbose=False, @@ -439,26 +492,24 @@ class T2I: eta=ddim_eta, img_callback=callback ) - yield self._sample_to_image(samples) + return self._sample_to_image(samples) + return make_image @torch.no_grad() def _img2img( self, prompt, - precision_scope, steps, cfg_scale, ddim_eta, skip_normalize, - init_img, - width, - height, - fit, + init_latent, strength, callback, # Currently not implemented for img2img ): """ - An infinite iterator of images from the prompt and the initial image + Returns a function returning an image derived from the prompt and the initial image + Return value depends on the seed at the time you call it """ # PLMS sampler not supported yet, so ignore previous sampler @@ -470,24 +521,20 @@ class T2I: else: sampler = self.sampler - init_image = self._load_img(init_img, width, height,fit).to(self.device) - with precision_scope(self.device.type): - init_latent = self.model.get_first_stage_encoding( - self.model.encode_first_stage(init_image) - ) # move to latent space - sampler.make_schedule( ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False ) t_enc = int(strength * steps) - while True: + def make_image(x_T): uc, c = self._get_uc_and_c(prompt, skip_normalize) # encode (scaled latent) z_enc = sampler.stochastic_encode( - init_latent, torch.tensor([t_enc]).to(self.device) + init_latent, + torch.tensor([t_enc]).to(self.device), + noise=x_T ) # decode it samples = sampler.decode( @@ -498,7 +545,8 @@ class T2I: unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, ) - yield self._sample_to_image(samples) + return self._sample_to_image(samples) + return make_image # TODO: does this actually need to run every loop? does anything in it vary by random seed? def _get_uc_and_c(self, prompt, skip_normalize): @@ -513,8 +561,7 @@ class T2I: # i dont know if this is correct.. but it works c = torch.zeros_like(uc) # normalize each "sub prompt" and add it - for i in range(0, len(weighted_subprompts)): - subprompt, weight = weighted_subprompts[i] + for subprompt, weight in weighted_subprompts: self._log_tokenization(subprompt) c = torch.add( c, @@ -619,7 +666,7 @@ class T2I: print( f'>> loaded input image of size {image.width}x{image.height} from {path}' ) - + # The logic here is: # 1. If "fit" is true, then the image will be fit into the bounding box defined # by width and height. It will do this in a way that preserves the init image's @@ -644,7 +691,7 @@ class T2I: if resize_needed: return InitImageResizer(image).resize(x,y) return image - + def _fit_image(self,image,max_dimensions): w,h = max_dimensions @@ -677,10 +724,10 @@ class T2I: (?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:' ) # end 'prompt' (?: # non-capture group - :+ # match one or more ':' characters + :+ # match one or more ':' characters (?P # capture group for 'weight' -?\d+(?:\.\d+)? # match positive or negative integer or decimal number - )? # end weight capture group, make optional + )? # end weight capture group, make optional \s* # strip spaces after weight | # OR $ # else, if no ':' then match end of line @@ -741,3 +788,41 @@ class T2I: print(">> This input is larger than your defaults. If you run out of memory, please use a smaller image.") return width, height, resize_needed + + + def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995): + ''' + Spherical linear interpolation + Args: + t (float/np.ndarray): Float value between 0.0 and 1.0 + v0 (np.ndarray): Starting vector + v1 (np.ndarray): Final vector + DOT_THRESHOLD (float): Threshold for considering the two vectors as + colineal. Not recommended to alter this. + Returns: + v2 (np.ndarray): Interpolation vector between v0 and v1 + ''' + inputs_are_torch = False + if not isinstance(v0, np.ndarray): + inputs_are_torch = True + v0 = v0.detach().cpu().numpy() + if not isinstance(v1, np.ndarray): + inputs_are_torch = True + v1 = v1.detach().cpu().numpy() + + dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) + if np.abs(dot) > DOT_THRESHOLD: + v2 = (1 - t) * v0 + t * v1 + else: + theta_0 = np.arccos(dot) + sin_theta_0 = np.sin(theta_0) + theta_t = theta_0 * t + sin_theta_t = np.sin(theta_t) + s0 = np.sin(theta_0 - theta_t) / sin_theta_0 + s1 = sin_theta_t / sin_theta_0 + v2 = s0 * v0 + s1 * v1 + + if inputs_are_torch: + v2 = torch.from_numpy(v2).to(self.device) + + return v2 diff --git a/scripts/dream.py b/scripts/dream.py index 1535ac386c..54d6f86e77 100755 --- a/scripts/dream.py +++ b/scripts/dream.py @@ -181,9 +181,32 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile): print(f'No previous seed at position {opt.seed} found') opt.seed = None - normalized_prompt = PromptFormatter(t2i, opt).normalize_prompt() do_grid = opt.grid or t2i.grid - individual_images = not do_grid + + if opt.with_variations is not None: + # shotgun parsing, woo + parts = [] + broken = False # python doesn't have labeled loops... + for part in opt.with_variations.split(';'): + seed_and_weight = part.split(',') + if len(seed_and_weight) != 2: + print(f'could not parse with_variation part "{part}"') + broken = True + break + try: + seed = int(seed_and_weight[0]) + weight = float(seed_and_weight[1]) + except ValueError: + print(f'could not parse with_variation part "{part}"') + broken = True + break + parts.append([seed, weight]) + if broken: + continue + if len(parts) > 0: + opt.with_variations = parts + else: + opt.with_variations = None if opt.outdir: if not os.path.exists(opt.outdir): @@ -211,7 +234,7 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile): file_writer = PngWriter(current_outdir) prefix = file_writer.unique_prefix() seeds = set() - results = [] + results = [] # list of filename, prompt pairs grid_images = dict() # seed -> Image, only used if `do_grid` def image_writer(image, seed, upscaled=False): if do_grid: @@ -221,10 +244,26 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile): filename = f'{prefix}.{seed}.postprocessed.png' else: filename = f'{prefix}.{seed}.png' - path = file_writer.save_image_and_prompt_to_png(image, f'{normalized_prompt} -S{seed}', filename) + if opt.variation_amount > 0: + iter_opt = argparse.Namespace(**vars(opt)) # copy + this_variation = [[seed, opt.variation_amount]] + if opt.with_variations is None: + iter_opt.with_variations = this_variation + else: + iter_opt.with_variations = opt.with_variations + this_variation + iter_opt.variation_amount = 0 + normalized_prompt = PromptFormatter(t2i, iter_opt).normalize_prompt() + metadata_prompt = f'{normalized_prompt} -S{iter_opt.seed}' + elif opt.with_variations is not None: + normalized_prompt = PromptFormatter(t2i, opt).normalize_prompt() + metadata_prompt = f'{normalized_prompt} -S{opt.seed}' # use the original seed - the per-iteration value is the last variation-seed + else: + normalized_prompt = PromptFormatter(t2i, opt).normalize_prompt() + metadata_prompt = f'{normalized_prompt} -S{seed}' + path = file_writer.save_image_and_prompt_to_png(image, metadata_prompt, filename) if (not upscaled) or opt.save_original: # only append to results if we didn't overwrite an earlier output - results.append([path, seed]) + results.append([path, metadata_prompt]) seeds.add(seed) @@ -235,11 +274,12 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile): first_seed = next(iter(seeds)) filename = f'{prefix}.{first_seed}.png' # TODO better metadata for grid images - metadata_prompt = f'{normalized_prompt} -S{first_seed}' + normalized_prompt = PromptFormatter(t2i, opt).normalize_prompt() + metadata_prompt = f'{normalized_prompt} -S{first_seed} --grid -N{len(grid_images)}' path = file_writer.save_image_and_prompt_to_png( grid_img, metadata_prompt, filename ) - results = [[path, seeds]] + results = [[path, metadata_prompt]] last_seeds = list(seeds) @@ -253,7 +293,7 @@ def main_loop(t2i, outdir, prompt_as_dir, parser, infile): print('Outputs:') log_path = os.path.join(current_outdir, 'dream_log.txt') - write_log_message(normalized_prompt, results, log_path) + write_log_message(results, log_path) print('goodbye!') @@ -291,9 +331,9 @@ def dream_server_loop(t2i): dream_server.server_close() -def write_log_message(prompt, results, log_path): +def write_log_message(results, log_path): """logs the name of the output image, prompt, and prompt args to the terminal and log file""" - log_lines = [f'{r[0]}: {prompt} -S{r[1]}\n' for r in results] + log_lines = [f'{path}: {prompt}\n' for path, prompt in results] print(*log_lines, sep='') with open(log_path, 'a', encoding='utf-8') as file: @@ -546,6 +586,20 @@ def create_cmd_parser(): action='store_true', help='shows how the prompt is split into tokens' ) + parser.add_argument( + '-v', + '--variation_amount', + default=0.0, + type=float, + help='If > 0, generates variations on the initial seed instead of random seeds per iteration. Must be between 0 and 1. Higher values will be more different.' + ) + parser.add_argument( + '-V', + '--with_variations', + default=None, + type=str, + help='list of variations to apply, in the format `seed,weight;seed,weight;...' + ) return parser diff --git a/static/variation_walkthru/000001.3357757885.png b/static/variation_walkthru/000001.3357757885.png new file mode 100644 index 0000000000..b9aa4a78ed Binary files /dev/null and b/static/variation_walkthru/000001.3357757885.png differ diff --git a/static/variation_walkthru/000002.1614299449.png b/static/variation_walkthru/000002.1614299449.png new file mode 100644 index 0000000000..0db167ae6c Binary files /dev/null and b/static/variation_walkthru/000002.1614299449.png differ diff --git a/static/variation_walkthru/000002.3647897225.png b/static/variation_walkthru/000002.3647897225.png new file mode 100644 index 0000000000..7fe1f29227 Binary files /dev/null and b/static/variation_walkthru/000002.3647897225.png differ diff --git a/static/variation_walkthru/000003.1614299449.png b/static/variation_walkthru/000003.1614299449.png new file mode 100644 index 0000000000..b7f6ae7613 Binary files /dev/null and b/static/variation_walkthru/000003.1614299449.png differ diff --git a/static/variation_walkthru/000004.3747154981.png b/static/variation_walkthru/000004.3747154981.png new file mode 100644 index 0000000000..e6ac5f3bc9 Binary files /dev/null and b/static/variation_walkthru/000004.3747154981.png differ