From 093174942b127a3ffa99f27255989c47dd15abbe Mon Sep 17 00:00:00 2001
From: Jonathan <34005131+JPPhoto@users.noreply.github.com>
Date: Tue, 14 Feb 2023 18:00:34 -0600
Subject: [PATCH] Add thresholding for all diffusers types (#2479)

`generator` now asks `InvokeAIDiffuserComponent` to do postprocessing work on latents after every step. Thresholding - now implemented as replacing latents outside of the threshold with random noise - is called at this point. This postprocessing step is also where we can hook up symmetry and other image latent manipulations in the future.

Note: code at this layer doesn't need to worry about MPS as relevant torch functions are wrapped and made MPS-safe by `generator.py`.
---
 ldm/invoke/generator/diffusers_pipeline.py    | 18 ++++-
 ldm/invoke/generator/img2img.py               |  4 +-
 ldm/invoke/generator/txt2img.py               |  4 +-
 ldm/invoke/generator/txt2img2img.py           |  5 +-
 .../diffusion/shared_invokeai_diffusion.py    | 79 +++++++++++++------
 5 files changed, 73 insertions(+), 37 deletions(-)

diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py
index 24626247cf..686fb40d3a 100644
--- a/ldm/invoke/generator/diffusers_pipeline.py
+++ b/ldm/invoke/generator/diffusers_pipeline.py
@@ -34,7 +34,7 @@ from torchvision.transforms.functional import resize as tv_resize
 from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
 
 from ldm.invoke.globals import Globals
-from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, ThresholdSettings
+from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
 from ldm.modules.textual_inversion_manager import TextualInversionManager
 
 
@@ -199,8 +199,10 @@ class ConditioningData:
     """
     extra: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo] = None
     scheduler_args: dict[str, Any] = field(default_factory=dict)
-    """Additional arguments to pass to scheduler.step."""
-    threshold: Optional[ThresholdSettings] = None
+    """
+    Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
+    """
+    postprocessing_settings: Optional[PostprocessingSettings] = None
 
     @property
     def dtype(self):
@@ -419,6 +421,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
                                         total_step_count=len(timesteps),
                                         additional_guidance=additional_guidance)
                 latents = step_output.prev_sample
+
+                latents = self.invokeai_diffuser.do_latent_postprocessing(
+                    postprocessing_settings=conditioning_data.postprocessing_settings,
+                    latents=latents,
+                    sigma=batched_t,
+                    step_index=i,
+                    total_step_count=len(timesteps)
+                )
+
                 predicted_original = getattr(step_output, 'pred_original_sample', None)
 
                 # TODO resuscitate attention map saving
@@ -455,7 +466,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
             conditioning_data.guidance_scale,
             step_index=step_index,
             total_step_count=total_step_count,
-            threshold=conditioning_data.threshold
         )
 
         # compute the previous noisy sample x_t -> x_t-1
diff --git a/ldm/invoke/generator/img2img.py b/ldm/invoke/generator/img2img.py
index bfa50617ef..0b762f7c98 100644
--- a/ldm/invoke/generator/img2img.py
+++ b/ldm/invoke/generator/img2img.py
@@ -7,7 +7,7 @@ from diffusers import logging
 
 from ldm.invoke.generator.base import Generator
 from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline, ConditioningData
-from ldm.models.diffusion.shared_invokeai_diffusion import ThresholdSettings
+from ldm.models.diffusion.shared_invokeai_diffusion import PostprocessingSettings
 
 
 class Img2Img(Generator):
@@ -33,7 +33,7 @@ class Img2Img(Generator):
         conditioning_data = (
             ConditioningData(
                 uc, c, cfg_scale, extra_conditioning_info,
-                threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None)
+                postprocessing_settings = PostprocessingSettings(threshold, warmup=0.2) if threshold else None)
             .add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
 
 
diff --git a/ldm/invoke/generator/txt2img.py b/ldm/invoke/generator/txt2img.py
index 6578794fa7..76da3e4904 100644
--- a/ldm/invoke/generator/txt2img.py
+++ b/ldm/invoke/generator/txt2img.py
@@ -6,7 +6,7 @@ import torch
 
 from .base import Generator
 from .diffusers_pipeline import StableDiffusionGeneratorPipeline, ConditioningData
-from ...models.diffusion.shared_invokeai_diffusion import ThresholdSettings
+from ...models.diffusion.shared_invokeai_diffusion import PostprocessingSettings
 
 
 class Txt2Img(Generator):
@@ -33,7 +33,7 @@ class Txt2Img(Generator):
         conditioning_data = (
             ConditioningData(
                 uc, c, cfg_scale, extra_conditioning_info,
-                threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None)
+                postprocessing_settings = PostprocessingSettings(threshold, warmup=0.2) if threshold else None)
             .add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
 
         def make_image(x_T) -> PIL.Image.Image:
diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py
index 9632a8d4b0..ff5d3a4d26 100644
--- a/ldm/invoke/generator/txt2img2img.py
+++ b/ldm/invoke/generator/txt2img2img.py
@@ -11,7 +11,7 @@ from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_
 from ldm.invoke.generator.base import Generator
 from ldm.invoke.generator.diffusers_pipeline import trim_to_multiple_of, StableDiffusionGeneratorPipeline, \
     ConditioningData
-from ldm.models.diffusion.shared_invokeai_diffusion import ThresholdSettings
+from ldm.models.diffusion.shared_invokeai_diffusion import PostprocessingSettings
 
 
 class Txt2Img2Img(Generator):
@@ -36,7 +36,7 @@ class Txt2Img2Img(Generator):
         conditioning_data = (
             ConditioningData(
                 uc, c, cfg_scale, extra_conditioning_info,
-                threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None)
+                postprocessing_settings = PostprocessingSettings(threshold=threshold, warmup=0.2) if threshold else None)
             .add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
 
         def make_image(x_T):
@@ -47,7 +47,6 @@ class Txt2Img2Img(Generator):
                 conditioning_data=conditioning_data,
                 noise=x_T,
                 callback=step_callback,
-                # TODO: threshold = threshold,
             )
 
             # Get our initial generation width and height directly from the latent output so
diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py
index f37bec789e..ca3e608fc0 100644
--- a/ldm/models/diffusion/shared_invokeai_diffusion.py
+++ b/ldm/models/diffusion/shared_invokeai_diffusion.py
@@ -15,7 +15,7 @@ from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
 
 
 @dataclass(frozen=True)
-class ThresholdSettings:
+class PostprocessingSettings:
     threshold: float
     warmup: float
 
@@ -121,7 +121,6 @@ class InvokeAIDiffuserComponent:
                                 unconditional_guidance_scale: float,
                                 step_index: Optional[int]=None,
                                 total_step_count: Optional[int]=None,
-                                threshold: Optional[ThresholdSettings]=None,
                           ):
         """
         :param x: current latents
@@ -130,7 +129,6 @@ class InvokeAIDiffuserComponent:
         :param conditioning: embeddings for conditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
         :param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has
         :param step_index: counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase. If None, will be estimated by comparing sigma against self.model.sigmas .
-        :param threshold: threshold to apply after each step
         :return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
         """
 
@@ -138,15 +136,7 @@ class InvokeAIDiffuserComponent:
         cross_attention_control_types_to_do = []
         context: Context = self.cross_attention_control_context
         if self.cross_attention_control_context is not None:
-            if step_index is not None and total_step_count is not None:
-                # 🧨diffusers codepath
-                percent_through = step_index / total_step_count  # will never reach 1.0 - this is deliberate
-            else:
-                # legacy compvis codepath
-                # TODO remove when compvis codepath support is dropped
-                if step_index is None and sigma is None:
-                    raise ValueError(f"Either step_index or sigma is required when doing cross attention control, but both are None.")
-                percent_through = self.estimate_percent_through(step_index, sigma)
+            percent_through = self.calculate_percent_through(sigma, step_index, total_step_count)
             cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through)
 
         wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0)
@@ -161,11 +151,34 @@ class InvokeAIDiffuserComponent:
 
         combined_next_x = self._combine(unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale)
 
-        if threshold:
-            combined_next_x = self._threshold(threshold.threshold, threshold.warmup, combined_next_x, sigma)
-
         return combined_next_x
 
+    def do_latent_postprocessing(
+        self,
+        postprocessing_settings: PostprocessingSettings,
+        latents: torch.Tensor,
+        sigma,
+        step_index,
+        total_step_count
+    ) -> torch.Tensor:
+        if postprocessing_settings is not None:
+            percent_through = self.calculate_percent_through(sigma, step_index, total_step_count)
+            latents = self.apply_threshold(postprocessing_settings, latents, percent_through)
+        return latents
+
+    def calculate_percent_through(self, sigma, step_index, total_step_count):
+        if step_index is not None and total_step_count is not None:
+            # 🧨diffusers codepath
+            percent_through = step_index / total_step_count  # will never reach 1.0 - this is deliberate
+        else:
+            # legacy compvis codepath
+            # TODO remove when compvis codepath support is dropped
+            if step_index is None and sigma is None:
+                raise ValueError(
+                    f"Either step_index or sigma is required when doing cross attention control, but both are None.")
+            percent_through = self.estimate_percent_through(step_index, sigma)
+        return percent_through
+
     # methods below are called from do_diffusion_step and should be considered private to this class.
 
     def apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
@@ -275,17 +288,23 @@ class InvokeAIDiffuserComponent:
         combined_next_x = unconditioned_next_x + scaled_delta
         return combined_next_x
 
-    def _threshold(self, threshold, warmup, latents: torch.Tensor, sigma) -> torch.Tensor:
-        warmup_scale = (1 - sigma.item() / 1000) / warmup if warmup else math.inf
-        if warmup_scale < 1:
-            # This arithmetic based on https://github.com/invoke-ai/InvokeAI/pull/395
-            warming_threshold = 1 + (threshold - 1) * warmup_scale
-            current_threshold = np.clip(warming_threshold, 1, threshold)
+    def apply_threshold(
+        self,
+        postprocessing_settings: PostprocessingSettings,
+        latents: torch.Tensor,
+        percent_through
+    ) -> torch.Tensor:
+        threshold = postprocessing_settings.threshold
+        warmup = postprocessing_settings.warmup
+
+        if percent_through < warmup:
+            current_threshold = threshold + threshold * 5 * (1 - (percent_through / warmup))
         else:
             current_threshold = threshold
 
         if current_threshold <= 0:
             return latents
+
         maxval = latents.max().item()
         minval = latents.min().item()
 
@@ -294,25 +313,34 @@ class InvokeAIDiffuserComponent:
         if self.debug_thresholding:
             std, mean = [i.item() for i in torch.std_mean(latents)]
             outside = torch.count_nonzero((latents < -current_threshold) | (latents > current_threshold))
-            print(f"\nThreshold: 𝜎={sigma.item()} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
+            print(f"\nThreshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
                   f"  | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n"
                   f"  | {outside / latents.numel() * 100:.2f}% values outside threshold")
 
         if maxval < current_threshold and minval > -current_threshold:
             return latents
 
+        num_altered = 0
+
+        # MPS torch.rand_like is fine because torch.rand_like is wrapped in generate.py!
+
         if maxval > current_threshold:
+            latents = torch.clone(latents)
             maxval = np.clip(maxval * scale, 1, current_threshold)
+            num_altered += torch.count_nonzero(latents > maxval)
+            latents[latents > maxval] = torch.rand_like(latents[latents > maxval]) * maxval
 
         if minval < -current_threshold:
+            latents = torch.clone(latents)
             minval = np.clip(minval * scale, -current_threshold, -1)
+            num_altered += torch.count_nonzero(latents < minval)
+            latents[latents < minval] = torch.rand_like(latents[latents < minval]) * minval
 
         if self.debug_thresholding:
-            outside = torch.count_nonzero((latents < minval) | (latents > maxval))
             print(f"  | min,     , max = {minval:.3f},        , {maxval:.3f}\t(scaled by {scale})\n"
-                  f"  | {outside / latents.numel() * 100:.2f}% values will be clamped")
+                  f"  | {num_altered / latents.numel() * 100:.2f}% values altered")
 
-        return latents.clamp(minval, maxval)
+        return latents
 
     def estimate_percent_through(self, step_index, sigma):
         if step_index is not None and self.cross_attention_control_context is not None:
@@ -376,4 +404,3 @@ class InvokeAIDiffuserComponent:
         # assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale))))
 
         return uncond_latents + deltas_merged * global_guidance_scale
-