diff --git a/ldm/models/diffusion/cross_attention.py b/ldm/models/diffusion/cross_attention.py index d5c3eaadf0..71d5995b4a 100644 --- a/ldm/models/diffusion/cross_attention.py +++ b/ldm/models/diffusion/cross_attention.py @@ -74,6 +74,8 @@ class CrossAttentionControl: self_attention_modules = cls.get_attention_modules(model, cls.AttentionType.SELF) tokens_attention_modules = cls.get_attention_modules(model, cls.AttentionType.TOKENS) for m in self_attention_modules+tokens_attention_modules: + # clear out the saved slice in case the outermost dim changes + m.last_attn_slice = None m.save_last_attn_slice = True @classmethod @@ -91,6 +93,8 @@ class CrossAttentionControl: def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size): + #print("in wrangler with suggested_attention_slice shape", suggested_attention_slice.shape, "dim", dim) + attn_slice = suggested_attention_slice if dim is not None: start = offset diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index c8b4823111..7459e2e7cc 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -50,23 +50,32 @@ class CFGDenoiser(nn.Module): CrossAttentionControl.clear_requests(self.inner_model) - #rint('generating unconditioned latents') - unconditioned_latents = self.inner_model(x, sigma, cond=uncond) + if self.edited_conditioning is None: + # faster batch path + x_twice = torch.cat([x]*2) + sigma_twice = torch.cat([sigma]*2) + both_conditionings = torch.cat([uncond, cond]) + unconditioned_next_x, conditioned_next_x = self.inner_model(x_twice, sigma_twice, cond=both_conditionings).chunk(2) + else: + # slower non-batched path (20% slower on mac MPS) + # We are only interested in using attention maps for conditioned_next_x, but batching them with generation of + # unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x. + # This messes app their application later, due to mismatched shape of dim 0 (16 vs. 8) + # (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16, + # representing batched uncond + cond, but then when it comes to applying the saved attention, the + # wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.) + # todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well. + unconditioned_next_x = self.inner_model(x, sigma, cond=uncond) - # process x using the original prompt, saving the attention maps if required - if self.edited_conditioning is not None: - # this is automatically toggled off after the model forward() + # process x using the original prompt, saving the attention maps CrossAttentionControl.request_save_attention_maps(self.inner_model) - #print('generating conditioned latents') - conditioned_latents = self.inner_model(x, sigma, cond=cond) - - if self.edited_conditioning is not None: - # process x again, using the saved attention maps but the new conditioning - # this is automatically toggled off after the model forward() + _ = self.inner_model(x, sigma, cond=cond) CrossAttentionControl.clear_requests(self.inner_model) + + # process x again, using the saved attention maps to control where self.edited_conditioning will be applied CrossAttentionControl.request_apply_saved_attention_maps(self.inner_model) - #print('generating edited conditioned latents') - conditioned_latents = self.inner_model(x, sigma, cond=self.edited_conditioning) + conditioned_next_x = self.inner_model(x, sigma, cond=self.edited_conditioning) + CrossAttentionControl.clear_requests(self.inner_model) if self.warmup < self.warmup_max: thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max)) @@ -75,8 +84,9 @@ class CFGDenoiser(nn.Module): thresh = self.threshold if thresh > self.threshold: thresh = self.threshold - delta = (conditioned_latents - unconditioned_latents) - return cfg_apply_threshold(unconditioned_latents + delta * cond_scale, thresh) + # to scale how much effect conditioning has, calculate the changes it does and then scale that + scaled_delta = (conditioned_next_x - unconditioned_next_x) * cond_scale + return cfg_apply_threshold(unconditioned_next_x + scaled_delta, thresh) class KSampler(Sampler):