Add thresholding for all diffusers types (#2479)

`generator` now asks `InvokeAIDiffuserComponent` to do postprocessing work on latents after every step. Thresholding - now implemented as replacing latents outside of the threshold with random noise - is called at this point. This postprocessing step is also where we can hook up symmetry and other image latent manipulations in the future.

Note: code at this layer doesn't need to worry about MPS as relevant torch functions are wrapped and made MPS-safe by `generator.py`.
This commit is contained in:
Jonathan 2023-02-14 18:00:34 -06:00 committed by GitHub
parent 0bc55a0d55
commit 093174942b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 73 additions and 37 deletions

View File

@ -34,7 +34,7 @@ from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ldm.invoke.globals import Globals
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, ThresholdSettings
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
from ldm.modules.textual_inversion_manager import TextualInversionManager
@ -199,8 +199,10 @@ class ConditioningData:
"""
extra: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo] = None
scheduler_args: dict[str, Any] = field(default_factory=dict)
"""Additional arguments to pass to scheduler.step."""
threshold: Optional[ThresholdSettings] = None
"""
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
"""
postprocessing_settings: Optional[PostprocessingSettings] = None
@property
def dtype(self):
@ -419,6 +421,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count=len(timesteps),
additional_guidance=additional_guidance)
latents = step_output.prev_sample
latents = self.invokeai_diffuser.do_latent_postprocessing(
postprocessing_settings=conditioning_data.postprocessing_settings,
latents=latents,
sigma=batched_t,
step_index=i,
total_step_count=len(timesteps)
)
predicted_original = getattr(step_output, 'pred_original_sample', None)
# TODO resuscitate attention map saving
@ -455,7 +466,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
conditioning_data.guidance_scale,
step_index=step_index,
total_step_count=total_step_count,
threshold=conditioning_data.threshold
)
# compute the previous noisy sample x_t -> x_t-1

View File

@ -7,7 +7,7 @@ from diffusers import logging
from ldm.invoke.generator.base import Generator
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline, ConditioningData
from ldm.models.diffusion.shared_invokeai_diffusion import ThresholdSettings
from ldm.models.diffusion.shared_invokeai_diffusion import PostprocessingSettings
class Img2Img(Generator):
@ -33,7 +33,7 @@ class Img2Img(Generator):
conditioning_data = (
ConditioningData(
uc, c, cfg_scale, extra_conditioning_info,
threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None)
postprocessing_settings = PostprocessingSettings(threshold, warmup=0.2) if threshold else None)
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))

View File

@ -6,7 +6,7 @@ import torch
from .base import Generator
from .diffusers_pipeline import StableDiffusionGeneratorPipeline, ConditioningData
from ...models.diffusion.shared_invokeai_diffusion import ThresholdSettings
from ...models.diffusion.shared_invokeai_diffusion import PostprocessingSettings
class Txt2Img(Generator):
@ -33,7 +33,7 @@ class Txt2Img(Generator):
conditioning_data = (
ConditioningData(
uc, c, cfg_scale, extra_conditioning_info,
threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None)
postprocessing_settings = PostprocessingSettings(threshold, warmup=0.2) if threshold else None)
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
def make_image(x_T) -> PIL.Image.Image:

View File

@ -11,7 +11,7 @@ from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_
from ldm.invoke.generator.base import Generator
from ldm.invoke.generator.diffusers_pipeline import trim_to_multiple_of, StableDiffusionGeneratorPipeline, \
ConditioningData
from ldm.models.diffusion.shared_invokeai_diffusion import ThresholdSettings
from ldm.models.diffusion.shared_invokeai_diffusion import PostprocessingSettings
class Txt2Img2Img(Generator):
@ -36,7 +36,7 @@ class Txt2Img2Img(Generator):
conditioning_data = (
ConditioningData(
uc, c, cfg_scale, extra_conditioning_info,
threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None)
postprocessing_settings = PostprocessingSettings(threshold=threshold, warmup=0.2) if threshold else None)
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
def make_image(x_T):
@ -47,7 +47,6 @@ class Txt2Img2Img(Generator):
conditioning_data=conditioning_data,
noise=x_T,
callback=step_callback,
# TODO: threshold = threshold,
)
# Get our initial generation width and height directly from the latent output so

View File

@ -15,7 +15,7 @@ from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
@dataclass(frozen=True)
class ThresholdSettings:
class PostprocessingSettings:
threshold: float
warmup: float
@ -121,7 +121,6 @@ class InvokeAIDiffuserComponent:
unconditional_guidance_scale: float,
step_index: Optional[int]=None,
total_step_count: Optional[int]=None,
threshold: Optional[ThresholdSettings]=None,
):
"""
:param x: current latents
@ -130,7 +129,6 @@ class InvokeAIDiffuserComponent:
:param conditioning: embeddings for conditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
:param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has
:param step_index: counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase. If None, will be estimated by comparing sigma against self.model.sigmas .
:param threshold: threshold to apply after each step
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
"""
@ -138,15 +136,7 @@ class InvokeAIDiffuserComponent:
cross_attention_control_types_to_do = []
context: Context = self.cross_attention_control_context
if self.cross_attention_control_context is not None:
if step_index is not None and total_step_count is not None:
# 🧨diffusers codepath
percent_through = step_index / total_step_count # will never reach 1.0 - this is deliberate
else:
# legacy compvis codepath
# TODO remove when compvis codepath support is dropped
if step_index is None and sigma is None:
raise ValueError(f"Either step_index or sigma is required when doing cross attention control, but both are None.")
percent_through = self.estimate_percent_through(step_index, sigma)
percent_through = self.calculate_percent_through(sigma, step_index, total_step_count)
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through)
wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0)
@ -161,11 +151,34 @@ class InvokeAIDiffuserComponent:
combined_next_x = self._combine(unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale)
if threshold:
combined_next_x = self._threshold(threshold.threshold, threshold.warmup, combined_next_x, sigma)
return combined_next_x
def do_latent_postprocessing(
self,
postprocessing_settings: PostprocessingSettings,
latents: torch.Tensor,
sigma,
step_index,
total_step_count
) -> torch.Tensor:
if postprocessing_settings is not None:
percent_through = self.calculate_percent_through(sigma, step_index, total_step_count)
latents = self.apply_threshold(postprocessing_settings, latents, percent_through)
return latents
def calculate_percent_through(self, sigma, step_index, total_step_count):
if step_index is not None and total_step_count is not None:
# 🧨diffusers codepath
percent_through = step_index / total_step_count # will never reach 1.0 - this is deliberate
else:
# legacy compvis codepath
# TODO remove when compvis codepath support is dropped
if step_index is None and sigma is None:
raise ValueError(
f"Either step_index or sigma is required when doing cross attention control, but both are None.")
percent_through = self.estimate_percent_through(step_index, sigma)
return percent_through
# methods below are called from do_diffusion_step and should be considered private to this class.
def apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
@ -275,17 +288,23 @@ class InvokeAIDiffuserComponent:
combined_next_x = unconditioned_next_x + scaled_delta
return combined_next_x
def _threshold(self, threshold, warmup, latents: torch.Tensor, sigma) -> torch.Tensor:
warmup_scale = (1 - sigma.item() / 1000) / warmup if warmup else math.inf
if warmup_scale < 1:
# This arithmetic based on https://github.com/invoke-ai/InvokeAI/pull/395
warming_threshold = 1 + (threshold - 1) * warmup_scale
current_threshold = np.clip(warming_threshold, 1, threshold)
def apply_threshold(
self,
postprocessing_settings: PostprocessingSettings,
latents: torch.Tensor,
percent_through
) -> torch.Tensor:
threshold = postprocessing_settings.threshold
warmup = postprocessing_settings.warmup
if percent_through < warmup:
current_threshold = threshold + threshold * 5 * (1 - (percent_through / warmup))
else:
current_threshold = threshold
if current_threshold <= 0:
return latents
maxval = latents.max().item()
minval = latents.min().item()
@ -294,25 +313,34 @@ class InvokeAIDiffuserComponent:
if self.debug_thresholding:
std, mean = [i.item() for i in torch.std_mean(latents)]
outside = torch.count_nonzero((latents < -current_threshold) | (latents > current_threshold))
print(f"\nThreshold: 𝜎={sigma.item()} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
print(f"\nThreshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n"
f" | {outside / latents.numel() * 100:.2f}% values outside threshold")
if maxval < current_threshold and minval > -current_threshold:
return latents
num_altered = 0
# MPS torch.rand_like is fine because torch.rand_like is wrapped in generate.py!
if maxval > current_threshold:
latents = torch.clone(latents)
maxval = np.clip(maxval * scale, 1, current_threshold)
num_altered += torch.count_nonzero(latents > maxval)
latents[latents > maxval] = torch.rand_like(latents[latents > maxval]) * maxval
if minval < -current_threshold:
latents = torch.clone(latents)
minval = np.clip(minval * scale, -current_threshold, -1)
num_altered += torch.count_nonzero(latents < minval)
latents[latents < minval] = torch.rand_like(latents[latents < minval]) * minval
if self.debug_thresholding:
outside = torch.count_nonzero((latents < minval) | (latents > maxval))
print(f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n"
f" | {outside / latents.numel() * 100:.2f}% values will be clamped")
f" | {num_altered / latents.numel() * 100:.2f}% values altered")
return latents.clamp(minval, maxval)
return latents
def estimate_percent_through(self, step_index, sigma):
if step_index is not None and self.cross_attention_control_context is not None:
@ -376,4 +404,3 @@ class InvokeAIDiffuserComponent:
# assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale))))
return uncond_latents + deltas_merged * global_guidance_scale