wip tracking down MPS slicing support

This commit is contained in:
Damian Stewart 2023-01-25 22:27:23 +01:00
parent 34a3f4a820
commit 41aed57449
2 changed files with 92 additions and 8 deletions

View File

@ -307,8 +307,14 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if is_xformers_available() and not Globals.disable_xformers:
self.enable_xformers_memory_efficient_attention()
else:
slice_size = 4 # or 2, or 8. i chose this arbitrarily.
self.enable_attention_slicing(slice_size=slice_size)
if torch.backends.mps.is_available():
# until pytorch #91617 is fixed, slicing is borked on MPS
# https://github.com/pytorch/pytorch/issues/91617
# fix is in https://github.com/kulinseth/pytorch/pull/222 but no idea when it will get merged to pytorch mainline.
pass
else:
slice_size = 4 # or 2, or 8. i chose this arbitrarily.
self.enable_attention_slicing(slice_size=slice_size)
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,

View File

@ -344,10 +344,14 @@ def setup_cross_attention_control(model, context: Context, is_running_diffusers
if is_running_diffusers:
unet = model
old_attn_processors = unet.attn_processors
# try to re-use an existing slice size
default_slice_size = 4
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
if torch.backends.mps.is_available():
# see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS
unet.set_attn_processor(SwapCrossAttnProcessor())
else:
# try to re-use an existing slice size
default_slice_size = 4
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
return old_attn_processors
else:
context.register_cross_attention_modules(model)
@ -605,7 +609,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
)
# do slices
for i in range(hidden_states.shape[0] // self.slice_size):
for i in range(max(1,hidden_states.shape[0] // self.slice_size)):
start_idx = i * self.slice_size
end_idx = (i + 1) * self.slice_size
@ -630,6 +634,8 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
remapped_original_attn_slice * mask + \
modified_attn_slice * inverse_mask
del remapped_original_attn_slice, modified_attn_slice
attn_slice = torch.bmm(attn_slice, modified_value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
@ -648,4 +654,76 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
class SwapCrossAttnProcessor(SlicedSwapCrossAttnProcesser):
def __init__(self):
super(SwapCrossAttnProcessor, self).__init__(slice_size=1e6) # big number so we never slice
super(SwapCrossAttnProcessor, self).__init__(slice_size=int(1e9))
# theoretically this class could simply inherit from SlicedSwapCrossAttnProcesser
# and consist wholly of an __init__ method that just calls super().__init__(slice_size=1000000000)
# - such a giant slice size would resolve to 'no slicing' at runtime.
# however, pytorch MPS is borked until https://github.com/kulinseth/pytorch/pull/222 is merged into
# mainline pytorch. so for now this has to be a full implementation.
def no__call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None,
# kwargs
swap_cross_attn_context: SwapCrossAttnContext=None):
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS
# if cross-attention control is not in play, just call through to the base implementation.
if attention_type == CrossAttentionType.SELF or \
swap_cross_attn_context is None or \
not swap_cross_attn_context.wants_cross_attention_control(attention_type):
#print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass")
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask)
#else:
# print(f"SwapCrossAttnContext for {attention_type} active")
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query)
# helper function
def get_attention_probs(embeddings):
this_key = attn.to_k(embeddings)
this_key = attn.head_to_batch_dim(this_key)
return attn.get_attention_scores(query, this_key, attention_mask)
# tokens (cross) attention
# first, find attention probabilities for the "original" prompt
original_text_embeddings = encoder_hidden_states
original_attention_probs = get_attention_probs(original_text_embeddings)
# then, find attention probabilities for the "modified" prompt
modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings
modified_attention_probs = get_attention_probs(modified_text_embeddings)
# because the prompt modifications may result in token sequences shifted forwards or backwards,
# the original attention probabilities must be remapped to account for token index changes in the
# modified prompt
remapped_original_attention_probs = torch.index_select(original_attention_probs, -1,
swap_cross_attn_context.index_map).clone()
# only some tokens taken from the original attention probabilities. this is controlled by the mask.
mask = swap_cross_attn_context.mask
inverse_mask = 1 - mask.clone()
attention_probs = \
remapped_original_attention_probs * mask + \
modified_attention_probs * inverse_mask
# for the "value" just use the modified text embeddings.
value = attn.to_v(modified_text_embeddings)
value = attn.head_to_batch_dim(value)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states