fix: mps attention fix for sd2

This commit is contained in:
psychedelicious 2023-07-21 21:52:12 +10:00 committed by Kent Keirsey
parent 055b2207cb
commit 3f79812dc6

View File

@ -83,7 +83,7 @@ class ChunkedSlicedAttnProcessor:
self._sliced_attn_processor = _SlicedAttnProcessor(slice_size)
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
if self.slice_size != 1:
if self.slice_size != 1 or attn.upcast_attention:
return self._sliced_attn_processor(attn, hidden_states, encoder_hidden_states, attention_mask)
residual = hidden_states