From ca8d9fb885b2211815f21027671defb19902e52a Mon Sep 17 00:00:00 2001 From: Jonathan <34005131+JPPhoto@users.noreply.github.com> Date: Mon, 20 Feb 2023 06:33:19 -0600 Subject: [PATCH] Add symmetry to generation (#2675) Added symmetry to Invoke based on discussions with @damian0815. This can currently only be activated via the CLI with the `--h_symmetry_time_pct` and `--v_symmetry_time_pct` options. Those take values from 0.0-1.0, exclusive, indicating the percentage through generation at which symmetry is applied as a one-time operation. To have symmetry in either axis applied after the first step, use a very low value like 0.001. --- docs/features/CLI.md | 2 + docs/features/PROMPTS.md | 2 +- ldm/generate.py | 6 ++ ldm/invoke/args.py | 19 +++++- ldm/invoke/generator/base.py | 3 + ldm/invoke/generator/img2img.py | 13 ++-- ldm/invoke/generator/txt2img.py | 15 +++-- ldm/invoke/generator/txt2img2img.py | 37 ++++++++--- ldm/invoke/readline.py | 6 +- .../diffusion/shared_invokeai_diffusion.py | 62 ++++++++++++++++++- 10 files changed, 142 insertions(+), 23 deletions(-) diff --git a/docs/features/CLI.md b/docs/features/CLI.md index 6379f7758c..d346b31000 100644 --- a/docs/features/CLI.md +++ b/docs/features/CLI.md @@ -214,6 +214,8 @@ Here are the invoke> command that apply to txt2img: | `--variation ` | `-v` | `0.0` | Add a bit of noise (0.0=none, 1.0=high) to the image in order to generate a series of variations. Usually used in combination with `-S` and `-n` to generate a series a riffs on a starting image. See [Variations](./VARIATIONS.md). | | `--with_variations ` | | `None` | Combine two or more variations. See [Variations](./VARIATIONS.md) for now to use this. | | `--save_intermediates ` | | `None` | Save the image from every nth step into an "intermediates" folder inside the output directory | +| `--h_symmetry_time_pct ` | | `None` | Create symmetry along the X axis at the desired percent complete of the generation process. (Must be between 0.0 and 1.0; set to a very small number like 0.0001 for just after the first step of generation.) | +| `--v_symmetry_time_pct ` | | `None` | Create symmetry along the Y axis at the desired percent complete of the generation process. (Must be between 0.0 and 1.0; set to a very small number like 0.0001 for just after the first step of generation.) | !!! note diff --git a/docs/features/PROMPTS.md b/docs/features/PROMPTS.md index 5413cc5e55..85919a5b29 100644 --- a/docs/features/PROMPTS.md +++ b/docs/features/PROMPTS.md @@ -40,7 +40,7 @@ for adj in adjectives: print(f'a {adj} day -A{samp} -C{cg}') ``` -It's output looks like this (abbreviated): +Its output looks like this (abbreviated): ```bash a sunny day -Aklms -C7.5 diff --git a/ldm/generate.py b/ldm/generate.py index a861c1a27c..d7ea87a5fd 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -320,6 +320,8 @@ class Generate: variation_amount=0.0, threshold=0.0, perlin=0.0, + h_symmetry_time_pct = None, + v_symmetry_time_pct = None, karras_max=None, outdir=None, # these are specific to img2img and inpaint @@ -390,6 +392,8 @@ class Generate: variation_amount // optional 0-1 value to slerp from -S noise to random noise (allows variations on an image) threshold // optional value >=0 to add thresholding to latent values for k-diffusion samplers (0 disables) perlin // optional 0-1 value to add a percentage of perlin noise to the initial noise + h_symmetry_time_pct // optional 0-1 value that indicates the time at which horizontal symmetry is applied + v_symmetry_time_pct // optional 0-1 value that indicates the time at which vertical symmetry is applied embiggen // scale factor relative to the size of the --init_img (-I), followed by ESRGAN upscaling strength (0-1.0), followed by minimum amount of overlap between tiles as a decimal ratio (0 - 1.0) or number of pixels embiggen_tiles // list of tiles by number in order to process and replace onto the image e.g. `0 2 4` embiggen_strength // strength for embiggen. 0.0 preserves image exactly, 1.0 replaces it completely @@ -561,6 +565,8 @@ class Generate: strength=strength, threshold=threshold, perlin=perlin, + h_symmetry_time_pct=h_symmetry_time_pct, + v_symmetry_time_pct=v_symmetry_time_pct, embiggen=embiggen, embiggen_tiles=embiggen_tiles, embiggen_strength=embiggen_strength, diff --git a/ldm/invoke/args.py b/ldm/invoke/args.py index 1bd1aa46ab..1ddba832c0 100644 --- a/ldm/invoke/args.py +++ b/ldm/invoke/args.py @@ -272,6 +272,10 @@ class Args(object): switches.append('--seamless') if a['hires_fix']: switches.append('--hires_fix') + if a['h_symmetry_time_pct']: + switches.append(f'--h_symmetry_time_pct {a["h_symmetry_time_pct"]}') + if a['v_symmetry_time_pct']: + switches.append(f'--v_symmetry_time_pct {a["v_symmetry_time_pct"]}') # img2img generations have parameters relevant only to them and have special handling if a['init_img'] and len(a['init_img'])>0: @@ -845,6 +849,18 @@ class Args(object): type=float, help='Perlin noise scale (0.0 - 1.0) - add perlin noise to the initialization instead of the usual gaussian noise.', ) + render_group.add_argument( + '--h_symmetry_time_pct', + default=None, + type=float, + help='Horizontal symmetry point (0.0 - 1.0) - apply horizontal symmetry at this point in image generation.', + ) + render_group.add_argument( + '--v_symmetry_time_pct', + default=None, + type=float, + help='Vertical symmetry point (0.0 - 1.0) - apply vertical symmetry at this point in image generation.', + ) render_group.add_argument( '--fnformat', default='{prefix}.{seed}.png', @@ -1151,7 +1167,8 @@ def metadata_dumps(opt, # remove any image keys not mentioned in RFC #266 rfc266_img_fields = ['type','postprocessing','sampler','prompt','seed','variations','steps', 'cfg_scale','threshold','perlin','step_number','width','height','extra','strength','seamless' - 'init_img','init_mask','facetool','facetool_strength','upscale'] + 'init_img','init_mask','facetool','facetool_strength','upscale','h_symmetry_time_pct', + 'v_symmetry_time_pct'] rfc_dict ={} for item in image_dict.items(): diff --git a/ldm/invoke/generator/base.py b/ldm/invoke/generator/base.py index 0738140171..d89fb48aff 100644 --- a/ldm/invoke/generator/base.py +++ b/ldm/invoke/generator/base.py @@ -64,6 +64,7 @@ class Generator: def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None, image_callback=None, step_callback=None, threshold=0.0, perlin=0.0, + h_symmetry_time_pct=None, v_symmetry_time_pct=None, safety_checker:dict=None, free_gpu_mem: bool=False, **kwargs): @@ -81,6 +82,8 @@ class Generator: step_callback = step_callback, threshold = threshold, perlin = perlin, + h_symmetry_time_pct = h_symmetry_time_pct, + v_symmetry_time_pct = v_symmetry_time_pct, attention_maps_callback = attention_maps_callback, **kwargs ) diff --git a/ldm/invoke/generator/img2img.py b/ldm/invoke/generator/img2img.py index 0b762f7c98..67a588234b 100644 --- a/ldm/invoke/generator/img2img.py +++ b/ldm/invoke/generator/img2img.py @@ -16,8 +16,8 @@ class Img2Img(Generator): self.init_latent = None # by get_noise() def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, - conditioning,init_image,strength,step_callback=None,threshold=0.0,perlin=0.0, - attention_maps_callback=None, + conditioning,init_image,strength,step_callback=None,threshold=0.0,warmup=0.2,perlin=0.0, + h_symmetry_time_pct=None,v_symmetry_time_pct=None,attention_maps_callback=None, **kwargs): """ Returns a function returning an image derived from the prompt and the initial image @@ -33,8 +33,13 @@ class Img2Img(Generator): conditioning_data = ( ConditioningData( uc, c, cfg_scale, extra_conditioning_info, - postprocessing_settings = PostprocessingSettings(threshold, warmup=0.2) if threshold else None) - .add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)) + postprocessing_settings=PostprocessingSettings( + threshold=threshold, + warmup=warmup, + h_symmetry_time_pct=h_symmetry_time_pct, + v_symmetry_time_pct=v_symmetry_time_pct + ) + ).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)) def make_image(x_T): diff --git a/ldm/invoke/generator/txt2img.py b/ldm/invoke/generator/txt2img.py index 76da3e4904..9903de1309 100644 --- a/ldm/invoke/generator/txt2img.py +++ b/ldm/invoke/generator/txt2img.py @@ -15,8 +15,8 @@ class Txt2Img(Generator): @torch.no_grad() def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta, - conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0, - attention_maps_callback=None, + conditioning,width,height,step_callback=None,threshold=0.0,warmup=0.2,perlin=0.0, + h_symmetry_time_pct=None,v_symmetry_time_pct=None,attention_maps_callback=None, **kwargs): """ Returns a function returning an image derived from the prompt and the initial image @@ -33,8 +33,13 @@ class Txt2Img(Generator): conditioning_data = ( ConditioningData( uc, c, cfg_scale, extra_conditioning_info, - postprocessing_settings = PostprocessingSettings(threshold, warmup=0.2) if threshold else None) - .add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)) + postprocessing_settings=PostprocessingSettings( + threshold=threshold, + warmup=warmup, + h_symmetry_time_pct=h_symmetry_time_pct, + v_symmetry_time_pct=v_symmetry_time_pct + ) + ).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)) def make_image(x_T) -> PIL.Image.Image: pipeline_output = pipeline.image_from_embeddings( @@ -44,8 +49,10 @@ class Txt2Img(Generator): conditioning_data=conditioning_data, callback=step_callback, ) + if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None: attention_maps_callback(pipeline_output.attention_map_saver) + return pipeline.numpy_to_pil(pipeline_output.images)[0] return make_image diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index ff5d3a4d26..a39dfccc3a 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -21,12 +21,14 @@ class Txt2Img2Img(Generator): def get_make_image(self, prompt:str, sampler, steps:int, cfg_scale:float, ddim_eta, conditioning, width:int, height:int, strength:float, - step_callback:Optional[Callable]=None, threshold=0.0, **kwargs): + step_callback:Optional[Callable]=None, threshold=0.0, warmup=0.2, perlin=0.0, + h_symmetry_time_pct=None, v_symmetry_time_pct=None, attention_maps_callback=None, **kwargs): """ Returns a function returning an image derived from the prompt and the initial image Return value depends on the seed at the time you call it kwargs are 'width' and 'height' """ + self.perlin = perlin # noinspection PyTypeChecker pipeline: StableDiffusionGeneratorPipeline = self.model @@ -36,8 +38,13 @@ class Txt2Img2Img(Generator): conditioning_data = ( ConditioningData( uc, c, cfg_scale, extra_conditioning_info, - postprocessing_settings = PostprocessingSettings(threshold=threshold, warmup=0.2) if threshold else None) - .add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)) + postprocessing_settings = PostprocessingSettings( + threshold=threshold, + warmup=0.2, + h_symmetry_time_pct=h_symmetry_time_pct, + v_symmetry_time_pct=v_symmetry_time_pct + ) + ).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)) def make_image(x_T): @@ -69,19 +76,28 @@ class Txt2Img2Img(Generator): if clear_cuda_cache is not None: clear_cuda_cache() - second_pass_noise = self.get_noise_like(resized_latents) + second_pass_noise = self.get_noise_like(resized_latents, override_perlin=True) + + # Clear symmetry for the second pass + from dataclasses import replace + new_postprocessing_settings = replace(conditioning_data.postprocessing_settings, h_symmetry_time_pct=None) + new_postprocessing_settings = replace(new_postprocessing_settings, v_symmetry_time_pct=None) + new_conditioning_data = replace(conditioning_data, postprocessing_settings=new_postprocessing_settings) verbosity = get_verbosity() set_verbosity_error() pipeline_output = pipeline.img2img_from_latents_and_embeddings( resized_latents, num_inference_steps=steps, - conditioning_data=conditioning_data, + conditioning_data=new_conditioning_data, strength=strength, noise=second_pass_noise, callback=step_callback) set_verbosity(verbosity) + if pipeline_output.attention_map_saver is not None and attention_maps_callback is not None: + attention_maps_callback(pipeline_output.attention_map_saver) + return pipeline.numpy_to_pil(pipeline_output.images)[0] @@ -95,13 +111,13 @@ class Txt2Img2Img(Generator): return make_image - def get_noise_like(self, like: torch.Tensor): + def get_noise_like(self, like: torch.Tensor, override_perlin: bool=False): device = like.device if device.type == 'mps': x = torch.randn_like(like, device='cpu', dtype=self.torch_dtype()).to(device) else: x = torch.randn_like(like, device=device, dtype=self.torch_dtype()) - if self.perlin > 0.0: + if self.perlin > 0.0 and override_perlin == False: shape = like.shape x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2]) return x @@ -139,6 +155,9 @@ class Txt2Img2Img(Generator): shape = (1, channels, scaled_height // self.downsampling_factor, scaled_width // self.downsampling_factor) if self.use_mps_noise or device.type == 'mps': - return torch.randn(shape, dtype=self.torch_dtype(), device='cpu').to(device) + tensor = torch.empty(size=shape, device='cpu') + tensor = self.get_noise_like(like=tensor).to(device) else: - return torch.randn(shape, dtype=self.torch_dtype(), device=device) + tensor = torch.empty(size=shape, device=device) + tensor = self.get_noise_like(like=tensor) + return tensor diff --git a/ldm/invoke/readline.py b/ldm/invoke/readline.py index 1e9b31ea8d..542bdeeaed 100644 --- a/ldm/invoke/readline.py +++ b/ldm/invoke/readline.py @@ -58,6 +58,8 @@ COMMANDS = ( '--inpaint_replace','-r', '--png_compression','-z', '--text_mask','-tm', + '--h_symmetry_time_pct', + '--v_symmetry_time_pct', '!fix','!fetch','!replay','!history','!search','!clear', '!models','!switch','!import_model','!optimize_model','!convert_model','!edit_model','!del_model', '!mask','!triggers', @@ -138,7 +140,7 @@ class Completer(object): elif re.match('^'+'|'.join(MODEL_COMMANDS),buffer): self.matches= self._model_completions(text, state) - # looking for a ckpt model + # looking for a ckpt model elif re.match('^'+'|'.join(CKPT_MODEL_COMMANDS),buffer): self.matches= self._model_completions(text, state, ckpt_only=True) @@ -255,7 +257,7 @@ class Completer(object): update our list of models ''' self.models = models - + def _seed_completions(self, text, state): m = re.search('(-S\s?|--seed[=\s]?)(\d*)',text) if m: diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index ca3e608fc0..66e0b94655 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -18,6 +18,8 @@ from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver class PostprocessingSettings: threshold: float warmup: float + h_symmetry_time_pct: Optional[float] + v_symmetry_time_pct: Optional[float] class InvokeAIDiffuserComponent: @@ -30,7 +32,7 @@ class InvokeAIDiffuserComponent: * Hybrid conditioning (used for inpainting) ''' debug_thresholding = False - + last_percent_through = 0.0 @dataclass class ExtraConditioningInfo: @@ -56,6 +58,7 @@ class InvokeAIDiffuserComponent: self.is_running_diffusers = is_running_diffusers self.model_forward_callback = model_forward_callback self.cross_attention_control_context = None + self.last_percent_through = 0.0 @contextmanager def custom_attention_context(self, @@ -164,6 +167,7 @@ class InvokeAIDiffuserComponent: if postprocessing_settings is not None: percent_through = self.calculate_percent_through(sigma, step_index, total_step_count) latents = self.apply_threshold(postprocessing_settings, latents, percent_through) + latents = self.apply_symmetry(postprocessing_settings, latents, percent_through) return latents def calculate_percent_through(self, sigma, step_index, total_step_count): @@ -292,8 +296,12 @@ class InvokeAIDiffuserComponent: self, postprocessing_settings: PostprocessingSettings, latents: torch.Tensor, - percent_through + percent_through: float ) -> torch.Tensor: + + if postprocessing_settings.threshold is None or postprocessing_settings.threshold == 0.0: + return latents + threshold = postprocessing_settings.threshold warmup = postprocessing_settings.warmup @@ -342,6 +350,56 @@ class InvokeAIDiffuserComponent: return latents + def apply_symmetry( + self, + postprocessing_settings: PostprocessingSettings, + latents: torch.Tensor, + percent_through: float + ) -> torch.Tensor: + + # Reset our last percent through if this is our first step. + if percent_through == 0.0: + self.last_percent_through = 0.0 + + if postprocessing_settings is None: + return latents + + # Check for out of bounds + h_symmetry_time_pct = postprocessing_settings.h_symmetry_time_pct + if (h_symmetry_time_pct is not None and (h_symmetry_time_pct <= 0.0 or h_symmetry_time_pct > 1.0)): + h_symmetry_time_pct = None + + v_symmetry_time_pct = postprocessing_settings.v_symmetry_time_pct + if (v_symmetry_time_pct is not None and (v_symmetry_time_pct <= 0.0 or v_symmetry_time_pct > 1.0)): + v_symmetry_time_pct = None + + dev = latents.device.type + + latents.to(device='cpu') + + if ( + h_symmetry_time_pct != None and + self.last_percent_through < h_symmetry_time_pct and + percent_through >= h_symmetry_time_pct + ): + # Horizontal symmetry occurs on the 3rd dimension of the latent + width = latents.shape[3] + x_flipped = torch.flip(latents, dims=[3]) + latents = torch.cat([latents[:, :, :, 0:int(width/2)], x_flipped[:, :, :, int(width/2):int(width)]], dim=3) + + if ( + v_symmetry_time_pct != None and + self.last_percent_through < v_symmetry_time_pct and + percent_through >= v_symmetry_time_pct + ): + # Vertical symmetry occurs on the 2nd dimension of the latent + height = latents.shape[2] + y_flipped = torch.flip(latents, dims=[2]) + latents = torch.cat([latents[:, :, 0:int(height / 2)], y_flipped[:, :, int(height / 2):int(height)]], dim=2) + + self.last_percent_through = percent_through + return latents.to(device=dev) + def estimate_percent_through(self, step_index, sigma): if step_index is not None and self.cross_attention_control_context is not None: # percent_through will never reach 1.0 (but this is intended)