Add thresholding for all diffusers types (#2479)

`generator` now asks `InvokeAIDiffuserComponent` to do postprocessing work on latents after every step. Thresholding - now implemented as replacing latents outside of the threshold with random noise - is called at this point. This postprocessing step is also where we can hook up symmetry and other image latent manipulations in the future.

Note: code at this layer doesn't need to worry about MPS as relevant torch functions are wrapped and made MPS-safe by `generator.py`.
This commit is contained in:
Jonathan 2023-02-14 18:00:34 -06:00 committed by GitHub
parent 0bc55a0d55
commit 093174942b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 73 additions and 37 deletions

View File

@ -34,7 +34,7 @@ from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ldm.invoke.globals import Globals from ldm.invoke.globals import Globals
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, ThresholdSettings from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
from ldm.modules.textual_inversion_manager import TextualInversionManager from ldm.modules.textual_inversion_manager import TextualInversionManager
@ -199,8 +199,10 @@ class ConditioningData:
""" """
extra: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo] = None extra: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo] = None
scheduler_args: dict[str, Any] = field(default_factory=dict) scheduler_args: dict[str, Any] = field(default_factory=dict)
"""Additional arguments to pass to scheduler.step.""" """
threshold: Optional[ThresholdSettings] = None Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
"""
postprocessing_settings: Optional[PostprocessingSettings] = None
@property @property
def dtype(self): def dtype(self):
@ -419,6 +421,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count=len(timesteps), total_step_count=len(timesteps),
additional_guidance=additional_guidance) additional_guidance=additional_guidance)
latents = step_output.prev_sample latents = step_output.prev_sample
latents = self.invokeai_diffuser.do_latent_postprocessing(
postprocessing_settings=conditioning_data.postprocessing_settings,
latents=latents,
sigma=batched_t,
step_index=i,
total_step_count=len(timesteps)
)
predicted_original = getattr(step_output, 'pred_original_sample', None) predicted_original = getattr(step_output, 'pred_original_sample', None)
# TODO resuscitate attention map saving # TODO resuscitate attention map saving
@ -455,7 +466,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
conditioning_data.guidance_scale, conditioning_data.guidance_scale,
step_index=step_index, step_index=step_index,
total_step_count=total_step_count, total_step_count=total_step_count,
threshold=conditioning_data.threshold
) )
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1

View File

@ -7,7 +7,7 @@ from diffusers import logging
from ldm.invoke.generator.base import Generator from ldm.invoke.generator.base import Generator
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline, ConditioningData from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline, ConditioningData
from ldm.models.diffusion.shared_invokeai_diffusion import ThresholdSettings from ldm.models.diffusion.shared_invokeai_diffusion import PostprocessingSettings
class Img2Img(Generator): class Img2Img(Generator):
@ -33,7 +33,7 @@ class Img2Img(Generator):
conditioning_data = ( conditioning_data = (
ConditioningData( ConditioningData(
uc, c, cfg_scale, extra_conditioning_info, uc, c, cfg_scale, extra_conditioning_info,
threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None) postprocessing_settings = PostprocessingSettings(threshold, warmup=0.2) if threshold else None)
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)) .add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))

View File

@ -6,7 +6,7 @@ import torch
from .base import Generator from .base import Generator
from .diffusers_pipeline import StableDiffusionGeneratorPipeline, ConditioningData from .diffusers_pipeline import StableDiffusionGeneratorPipeline, ConditioningData
from ...models.diffusion.shared_invokeai_diffusion import ThresholdSettings from ...models.diffusion.shared_invokeai_diffusion import PostprocessingSettings
class Txt2Img(Generator): class Txt2Img(Generator):
@ -33,7 +33,7 @@ class Txt2Img(Generator):
conditioning_data = ( conditioning_data = (
ConditioningData( ConditioningData(
uc, c, cfg_scale, extra_conditioning_info, uc, c, cfg_scale, extra_conditioning_info,
threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None) postprocessing_settings = PostprocessingSettings(threshold, warmup=0.2) if threshold else None)
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)) .add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
def make_image(x_T) -> PIL.Image.Image: def make_image(x_T) -> PIL.Image.Image:

View File

@ -11,7 +11,7 @@ from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_
from ldm.invoke.generator.base import Generator from ldm.invoke.generator.base import Generator
from ldm.invoke.generator.diffusers_pipeline import trim_to_multiple_of, StableDiffusionGeneratorPipeline, \ from ldm.invoke.generator.diffusers_pipeline import trim_to_multiple_of, StableDiffusionGeneratorPipeline, \
ConditioningData ConditioningData
from ldm.models.diffusion.shared_invokeai_diffusion import ThresholdSettings from ldm.models.diffusion.shared_invokeai_diffusion import PostprocessingSettings
class Txt2Img2Img(Generator): class Txt2Img2Img(Generator):
@ -36,7 +36,7 @@ class Txt2Img2Img(Generator):
conditioning_data = ( conditioning_data = (
ConditioningData( ConditioningData(
uc, c, cfg_scale, extra_conditioning_info, uc, c, cfg_scale, extra_conditioning_info,
threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None) postprocessing_settings = PostprocessingSettings(threshold=threshold, warmup=0.2) if threshold else None)
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)) .add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
def make_image(x_T): def make_image(x_T):
@ -47,7 +47,6 @@ class Txt2Img2Img(Generator):
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
noise=x_T, noise=x_T,
callback=step_callback, callback=step_callback,
# TODO: threshold = threshold,
) )
# Get our initial generation width and height directly from the latent output so # Get our initial generation width and height directly from the latent output so

View File

@ -15,7 +15,7 @@ from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
@dataclass(frozen=True) @dataclass(frozen=True)
class ThresholdSettings: class PostprocessingSettings:
threshold: float threshold: float
warmup: float warmup: float
@ -121,7 +121,6 @@ class InvokeAIDiffuserComponent:
unconditional_guidance_scale: float, unconditional_guidance_scale: float,
step_index: Optional[int]=None, step_index: Optional[int]=None,
total_step_count: Optional[int]=None, total_step_count: Optional[int]=None,
threshold: Optional[ThresholdSettings]=None,
): ):
""" """
:param x: current latents :param x: current latents
@ -130,7 +129,6 @@ class InvokeAIDiffuserComponent:
:param conditioning: embeddings for conditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768] :param conditioning: embeddings for conditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
:param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has :param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has
:param step_index: counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase. If None, will be estimated by comparing sigma against self.model.sigmas . :param step_index: counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase. If None, will be estimated by comparing sigma against self.model.sigmas .
:param threshold: threshold to apply after each step
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning. :return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
""" """
@ -138,15 +136,7 @@ class InvokeAIDiffuserComponent:
cross_attention_control_types_to_do = [] cross_attention_control_types_to_do = []
context: Context = self.cross_attention_control_context context: Context = self.cross_attention_control_context
if self.cross_attention_control_context is not None: if self.cross_attention_control_context is not None:
if step_index is not None and total_step_count is not None: percent_through = self.calculate_percent_through(sigma, step_index, total_step_count)
# 🧨diffusers codepath
percent_through = step_index / total_step_count # will never reach 1.0 - this is deliberate
else:
# legacy compvis codepath
# TODO remove when compvis codepath support is dropped
if step_index is None and sigma is None:
raise ValueError(f"Either step_index or sigma is required when doing cross attention control, but both are None.")
percent_through = self.estimate_percent_through(step_index, sigma)
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through) cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through)
wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0) wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0)
@ -161,11 +151,34 @@ class InvokeAIDiffuserComponent:
combined_next_x = self._combine(unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale) combined_next_x = self._combine(unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale)
if threshold:
combined_next_x = self._threshold(threshold.threshold, threshold.warmup, combined_next_x, sigma)
return combined_next_x return combined_next_x
def do_latent_postprocessing(
self,
postprocessing_settings: PostprocessingSettings,
latents: torch.Tensor,
sigma,
step_index,
total_step_count
) -> torch.Tensor:
if postprocessing_settings is not None:
percent_through = self.calculate_percent_through(sigma, step_index, total_step_count)
latents = self.apply_threshold(postprocessing_settings, latents, percent_through)
return latents
def calculate_percent_through(self, sigma, step_index, total_step_count):
if step_index is not None and total_step_count is not None:
# 🧨diffusers codepath
percent_through = step_index / total_step_count # will never reach 1.0 - this is deliberate
else:
# legacy compvis codepath
# TODO remove when compvis codepath support is dropped
if step_index is None and sigma is None:
raise ValueError(
f"Either step_index or sigma is required when doing cross attention control, but both are None.")
percent_through = self.estimate_percent_through(step_index, sigma)
return percent_through
# methods below are called from do_diffusion_step and should be considered private to this class. # methods below are called from do_diffusion_step and should be considered private to this class.
def apply_standard_conditioning(self, x, sigma, unconditioning, conditioning): def apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
@ -275,17 +288,23 @@ class InvokeAIDiffuserComponent:
combined_next_x = unconditioned_next_x + scaled_delta combined_next_x = unconditioned_next_x + scaled_delta
return combined_next_x return combined_next_x
def _threshold(self, threshold, warmup, latents: torch.Tensor, sigma) -> torch.Tensor: def apply_threshold(
warmup_scale = (1 - sigma.item() / 1000) / warmup if warmup else math.inf self,
if warmup_scale < 1: postprocessing_settings: PostprocessingSettings,
# This arithmetic based on https://github.com/invoke-ai/InvokeAI/pull/395 latents: torch.Tensor,
warming_threshold = 1 + (threshold - 1) * warmup_scale percent_through
current_threshold = np.clip(warming_threshold, 1, threshold) ) -> torch.Tensor:
threshold = postprocessing_settings.threshold
warmup = postprocessing_settings.warmup
if percent_through < warmup:
current_threshold = threshold + threshold * 5 * (1 - (percent_through / warmup))
else: else:
current_threshold = threshold current_threshold = threshold
if current_threshold <= 0: if current_threshold <= 0:
return latents return latents
maxval = latents.max().item() maxval = latents.max().item()
minval = latents.min().item() minval = latents.min().item()
@ -294,25 +313,34 @@ class InvokeAIDiffuserComponent:
if self.debug_thresholding: if self.debug_thresholding:
std, mean = [i.item() for i in torch.std_mean(latents)] std, mean = [i.item() for i in torch.std_mean(latents)]
outside = torch.count_nonzero((latents < -current_threshold) | (latents > current_threshold)) outside = torch.count_nonzero((latents < -current_threshold) | (latents > current_threshold))
print(f"\nThreshold: 𝜎={sigma.item()} threshold={current_threshold:.3f} (of {threshold:.3f})\n" print(f"\nThreshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n" f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n"
f" | {outside / latents.numel() * 100:.2f}% values outside threshold") f" | {outside / latents.numel() * 100:.2f}% values outside threshold")
if maxval < current_threshold and minval > -current_threshold: if maxval < current_threshold and minval > -current_threshold:
return latents return latents
num_altered = 0
# MPS torch.rand_like is fine because torch.rand_like is wrapped in generate.py!
if maxval > current_threshold: if maxval > current_threshold:
latents = torch.clone(latents)
maxval = np.clip(maxval * scale, 1, current_threshold) maxval = np.clip(maxval * scale, 1, current_threshold)
num_altered += torch.count_nonzero(latents > maxval)
latents[latents > maxval] = torch.rand_like(latents[latents > maxval]) * maxval
if minval < -current_threshold: if minval < -current_threshold:
latents = torch.clone(latents)
minval = np.clip(minval * scale, -current_threshold, -1) minval = np.clip(minval * scale, -current_threshold, -1)
num_altered += torch.count_nonzero(latents < minval)
latents[latents < minval] = torch.rand_like(latents[latents < minval]) * minval
if self.debug_thresholding: if self.debug_thresholding:
outside = torch.count_nonzero((latents < minval) | (latents > maxval))
print(f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n" print(f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n"
f" | {outside / latents.numel() * 100:.2f}% values will be clamped") f" | {num_altered / latents.numel() * 100:.2f}% values altered")
return latents.clamp(minval, maxval) return latents
def estimate_percent_through(self, step_index, sigma): def estimate_percent_through(self, step_index, sigma):
if step_index is not None and self.cross_attention_control_context is not None: if step_index is not None and self.cross_attention_control_context is not None:
@ -376,4 +404,3 @@ class InvokeAIDiffuserComponent:
# assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale)))) # assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale))))
return uncond_latents + deltas_merged * global_guidance_scale return uncond_latents + deltas_merged * global_guidance_scale