diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index e5ce403cb7..a3d5ae3c07 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -307,9 +307,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): if is_xformers_available() and not Globals.disable_xformers: self.enable_xformers_memory_efficient_attention() else: - slice_size = 2 + slice_size = 4 # or 2, or 8. i chose this arbitrarily. self.enable_attention_slicing(slice_size=slice_size) + def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int, conditioning_data: ConditioningData, *, diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 8b5467e85f..c248343040 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -560,77 +560,6 @@ class SwapCrossAttnContext: return mask, indices -class SwapCrossAttnProcessor(CrossAttnProcessor): - - def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, - # kwargs - swap_cross_attn_context: SwapCrossAttnContext=None): - - attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS - - # if cross-attention control is not in play, just call through to the base implementation. - if swap_cross_attn_context is None or not swap_cross_attn_context.wants_cross_attention_control(attention_type): - #print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass") - return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask) - #else: - # print(f"SwapCrossAttnContext for {attention_type} active") - - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) - - query = attn.to_q(hidden_states) - query = attn.head_to_batch_dim(query) - - # helper function - def get_attention_probs(embeddings): - this_key = attn.to_k(embeddings) - this_key = attn.head_to_batch_dim(this_key) - return attn.get_attention_scores(query, this_key, attention_mask) - - if attention_type == CrossAttentionType.SELF: - # self attention has no remapping, it just bluntly copies the whole tensor - attention_probs = get_attention_probs(hidden_states) - value = attn.to_v(hidden_states) - else: - # tokens (cross) attention - # first, find attention probabilities for the "original" prompt - original_text_embeddings = encoder_hidden_states - original_attention_probs = get_attention_probs(original_text_embeddings) - - # then, find attention probabilities for the "modified" prompt - modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings - modified_attention_probs = get_attention_probs(modified_text_embeddings) - - # because the prompt modifications may result in token sequences shifted forwards or backwards, - # the original attention probabilities must be remapped to account for token index changes in the - # modified prompt - remapped_original_attention_probs = torch.index_select(original_attention_probs, -1, - swap_cross_attn_context.index_map) - - # only some tokens taken from the original attention probabilities. this is controlled by the mask. - mask = swap_cross_attn_context.mask - inverse_mask = 1 - mask - attention_probs = \ - remapped_original_attention_probs * mask + \ - modified_attention_probs * inverse_mask - - # for the "value" just use the modified text embeddings. - value = attn.to_v(modified_text_embeddings) - - value = attn.head_to_batch_dim(value) - - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - - - class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): # TODO: dynamically pick slice size based on memory conditions @@ -714,3 +643,9 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): hidden_states = attn.to_out[1](hidden_states) return hidden_states + + +class SwapCrossAttnProcessor(SlicedSwapCrossAttnProcesser): + + def __init__(self): + super(SwapCrossAttnProcessor, self).__init__(slice_size=1e6) # big number so we never slice