diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 24626247cf..686fb40d3a 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -34,7 +34,7 @@ from torchvision.transforms.functional import resize as tv_resize from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ldm.invoke.globals import Globals -from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, ThresholdSettings +from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, PostprocessingSettings from ldm.modules.textual_inversion_manager import TextualInversionManager @@ -199,8 +199,10 @@ class ConditioningData: """ extra: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo] = None scheduler_args: dict[str, Any] = field(default_factory=dict) - """Additional arguments to pass to scheduler.step.""" - threshold: Optional[ThresholdSettings] = None + """ + Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing(). + """ + postprocessing_settings: Optional[PostprocessingSettings] = None @property def dtype(self): @@ -419,6 +421,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): total_step_count=len(timesteps), additional_guidance=additional_guidance) latents = step_output.prev_sample + + latents = self.invokeai_diffuser.do_latent_postprocessing( + postprocessing_settings=conditioning_data.postprocessing_settings, + latents=latents, + sigma=batched_t, + step_index=i, + total_step_count=len(timesteps) + ) + predicted_original = getattr(step_output, 'pred_original_sample', None) # TODO resuscitate attention map saving @@ -455,7 +466,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): conditioning_data.guidance_scale, step_index=step_index, total_step_count=total_step_count, - threshold=conditioning_data.threshold ) # compute the previous noisy sample x_t -> x_t-1 diff --git a/ldm/invoke/generator/img2img.py b/ldm/invoke/generator/img2img.py index bfa50617ef..0b762f7c98 100644 --- a/ldm/invoke/generator/img2img.py +++ b/ldm/invoke/generator/img2img.py @@ -7,7 +7,7 @@ from diffusers import logging from ldm.invoke.generator.base import Generator from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline, ConditioningData -from ldm.models.diffusion.shared_invokeai_diffusion import ThresholdSettings +from ldm.models.diffusion.shared_invokeai_diffusion import PostprocessingSettings class Img2Img(Generator): @@ -33,7 +33,7 @@ class Img2Img(Generator): conditioning_data = ( ConditioningData( uc, c, cfg_scale, extra_conditioning_info, - threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None) + postprocessing_settings = PostprocessingSettings(threshold, warmup=0.2) if threshold else None) .add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)) diff --git a/ldm/invoke/generator/txt2img.py b/ldm/invoke/generator/txt2img.py index 6578794fa7..76da3e4904 100644 --- a/ldm/invoke/generator/txt2img.py +++ b/ldm/invoke/generator/txt2img.py @@ -6,7 +6,7 @@ import torch from .base import Generator from .diffusers_pipeline import StableDiffusionGeneratorPipeline, ConditioningData -from ...models.diffusion.shared_invokeai_diffusion import ThresholdSettings +from ...models.diffusion.shared_invokeai_diffusion import PostprocessingSettings class Txt2Img(Generator): @@ -33,7 +33,7 @@ class Txt2Img(Generator): conditioning_data = ( ConditioningData( uc, c, cfg_scale, extra_conditioning_info, - threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None) + postprocessing_settings = PostprocessingSettings(threshold, warmup=0.2) if threshold else None) .add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)) def make_image(x_T) -> PIL.Image.Image: diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index 9632a8d4b0..ff5d3a4d26 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -11,7 +11,7 @@ from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_ from ldm.invoke.generator.base import Generator from ldm.invoke.generator.diffusers_pipeline import trim_to_multiple_of, StableDiffusionGeneratorPipeline, \ ConditioningData -from ldm.models.diffusion.shared_invokeai_diffusion import ThresholdSettings +from ldm.models.diffusion.shared_invokeai_diffusion import PostprocessingSettings class Txt2Img2Img(Generator): @@ -36,7 +36,7 @@ class Txt2Img2Img(Generator): conditioning_data = ( ConditioningData( uc, c, cfg_scale, extra_conditioning_info, - threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None) + postprocessing_settings = PostprocessingSettings(threshold=threshold, warmup=0.2) if threshold else None) .add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)) def make_image(x_T): @@ -47,7 +47,6 @@ class Txt2Img2Img(Generator): conditioning_data=conditioning_data, noise=x_T, callback=step_callback, - # TODO: threshold = threshold, ) # Get our initial generation width and height directly from the latent output so diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index f37bec789e..ca3e608fc0 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -15,7 +15,7 @@ from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver @dataclass(frozen=True) -class ThresholdSettings: +class PostprocessingSettings: threshold: float warmup: float @@ -121,7 +121,6 @@ class InvokeAIDiffuserComponent: unconditional_guidance_scale: float, step_index: Optional[int]=None, total_step_count: Optional[int]=None, - threshold: Optional[ThresholdSettings]=None, ): """ :param x: current latents @@ -130,7 +129,6 @@ class InvokeAIDiffuserComponent: :param conditioning: embeddings for conditioned output. for hybrid conditioning this is a dict of tensors [B x 77 x 768], otherwise a single tensor [B x 77 x 768] :param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has :param step_index: counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase. If None, will be estimated by comparing sigma against self.model.sigmas . - :param threshold: threshold to apply after each step :return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning. """ @@ -138,15 +136,7 @@ class InvokeAIDiffuserComponent: cross_attention_control_types_to_do = [] context: Context = self.cross_attention_control_context if self.cross_attention_control_context is not None: - if step_index is not None and total_step_count is not None: - # 🧨diffusers codepath - percent_through = step_index / total_step_count # will never reach 1.0 - this is deliberate - else: - # legacy compvis codepath - # TODO remove when compvis codepath support is dropped - if step_index is None and sigma is None: - raise ValueError(f"Either step_index or sigma is required when doing cross attention control, but both are None.") - percent_through = self.estimate_percent_through(step_index, sigma) + percent_through = self.calculate_percent_through(sigma, step_index, total_step_count) cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through) wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0) @@ -161,11 +151,34 @@ class InvokeAIDiffuserComponent: combined_next_x = self._combine(unconditioned_next_x, conditioned_next_x, unconditional_guidance_scale) - if threshold: - combined_next_x = self._threshold(threshold.threshold, threshold.warmup, combined_next_x, sigma) - return combined_next_x + def do_latent_postprocessing( + self, + postprocessing_settings: PostprocessingSettings, + latents: torch.Tensor, + sigma, + step_index, + total_step_count + ) -> torch.Tensor: + if postprocessing_settings is not None: + percent_through = self.calculate_percent_through(sigma, step_index, total_step_count) + latents = self.apply_threshold(postprocessing_settings, latents, percent_through) + return latents + + def calculate_percent_through(self, sigma, step_index, total_step_count): + if step_index is not None and total_step_count is not None: + # 🧨diffusers codepath + percent_through = step_index / total_step_count # will never reach 1.0 - this is deliberate + else: + # legacy compvis codepath + # TODO remove when compvis codepath support is dropped + if step_index is None and sigma is None: + raise ValueError( + f"Either step_index or sigma is required when doing cross attention control, but both are None.") + percent_through = self.estimate_percent_through(step_index, sigma) + return percent_through + # methods below are called from do_diffusion_step and should be considered private to this class. def apply_standard_conditioning(self, x, sigma, unconditioning, conditioning): @@ -275,17 +288,23 @@ class InvokeAIDiffuserComponent: combined_next_x = unconditioned_next_x + scaled_delta return combined_next_x - def _threshold(self, threshold, warmup, latents: torch.Tensor, sigma) -> torch.Tensor: - warmup_scale = (1 - sigma.item() / 1000) / warmup if warmup else math.inf - if warmup_scale < 1: - # This arithmetic based on https://github.com/invoke-ai/InvokeAI/pull/395 - warming_threshold = 1 + (threshold - 1) * warmup_scale - current_threshold = np.clip(warming_threshold, 1, threshold) + def apply_threshold( + self, + postprocessing_settings: PostprocessingSettings, + latents: torch.Tensor, + percent_through + ) -> torch.Tensor: + threshold = postprocessing_settings.threshold + warmup = postprocessing_settings.warmup + + if percent_through < warmup: + current_threshold = threshold + threshold * 5 * (1 - (percent_through / warmup)) else: current_threshold = threshold if current_threshold <= 0: return latents + maxval = latents.max().item() minval = latents.min().item() @@ -294,25 +313,34 @@ class InvokeAIDiffuserComponent: if self.debug_thresholding: std, mean = [i.item() for i in torch.std_mean(latents)] outside = torch.count_nonzero((latents < -current_threshold) | (latents > current_threshold)) - print(f"\nThreshold: 𝜎={sigma.item()} threshold={current_threshold:.3f} (of {threshold:.3f})\n" + print(f"\nThreshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})\n" f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n" f" | {outside / latents.numel() * 100:.2f}% values outside threshold") if maxval < current_threshold and minval > -current_threshold: return latents + num_altered = 0 + + # MPS torch.rand_like is fine because torch.rand_like is wrapped in generate.py! + if maxval > current_threshold: + latents = torch.clone(latents) maxval = np.clip(maxval * scale, 1, current_threshold) + num_altered += torch.count_nonzero(latents > maxval) + latents[latents > maxval] = torch.rand_like(latents[latents > maxval]) * maxval if minval < -current_threshold: + latents = torch.clone(latents) minval = np.clip(minval * scale, -current_threshold, -1) + num_altered += torch.count_nonzero(latents < minval) + latents[latents < minval] = torch.rand_like(latents[latents < minval]) * minval if self.debug_thresholding: - outside = torch.count_nonzero((latents < minval) | (latents > maxval)) print(f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n" - f" | {outside / latents.numel() * 100:.2f}% values will be clamped") + f" | {num_altered / latents.numel() * 100:.2f}% values altered") - return latents.clamp(minval, maxval) + return latents def estimate_percent_through(self, step_index, sigma): if step_index is not None and self.cross_attention_control_context is not None: @@ -376,4 +404,3 @@ class InvokeAIDiffuserComponent: # assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale)))) return uncond_latents + deltas_merged * global_guidance_scale -