mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add thresholding for all diffusers types (#2479)
`generator` now asks `InvokeAIDiffuserComponent` to do postprocessing work on latents after every step. Thresholding - now implemented as replacing latents outside of the threshold with random noise - is called at this point. This postprocessing step is also where we can hook up symmetry and other image latent manipulations in the future. Note: code at this layer doesn't need to worry about MPS as relevant torch functions are wrapped and made MPS-safe by `generator.py`.
This commit is contained in:
parent
0bc55a0d55
commit
093174942b
@ -34,7 +34,7 @@ from torchvision.transforms.functional import resize as tv_resize
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ldm.invoke.globals import Globals
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, ThresholdSettings
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
|
||||
from ldm.modules.textual_inversion_manager import TextualInversionManager
|
||||
|
||||
|
||||
@ -199,8 +199,10 @@ class ConditioningData:
|
||||
"""
|
||||
extra: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo] = None
|
||||
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
||||
"""Additional arguments to pass to scheduler.step."""
|
||||
threshold: Optional[ThresholdSettings] = None
|
||||
"""
|
||||
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
|
||||
"""
|
||||
postprocessing_settings: Optional[PostprocessingSettings] = None
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
@ -419,6 +421,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
total_step_count=len(timesteps),
|
||||
additional_guidance=additional_guidance)
|
||||
latents = step_output.prev_sample
|
||||
|
||||
latents = self.invokeai_diffuser.do_latent_postprocessing(
|
||||
postprocessing_settings=conditioning_data.postprocessing_settings,
|
||||
latents=latents,
|
||||
sigma=batched_t,
|
||||
step_index=i,
|
||||
total_step_count=len(timesteps)
|
||||
)
|
||||
|
||||
predicted_original = getattr(step_output, 'pred_original_sample', None)
|
||||
|
||||
# TODO resuscitate attention map saving
|
||||
@ -455,7 +466,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
conditioning_data.guidance_scale,
|
||||
step_index=step_index,
|
||||
total_step_count=total_step_count,
|
||||
threshold=conditioning_data.threshold
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
|
@ -7,7 +7,7 @@ from diffusers import logging
|
||||
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline, ConditioningData
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import ThresholdSettings
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
|
||||
|
||||
class Img2Img(Generator):
|
||||
@ -33,7 +33,7 @@ class Img2Img(Generator):
|
||||
conditioning_data = (
|
||||
ConditioningData(
|
||||
uc, c, cfg_scale, extra_conditioning_info,
|
||||
threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None)
|
||||
postprocessing_settings = PostprocessingSettings(threshold, warmup=0.2) if threshold else None)
|
||||
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||
|
||||
|
||||
|
@ -6,7 +6,7 @@ import torch
|
||||
|
||||
from .base import Generator
|
||||
from .diffusers_pipeline import StableDiffusionGeneratorPipeline, ConditioningData
|
||||
from ...models.diffusion.shared_invokeai_diffusion import ThresholdSettings
|
||||
from ...models.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
|
||||
|
||||
class Txt2Img(Generator):
|
||||
@ -33,7 +33,7 @@ class Txt2Img(Generator):
|
||||
conditioning_data = (
|
||||
ConditioningData(
|
||||
uc, c, cfg_scale, extra_conditioning_info,
|
||||
threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None)
|
||||
postprocessing_settings = PostprocessingSettings(threshold, warmup=0.2) if threshold else None)
|
||||
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||
|
||||
def make_image(x_T) -> PIL.Image.Image:
|
||||
|
@ -11,7 +11,7 @@ from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.invoke.generator.diffusers_pipeline import trim_to_multiple_of, StableDiffusionGeneratorPipeline, \
|
||||
ConditioningData
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import ThresholdSettings
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
|
||||
|
||||
class Txt2Img2Img(Generator):
|
||||
@ -36,7 +36,7 @@ class Txt2Img2Img(Generator):
|
||||
conditioning_data = (
|
||||
ConditioningData(
|
||||
uc, c, cfg_scale, extra_conditioning_info,
|
||||
threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None)
|
||||
postprocessing_settings = PostprocessingSettings(threshold=threshold, warmup=0.2) if threshold else None)
|
||||
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||
|
||||
def make_image(x_T):
|
||||
@ -47,7 +47,6 @@ class Txt2Img2Img(Generator):
|
||||
conditioning_data=conditioning_data,
|
||||
noise=x_T,
|
||||
callback=step_callback,
|
||||
# TODO: threshold = threshold,
|
||||
)
|
||||
|
||||
# Get our initial generation width and height directly from the latent output so
|
||||
|
@ -15,7 +15,7 @@ from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ThresholdSettings:
|
||||
class PostprocessingSettings:
|
||||
threshold: float
|
||||
warmup: float
|
||||
|
||||
@ -121,7 +121,6 @@ class InvokeAIDiffuserComponent:
|
||||
unconditional_guidance_scale: float,
|
||||
step_index: Optional[int]=None,
|
||||
total_step_count: Optional[int]=None,
|
||||
threshold: Optional[ThresholdSettings]=None,
|
||||
):
|
||||
"""
|
||||
:param x: current latents
|
||||
@ -130,7 +129,6 @@ class InvokeAIDiffuserComponent:
|
||||
:param conditioning: embeddings for conditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
|
||||
:param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has
|
||||
:param step_index: counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase. If None, will be estimated by comparing sigma against self.model.sigmas .
|
||||
:param threshold: threshold to apply after each step
|
||||
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
|
||||
"""
|
||||
|
||||
@ -138,15 +136,7 @@ class InvokeAIDiffuserComponent:
|
||||
cross_attention_control_types_to_do = []
|
||||
context: Context = self.cross_attention_control_context
|
||||
if self.cross_attention_control_context is not None:
|
||||
if step_index is not None and total_step_count is not None:
|
||||
# 🧨diffusers codepath
|
||||
percent_through = step_index / total_step_count # will never reach 1.0 - this is deliberate
|
||||
else:
|
||||
# legacy compvis codepath
|
||||
# TODO remove when compvis codepath support is dropped
|
||||
if step_index is None and sigma is None:
|
||||
raise ValueError(f"Either step_index or sigma is required when doing cross attention control, but both are None.")
|
||||
percent_through = self.estimate_percent_through(step_index, sigma)
|
||||
percent_through = self.calculate_percent_through(sigma, step_index, total_step_count)
|
||||
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through)
|
||||
|
||||
wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0)
|
||||
@ -161,11 +151,34 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
combined_next_x = self._combine(unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale)
|
||||
|
||||
if threshold:
|
||||
combined_next_x = self._threshold(threshold.threshold, threshold.warmup, combined_next_x, sigma)
|
||||
|
||||
return combined_next_x
|
||||
|
||||
def do_latent_postprocessing(
|
||||
self,
|
||||
postprocessing_settings: PostprocessingSettings,
|
||||
latents: torch.Tensor,
|
||||
sigma,
|
||||
step_index,
|
||||
total_step_count
|
||||
) -> torch.Tensor:
|
||||
if postprocessing_settings is not None:
|
||||
percent_through = self.calculate_percent_through(sigma, step_index, total_step_count)
|
||||
latents = self.apply_threshold(postprocessing_settings, latents, percent_through)
|
||||
return latents
|
||||
|
||||
def calculate_percent_through(self, sigma, step_index, total_step_count):
|
||||
if step_index is not None and total_step_count is not None:
|
||||
# 🧨diffusers codepath
|
||||
percent_through = step_index / total_step_count # will never reach 1.0 - this is deliberate
|
||||
else:
|
||||
# legacy compvis codepath
|
||||
# TODO remove when compvis codepath support is dropped
|
||||
if step_index is None and sigma is None:
|
||||
raise ValueError(
|
||||
f"Either step_index or sigma is required when doing cross attention control, but both are None.")
|
||||
percent_through = self.estimate_percent_through(step_index, sigma)
|
||||
return percent_through
|
||||
|
||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||
|
||||
def apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
|
||||
@ -275,17 +288,23 @@ class InvokeAIDiffuserComponent:
|
||||
combined_next_x = unconditioned_next_x + scaled_delta
|
||||
return combined_next_x
|
||||
|
||||
def _threshold(self, threshold, warmup, latents: torch.Tensor, sigma) -> torch.Tensor:
|
||||
warmup_scale = (1 - sigma.item() / 1000) / warmup if warmup else math.inf
|
||||
if warmup_scale < 1:
|
||||
# This arithmetic based on https://github.com/invoke-ai/InvokeAI/pull/395
|
||||
warming_threshold = 1 + (threshold - 1) * warmup_scale
|
||||
current_threshold = np.clip(warming_threshold, 1, threshold)
|
||||
def apply_threshold(
|
||||
self,
|
||||
postprocessing_settings: PostprocessingSettings,
|
||||
latents: torch.Tensor,
|
||||
percent_through
|
||||
) -> torch.Tensor:
|
||||
threshold = postprocessing_settings.threshold
|
||||
warmup = postprocessing_settings.warmup
|
||||
|
||||
if percent_through < warmup:
|
||||
current_threshold = threshold + threshold * 5 * (1 - (percent_through / warmup))
|
||||
else:
|
||||
current_threshold = threshold
|
||||
|
||||
if current_threshold <= 0:
|
||||
return latents
|
||||
|
||||
maxval = latents.max().item()
|
||||
minval = latents.min().item()
|
||||
|
||||
@ -294,25 +313,34 @@ class InvokeAIDiffuserComponent:
|
||||
if self.debug_thresholding:
|
||||
std, mean = [i.item() for i in torch.std_mean(latents)]
|
||||
outside = torch.count_nonzero((latents < -current_threshold) | (latents > current_threshold))
|
||||
print(f"\nThreshold: 𝜎={sigma.item()} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
|
||||
print(f"\nThreshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
|
||||
f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n"
|
||||
f" | {outside / latents.numel() * 100:.2f}% values outside threshold")
|
||||
|
||||
if maxval < current_threshold and minval > -current_threshold:
|
||||
return latents
|
||||
|
||||
num_altered = 0
|
||||
|
||||
# MPS torch.rand_like is fine because torch.rand_like is wrapped in generate.py!
|
||||
|
||||
if maxval > current_threshold:
|
||||
latents = torch.clone(latents)
|
||||
maxval = np.clip(maxval * scale, 1, current_threshold)
|
||||
num_altered += torch.count_nonzero(latents > maxval)
|
||||
latents[latents > maxval] = torch.rand_like(latents[latents > maxval]) * maxval
|
||||
|
||||
if minval < -current_threshold:
|
||||
latents = torch.clone(latents)
|
||||
minval = np.clip(minval * scale, -current_threshold, -1)
|
||||
num_altered += torch.count_nonzero(latents < minval)
|
||||
latents[latents < minval] = torch.rand_like(latents[latents < minval]) * minval
|
||||
|
||||
if self.debug_thresholding:
|
||||
outside = torch.count_nonzero((latents < minval) | (latents > maxval))
|
||||
print(f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n"
|
||||
f" | {outside / latents.numel() * 100:.2f}% values will be clamped")
|
||||
f" | {num_altered / latents.numel() * 100:.2f}% values altered")
|
||||
|
||||
return latents.clamp(minval, maxval)
|
||||
return latents
|
||||
|
||||
def estimate_percent_through(self, step_index, sigma):
|
||||
if step_index is not None and self.cross_attention_control_context is not None:
|
||||
@ -376,4 +404,3 @@ class InvokeAIDiffuserComponent:
|
||||
# assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale))))
|
||||
|
||||
return uncond_latents + deltas_merged * global_guidance_scale
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user