From 95d147c5df4a4e70699fb8debd430f0955bb99b0 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Wed, 25 Jan 2023 23:03:30 +0100 Subject: [PATCH] MPS support: negatory --- .../diffusion/cross_attention_control.py | 73 +------------------ 1 file changed, 1 insertion(+), 72 deletions(-) diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index a1b680c411..9712ddf1bd 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -654,76 +654,5 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): class SwapCrossAttnProcessor(SlicedSwapCrossAttnProcesser): def __init__(self): - super(SwapCrossAttnProcessor, self).__init__(slice_size=int(1e9)) - - # theoretically this class could simply inherit from SlicedSwapCrossAttnProcesser - # and consist wholly of an __init__ method that just calls super().__init__(slice_size=1000000000) - # - such a giant slice size would resolve to 'no slicing' at runtime. - # however, pytorch MPS is borked until https://github.com/kulinseth/pytorch/pull/222 is merged into - # mainline pytorch. so for now this has to be a full implementation. - - def no__call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, - # kwargs - swap_cross_attn_context: SwapCrossAttnContext=None): - - attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS - - # if cross-attention control is not in play, just call through to the base implementation. - if attention_type == CrossAttentionType.SELF or \ - swap_cross_attn_context is None or \ - not swap_cross_attn_context.wants_cross_attention_control(attention_type): - #print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass") - return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask) - #else: - # print(f"SwapCrossAttnContext for {attention_type} active") - - batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) - - query = attn.to_q(hidden_states) - query = attn.head_to_batch_dim(query) - - # helper function - def get_attention_probs(embeddings): - this_key = attn.to_k(embeddings) - this_key = attn.head_to_batch_dim(this_key) - return attn.get_attention_scores(query, this_key, attention_mask) - - # tokens (cross) attention - # first, find attention probabilities for the "original" prompt - original_text_embeddings = encoder_hidden_states - original_attention_probs = get_attention_probs(original_text_embeddings) - - # then, find attention probabilities for the "modified" prompt - modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings - modified_attention_probs = get_attention_probs(modified_text_embeddings) - - # because the prompt modifications may result in token sequences shifted forwards or backwards, - # the original attention probabilities must be remapped to account for token index changes in the - # modified prompt - remapped_original_attention_probs = torch.index_select(original_attention_probs, -1, - swap_cross_attn_context.index_map).clone() - - # only some tokens taken from the original attention probabilities. this is controlled by the mask. - mask = swap_cross_attn_context.mask - inverse_mask = 1 - mask.clone() - attention_probs = \ - remapped_original_attention_probs * mask + \ - modified_attention_probs * inverse_mask - - # for the "value" just use the modified text embeddings. - value = attn.to_v(modified_text_embeddings) - - value = attn.head_to_batch_dim(value) - - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - return hidden_states - + super(SwapCrossAttnProcessor, self).__init__(slice_size=int(1e9)) # massive slice size = don't slice