mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add thresholding for all diffusers types (#2479)
`generator` now asks `InvokeAIDiffuserComponent` to do postprocessing work on latents after every step. Thresholding - now implemented as replacing latents outside of the threshold with random noise - is called at this point. This postprocessing step is also where we can hook up symmetry and other image latent manipulations in the future. Note: code at this layer doesn't need to worry about MPS as relevant torch functions are wrapped and made MPS-safe by `generator.py`.
This commit is contained in:
parent
0bc55a0d55
commit
093174942b
@ -34,7 +34,7 @@ from torchvision.transforms.functional import resize as tv_resize
|
|||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from ldm.invoke.globals import Globals
|
from ldm.invoke.globals import Globals
|
||||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, ThresholdSettings
|
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings
|
||||||
from ldm.modules.textual_inversion_manager import TextualInversionManager
|
from ldm.modules.textual_inversion_manager import TextualInversionManager
|
||||||
|
|
||||||
|
|
||||||
@ -199,8 +199,10 @@ class ConditioningData:
|
|||||||
"""
|
"""
|
||||||
extra: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo] = None
|
extra: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo] = None
|
||||||
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
||||||
"""Additional arguments to pass to scheduler.step."""
|
"""
|
||||||
threshold: Optional[ThresholdSettings] = None
|
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
|
||||||
|
"""
|
||||||
|
postprocessing_settings: Optional[PostprocessingSettings] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
@ -419,6 +421,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
total_step_count=len(timesteps),
|
total_step_count=len(timesteps),
|
||||||
additional_guidance=additional_guidance)
|
additional_guidance=additional_guidance)
|
||||||
latents = step_output.prev_sample
|
latents = step_output.prev_sample
|
||||||
|
|
||||||
|
latents = self.invokeai_diffuser.do_latent_postprocessing(
|
||||||
|
postprocessing_settings=conditioning_data.postprocessing_settings,
|
||||||
|
latents=latents,
|
||||||
|
sigma=batched_t,
|
||||||
|
step_index=i,
|
||||||
|
total_step_count=len(timesteps)
|
||||||
|
)
|
||||||
|
|
||||||
predicted_original = getattr(step_output, 'pred_original_sample', None)
|
predicted_original = getattr(step_output, 'pred_original_sample', None)
|
||||||
|
|
||||||
# TODO resuscitate attention map saving
|
# TODO resuscitate attention map saving
|
||||||
@ -455,7 +466,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
conditioning_data.guidance_scale,
|
conditioning_data.guidance_scale,
|
||||||
step_index=step_index,
|
step_index=step_index,
|
||||||
total_step_count=total_step_count,
|
total_step_count=total_step_count,
|
||||||
threshold=conditioning_data.threshold
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
@ -7,7 +7,7 @@ from diffusers import logging
|
|||||||
|
|
||||||
from ldm.invoke.generator.base import Generator
|
from ldm.invoke.generator.base import Generator
|
||||||
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline, ConditioningData
|
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline, ConditioningData
|
||||||
from ldm.models.diffusion.shared_invokeai_diffusion import ThresholdSettings
|
from ldm.models.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||||
|
|
||||||
|
|
||||||
class Img2Img(Generator):
|
class Img2Img(Generator):
|
||||||
@ -33,7 +33,7 @@ class Img2Img(Generator):
|
|||||||
conditioning_data = (
|
conditioning_data = (
|
||||||
ConditioningData(
|
ConditioningData(
|
||||||
uc, c, cfg_scale, extra_conditioning_info,
|
uc, c, cfg_scale, extra_conditioning_info,
|
||||||
threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None)
|
postprocessing_settings = PostprocessingSettings(threshold, warmup=0.2) if threshold else None)
|
||||||
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ import torch
|
|||||||
|
|
||||||
from .base import Generator
|
from .base import Generator
|
||||||
from .diffusers_pipeline import StableDiffusionGeneratorPipeline, ConditioningData
|
from .diffusers_pipeline import StableDiffusionGeneratorPipeline, ConditioningData
|
||||||
from ...models.diffusion.shared_invokeai_diffusion import ThresholdSettings
|
from ...models.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||||
|
|
||||||
|
|
||||||
class Txt2Img(Generator):
|
class Txt2Img(Generator):
|
||||||
@ -33,7 +33,7 @@ class Txt2Img(Generator):
|
|||||||
conditioning_data = (
|
conditioning_data = (
|
||||||
ConditioningData(
|
ConditioningData(
|
||||||
uc, c, cfg_scale, extra_conditioning_info,
|
uc, c, cfg_scale, extra_conditioning_info,
|
||||||
threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None)
|
postprocessing_settings = PostprocessingSettings(threshold, warmup=0.2) if threshold else None)
|
||||||
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||||
|
|
||||||
def make_image(x_T) -> PIL.Image.Image:
|
def make_image(x_T) -> PIL.Image.Image:
|
||||||
|
@ -11,7 +11,7 @@ from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_
|
|||||||
from ldm.invoke.generator.base import Generator
|
from ldm.invoke.generator.base import Generator
|
||||||
from ldm.invoke.generator.diffusers_pipeline import trim_to_multiple_of, StableDiffusionGeneratorPipeline, \
|
from ldm.invoke.generator.diffusers_pipeline import trim_to_multiple_of, StableDiffusionGeneratorPipeline, \
|
||||||
ConditioningData
|
ConditioningData
|
||||||
from ldm.models.diffusion.shared_invokeai_diffusion import ThresholdSettings
|
from ldm.models.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||||
|
|
||||||
|
|
||||||
class Txt2Img2Img(Generator):
|
class Txt2Img2Img(Generator):
|
||||||
@ -36,7 +36,7 @@ class Txt2Img2Img(Generator):
|
|||||||
conditioning_data = (
|
conditioning_data = (
|
||||||
ConditioningData(
|
ConditioningData(
|
||||||
uc, c, cfg_scale, extra_conditioning_info,
|
uc, c, cfg_scale, extra_conditioning_info,
|
||||||
threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None)
|
postprocessing_settings = PostprocessingSettings(threshold=threshold, warmup=0.2) if threshold else None)
|
||||||
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||||
|
|
||||||
def make_image(x_T):
|
def make_image(x_T):
|
||||||
@ -47,7 +47,6 @@ class Txt2Img2Img(Generator):
|
|||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
noise=x_T,
|
noise=x_T,
|
||||||
callback=step_callback,
|
callback=step_callback,
|
||||||
# TODO: threshold = threshold,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get our initial generation width and height directly from the latent output so
|
# Get our initial generation width and height directly from the latent output so
|
||||||
|
@ -15,7 +15,7 @@ from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
|||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ThresholdSettings:
|
class PostprocessingSettings:
|
||||||
threshold: float
|
threshold: float
|
||||||
warmup: float
|
warmup: float
|
||||||
|
|
||||||
@ -121,7 +121,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
unconditional_guidance_scale: float,
|
unconditional_guidance_scale: float,
|
||||||
step_index: Optional[int]=None,
|
step_index: Optional[int]=None,
|
||||||
total_step_count: Optional[int]=None,
|
total_step_count: Optional[int]=None,
|
||||||
threshold: Optional[ThresholdSettings]=None,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param x: current latents
|
:param x: current latents
|
||||||
@ -130,7 +129,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
:param conditioning: embeddings for conditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
|
:param conditioning: embeddings for conditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768]
|
||||||
:param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has
|
:param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has
|
||||||
:param step_index: counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase. If None, will be estimated by comparing sigma against self.model.sigmas .
|
:param step_index: counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase. If None, will be estimated by comparing sigma against self.model.sigmas .
|
||||||
:param threshold: threshold to apply after each step
|
|
||||||
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
|
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -138,15 +136,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
cross_attention_control_types_to_do = []
|
cross_attention_control_types_to_do = []
|
||||||
context: Context = self.cross_attention_control_context
|
context: Context = self.cross_attention_control_context
|
||||||
if self.cross_attention_control_context is not None:
|
if self.cross_attention_control_context is not None:
|
||||||
if step_index is not None and total_step_count is not None:
|
percent_through = self.calculate_percent_through(sigma, step_index, total_step_count)
|
||||||
# 🧨diffusers codepath
|
|
||||||
percent_through = step_index / total_step_count # will never reach 1.0 - this is deliberate
|
|
||||||
else:
|
|
||||||
# legacy compvis codepath
|
|
||||||
# TODO remove when compvis codepath support is dropped
|
|
||||||
if step_index is None and sigma is None:
|
|
||||||
raise ValueError(f"Either step_index or sigma is required when doing cross attention control, but both are None.")
|
|
||||||
percent_through = self.estimate_percent_through(step_index, sigma)
|
|
||||||
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through)
|
cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through)
|
||||||
|
|
||||||
wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0)
|
wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0)
|
||||||
@ -161,11 +151,34 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
combined_next_x = self._combine(unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale)
|
combined_next_x = self._combine(unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale)
|
||||||
|
|
||||||
if threshold:
|
|
||||||
combined_next_x = self._threshold(threshold.threshold, threshold.warmup, combined_next_x, sigma)
|
|
||||||
|
|
||||||
return combined_next_x
|
return combined_next_x
|
||||||
|
|
||||||
|
def do_latent_postprocessing(
|
||||||
|
self,
|
||||||
|
postprocessing_settings: PostprocessingSettings,
|
||||||
|
latents: torch.Tensor,
|
||||||
|
sigma,
|
||||||
|
step_index,
|
||||||
|
total_step_count
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if postprocessing_settings is not None:
|
||||||
|
percent_through = self.calculate_percent_through(sigma, step_index, total_step_count)
|
||||||
|
latents = self.apply_threshold(postprocessing_settings, latents, percent_through)
|
||||||
|
return latents
|
||||||
|
|
||||||
|
def calculate_percent_through(self, sigma, step_index, total_step_count):
|
||||||
|
if step_index is not None and total_step_count is not None:
|
||||||
|
# 🧨diffusers codepath
|
||||||
|
percent_through = step_index / total_step_count # will never reach 1.0 - this is deliberate
|
||||||
|
else:
|
||||||
|
# legacy compvis codepath
|
||||||
|
# TODO remove when compvis codepath support is dropped
|
||||||
|
if step_index is None and sigma is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Either step_index or sigma is required when doing cross attention control, but both are None.")
|
||||||
|
percent_through = self.estimate_percent_through(step_index, sigma)
|
||||||
|
return percent_through
|
||||||
|
|
||||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||||
|
|
||||||
def apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
|
def apply_standard_conditioning(self, x, sigma, unconditioning, conditioning):
|
||||||
@ -275,17 +288,23 @@ class InvokeAIDiffuserComponent:
|
|||||||
combined_next_x = unconditioned_next_x + scaled_delta
|
combined_next_x = unconditioned_next_x + scaled_delta
|
||||||
return combined_next_x
|
return combined_next_x
|
||||||
|
|
||||||
def _threshold(self, threshold, warmup, latents: torch.Tensor, sigma) -> torch.Tensor:
|
def apply_threshold(
|
||||||
warmup_scale = (1 - sigma.item() / 1000) / warmup if warmup else math.inf
|
self,
|
||||||
if warmup_scale < 1:
|
postprocessing_settings: PostprocessingSettings,
|
||||||
# This arithmetic based on https://github.com/invoke-ai/InvokeAI/pull/395
|
latents: torch.Tensor,
|
||||||
warming_threshold = 1 + (threshold - 1) * warmup_scale
|
percent_through
|
||||||
current_threshold = np.clip(warming_threshold, 1, threshold)
|
) -> torch.Tensor:
|
||||||
|
threshold = postprocessing_settings.threshold
|
||||||
|
warmup = postprocessing_settings.warmup
|
||||||
|
|
||||||
|
if percent_through < warmup:
|
||||||
|
current_threshold = threshold + threshold * 5 * (1 - (percent_through / warmup))
|
||||||
else:
|
else:
|
||||||
current_threshold = threshold
|
current_threshold = threshold
|
||||||
|
|
||||||
if current_threshold <= 0:
|
if current_threshold <= 0:
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
maxval = latents.max().item()
|
maxval = latents.max().item()
|
||||||
minval = latents.min().item()
|
minval = latents.min().item()
|
||||||
|
|
||||||
@ -294,25 +313,34 @@ class InvokeAIDiffuserComponent:
|
|||||||
if self.debug_thresholding:
|
if self.debug_thresholding:
|
||||||
std, mean = [i.item() for i in torch.std_mean(latents)]
|
std, mean = [i.item() for i in torch.std_mean(latents)]
|
||||||
outside = torch.count_nonzero((latents < -current_threshold) | (latents > current_threshold))
|
outside = torch.count_nonzero((latents < -current_threshold) | (latents > current_threshold))
|
||||||
print(f"\nThreshold: 𝜎={sigma.item()} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
|
print(f"\nThreshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
|
||||||
f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n"
|
f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n"
|
||||||
f" | {outside / latents.numel() * 100:.2f}% values outside threshold")
|
f" | {outside / latents.numel() * 100:.2f}% values outside threshold")
|
||||||
|
|
||||||
if maxval < current_threshold and minval > -current_threshold:
|
if maxval < current_threshold and minval > -current_threshold:
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
|
num_altered = 0
|
||||||
|
|
||||||
|
# MPS torch.rand_like is fine because torch.rand_like is wrapped in generate.py!
|
||||||
|
|
||||||
if maxval > current_threshold:
|
if maxval > current_threshold:
|
||||||
|
latents = torch.clone(latents)
|
||||||
maxval = np.clip(maxval * scale, 1, current_threshold)
|
maxval = np.clip(maxval * scale, 1, current_threshold)
|
||||||
|
num_altered += torch.count_nonzero(latents > maxval)
|
||||||
|
latents[latents > maxval] = torch.rand_like(latents[latents > maxval]) * maxval
|
||||||
|
|
||||||
if minval < -current_threshold:
|
if minval < -current_threshold:
|
||||||
|
latents = torch.clone(latents)
|
||||||
minval = np.clip(minval * scale, -current_threshold, -1)
|
minval = np.clip(minval * scale, -current_threshold, -1)
|
||||||
|
num_altered += torch.count_nonzero(latents < minval)
|
||||||
|
latents[latents < minval] = torch.rand_like(latents[latents < minval]) * minval
|
||||||
|
|
||||||
if self.debug_thresholding:
|
if self.debug_thresholding:
|
||||||
outside = torch.count_nonzero((latents < minval) | (latents > maxval))
|
|
||||||
print(f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n"
|
print(f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n"
|
||||||
f" | {outside / latents.numel() * 100:.2f}% values will be clamped")
|
f" | {num_altered / latents.numel() * 100:.2f}% values altered")
|
||||||
|
|
||||||
return latents.clamp(minval, maxval)
|
return latents
|
||||||
|
|
||||||
def estimate_percent_through(self, step_index, sigma):
|
def estimate_percent_through(self, step_index, sigma):
|
||||||
if step_index is not None and self.cross_attention_control_context is not None:
|
if step_index is not None and self.cross_attention_control_context is not None:
|
||||||
@ -376,4 +404,3 @@ class InvokeAIDiffuserComponent:
|
|||||||
# assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale))))
|
# assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale))))
|
||||||
|
|
||||||
return uncond_latents + deltas_merged * global_guidance_scale
|
return uncond_latents + deltas_merged * global_guidance_scale
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user