Apply black

This commit is contained in:
Martin Kristiansen
2023-07-27 10:54:01 -04:00
parent 2183dba5c5
commit 218b6d0546
148 changed files with 5486 additions and 6296 deletions

View File

@ -8,6 +8,8 @@ if torch.backends.mps.is_available():
_torch_layer_norm = torch.nn.functional.layer_norm
def new_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
if input.device.type == "mps" and input.dtype == torch.float16:
input = input.float()
@ -19,20 +21,26 @@ def new_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
else:
return _torch_layer_norm(input, normalized_shape, weight, bias, eps)
torch.nn.functional.layer_norm = new_layer_norm
_torch_tensor_permute = torch.Tensor.permute
def new_torch_tensor_permute(input, *dims):
result = _torch_tensor_permute(input, *dims)
if input.device == "mps" and input.dtype == torch.float16:
result = result.contiguous()
return result
torch.Tensor.permute = new_torch_tensor_permute
_torch_lerp = torch.lerp
def new_torch_lerp(input, end, weight, *, out=None):
if input.device.type == "mps" and input.dtype == torch.float16:
input = input.float()
@ -52,20 +60,36 @@ def new_torch_lerp(input, end, weight, *, out=None):
else:
return _torch_lerp(input, end, weight, out=out)
torch.lerp = new_torch_lerp
_torch_interpolate = torch.nn.functional.interpolate
def new_torch_interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False):
def new_torch_interpolate(
input,
size=None,
scale_factor=None,
mode="nearest",
align_corners=None,
recompute_scale_factor=None,
antialias=False,
):
if input.device.type == "mps" and input.dtype == torch.float16:
return _torch_interpolate(input.float(), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias).half()
return _torch_interpolate(
input.float(), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias
).half()
else:
return _torch_interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)
torch.nn.functional.interpolate = new_torch_interpolate
# TODO: refactor it
_SlicedAttnProcessor = diffusers.models.attention_processor.SlicedAttnProcessor
class ChunkedSlicedAttnProcessor:
r"""
Processor for implementing sliced attention.
@ -78,7 +102,7 @@ class ChunkedSlicedAttnProcessor:
def __init__(self, slice_size):
assert isinstance(slice_size, int)
slice_size = 1 # TODO: maybe implement chunking in batches too when enough memory
slice_size = 1 # TODO: maybe implement chunking in batches too when enough memory
self.slice_size = slice_size
self._sliced_attn_processor = _SlicedAttnProcessor(slice_size)
@ -121,7 +145,9 @@ class ChunkedSlicedAttnProcessor:
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
)
chunk_tmp_tensor = torch.empty(self.slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device)
chunk_tmp_tensor = torch.empty(
self.slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
)
for i in range(batch_size_attention // self.slice_size):
start_idx = i * self.slice_size
@ -131,7 +157,15 @@ class ChunkedSlicedAttnProcessor:
key_slice = key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
self.get_attention_scores_chunked(attn, query_slice, key_slice, attn_mask_slice, hidden_states[start_idx:end_idx], value[start_idx:end_idx], chunk_tmp_tensor)
self.get_attention_scores_chunked(
attn,
query_slice,
key_slice,
attn_mask_slice,
hidden_states[start_idx:end_idx],
value[start_idx:end_idx],
chunk_tmp_tensor,
)
hidden_states = attn.batch_to_head_dim(hidden_states)
@ -150,7 +184,6 @@ class ChunkedSlicedAttnProcessor:
return hidden_states
def get_attention_scores_chunked(self, attn, query, key, attention_mask, hidden_states, value, chunk):
# batch size = 1
assert query.shape[0] == 1
@ -163,14 +196,14 @@ class ChunkedSlicedAttnProcessor:
query = query.float()
key = key.float()
#out_item_size = query.dtype.itemsize
#if attn.upcast_attention:
# out_item_size = query.dtype.itemsize
# if attn.upcast_attention:
# out_item_size = torch.float32.itemsize
out_item_size = query.element_size()
if attn.upcast_attention:
out_item_size = 4
chunk_size = 2 ** 29
chunk_size = 2**29
out_size = query.shape[1] * key.shape[1] * out_item_size
chunks_count = min(query.shape[1], math.ceil((out_size - 1) / chunk_size))
@ -181,8 +214,8 @@ class ChunkedSlicedAttnProcessor:
def _get_chunk_view(tensor, start, length):
if start + length > tensor.shape[1]:
length = tensor.shape[1] - start
#print(f"view: [{tensor.shape[0]},{tensor.shape[1]},{tensor.shape[2]}] - start: {start}, length: {length}")
return tensor[:,start:start+length]
# print(f"view: [{tensor.shape[0]},{tensor.shape[1]},{tensor.shape[2]}] - start: {start}, length: {length}")
return tensor[:, start : start + length]
for chunk_pos in range(0, query.shape[1], chunk_step):
if attention_mask is not None:
@ -196,7 +229,7 @@ class ChunkedSlicedAttnProcessor:
)
else:
torch.baddbmm(
torch.zeros((1,1,1), device=query.device, dtype=query.dtype),
torch.zeros((1, 1, 1), device=query.device, dtype=query.dtype),
_get_chunk_view(query, chunk_pos, chunk_step),
key,
beta=0,
@ -206,7 +239,7 @@ class ChunkedSlicedAttnProcessor:
chunk = chunk.softmax(dim=-1)
torch.bmm(chunk, value, out=_get_chunk_view(hidden_states, chunk_pos, chunk_step))
#del chunk
# del chunk
diffusers.models.attention_processor.SlicedAttnProcessor = ChunkedSlicedAttnProcessor