2023-07-21 01:08:49 +00:00
|
|
|
import math
|
2023-07-04 22:05:01 +00:00
|
|
|
import torch
|
2023-07-21 01:08:49 +00:00
|
|
|
import diffusers
|
2023-07-04 22:05:01 +00:00
|
|
|
|
|
|
|
|
|
|
|
if torch.backends.mps.is_available():
|
|
|
|
torch.empty = torch.zeros
|
|
|
|
|
|
|
|
|
|
|
|
_torch_layer_norm = torch.nn.functional.layer_norm
|
2023-07-28 13:46:44 +00:00
|
|
|
|
|
|
|
|
2023-07-04 22:05:01 +00:00
|
|
|
def new_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
|
|
|
|
if input.device.type == "mps" and input.dtype == torch.float16:
|
|
|
|
input = input.float()
|
|
|
|
if weight is not None:
|
|
|
|
weight = weight.float()
|
|
|
|
if bias is not None:
|
|
|
|
bias = bias.float()
|
|
|
|
return _torch_layer_norm(input, normalized_shape, weight, bias, eps).half()
|
|
|
|
else:
|
|
|
|
return _torch_layer_norm(input, normalized_shape, weight, bias, eps)
|
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
|
2023-07-04 22:05:01 +00:00
|
|
|
torch.nn.functional.layer_norm = new_layer_norm
|
|
|
|
|
|
|
|
|
|
|
|
_torch_tensor_permute = torch.Tensor.permute
|
2023-07-28 13:46:44 +00:00
|
|
|
|
|
|
|
|
2023-07-04 22:05:01 +00:00
|
|
|
def new_torch_tensor_permute(input, *dims):
|
|
|
|
result = _torch_tensor_permute(input, *dims)
|
|
|
|
if input.device == "mps" and input.dtype == torch.float16:
|
|
|
|
result = result.contiguous()
|
|
|
|
return result
|
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
|
2023-07-04 22:05:01 +00:00
|
|
|
torch.Tensor.permute = new_torch_tensor_permute
|
|
|
|
|
|
|
|
|
|
|
|
_torch_lerp = torch.lerp
|
2023-07-28 13:46:44 +00:00
|
|
|
|
|
|
|
|
2023-07-04 22:05:01 +00:00
|
|
|
def new_torch_lerp(input, end, weight, *, out=None):
|
|
|
|
if input.device.type == "mps" and input.dtype == torch.float16:
|
|
|
|
input = input.float()
|
|
|
|
end = end.float()
|
|
|
|
if isinstance(weight, torch.Tensor):
|
|
|
|
weight = weight.float()
|
|
|
|
if out is not None:
|
|
|
|
out_fp32 = torch.zeros_like(out, dtype=torch.float32)
|
|
|
|
else:
|
|
|
|
out_fp32 = None
|
|
|
|
result = _torch_lerp(input, end, weight, out=out_fp32)
|
|
|
|
if out is not None:
|
|
|
|
out.copy_(out_fp32.half())
|
|
|
|
del out_fp32
|
|
|
|
return result.half()
|
|
|
|
|
|
|
|
else:
|
|
|
|
return _torch_lerp(input, end, weight, out=out)
|
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
|
2023-07-05 21:47:23 +00:00
|
|
|
torch.lerp = new_torch_lerp
|
|
|
|
|
|
|
|
|
|
|
|
_torch_interpolate = torch.nn.functional.interpolate
|
2023-07-28 13:46:44 +00:00
|
|
|
|
|
|
|
|
|
|
|
def new_torch_interpolate(
|
|
|
|
input,
|
|
|
|
size=None,
|
|
|
|
scale_factor=None,
|
|
|
|
mode="nearest",
|
|
|
|
align_corners=None,
|
|
|
|
recompute_scale_factor=None,
|
|
|
|
antialias=False,
|
|
|
|
):
|
2023-07-05 21:47:23 +00:00
|
|
|
if input.device.type == "mps" and input.dtype == torch.float16:
|
2023-07-28 13:46:44 +00:00
|
|
|
return _torch_interpolate(
|
|
|
|
input.float(), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias
|
|
|
|
).half()
|
2023-07-05 21:47:23 +00:00
|
|
|
else:
|
|
|
|
return _torch_interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)
|
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
|
2023-07-05 21:47:23 +00:00
|
|
|
torch.nn.functional.interpolate = new_torch_interpolate
|
2023-07-21 01:08:49 +00:00
|
|
|
|
|
|
|
# TODO: refactor it
|
|
|
|
_SlicedAttnProcessor = diffusers.models.attention_processor.SlicedAttnProcessor
|
2023-07-28 13:46:44 +00:00
|
|
|
|
|
|
|
|
2023-07-21 01:08:49 +00:00
|
|
|
class ChunkedSlicedAttnProcessor:
|
|
|
|
r"""
|
|
|
|
Processor for implementing sliced attention.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
slice_size (`int`, *optional*):
|
|
|
|
The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
|
|
|
|
`attention_head_dim` must be a multiple of the `slice_size`.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, slice_size):
|
|
|
|
assert isinstance(slice_size, int)
|
2023-07-28 13:46:44 +00:00
|
|
|
slice_size = 1 # TODO: maybe implement chunking in batches too when enough memory
|
2023-07-21 01:08:49 +00:00
|
|
|
self.slice_size = slice_size
|
|
|
|
self._sliced_attn_processor = _SlicedAttnProcessor(slice_size)
|
|
|
|
|
|
|
|
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
2023-07-21 11:52:12 +00:00
|
|
|
if self.slice_size != 1 or attn.upcast_attention:
|
2023-07-21 01:08:49 +00:00
|
|
|
return self._sliced_attn_processor(attn, hidden_states, encoder_hidden_states, attention_mask)
|
|
|
|
|
|
|
|
residual = hidden_states
|
|
|
|
|
|
|
|
input_ndim = hidden_states.ndim
|
|
|
|
|
|
|
|
if input_ndim == 4:
|
|
|
|
batch_size, channel, height, width = hidden_states.shape
|
|
|
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
|
|
|
|
|
|
batch_size, sequence_length, _ = (
|
|
|
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
|
|
)
|
|
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
|
|
|
|
|
|
if attn.group_norm is not None:
|
|
|
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
|
|
|
|
|
|
query = attn.to_q(hidden_states)
|
|
|
|
dim = query.shape[-1]
|
|
|
|
query = attn.head_to_batch_dim(query)
|
|
|
|
|
|
|
|
if encoder_hidden_states is None:
|
|
|
|
encoder_hidden_states = hidden_states
|
|
|
|
elif attn.norm_cross:
|
|
|
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
|
|
|
|
|
|
key = attn.to_k(encoder_hidden_states)
|
|
|
|
value = attn.to_v(encoder_hidden_states)
|
|
|
|
key = attn.head_to_batch_dim(key)
|
|
|
|
value = attn.head_to_batch_dim(value)
|
|
|
|
|
|
|
|
batch_size_attention, query_tokens, _ = query.shape
|
|
|
|
hidden_states = torch.zeros(
|
|
|
|
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
|
|
|
)
|
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
chunk_tmp_tensor = torch.empty(
|
|
|
|
self.slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
|
|
|
|
)
|
2023-07-21 01:08:49 +00:00
|
|
|
|
|
|
|
for i in range(batch_size_attention // self.slice_size):
|
|
|
|
start_idx = i * self.slice_size
|
|
|
|
end_idx = (i + 1) * self.slice_size
|
|
|
|
|
|
|
|
query_slice = query[start_idx:end_idx]
|
|
|
|
key_slice = key[start_idx:end_idx]
|
|
|
|
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
self.get_attention_scores_chunked(
|
|
|
|
attn,
|
|
|
|
query_slice,
|
|
|
|
key_slice,
|
|
|
|
attn_mask_slice,
|
|
|
|
hidden_states[start_idx:end_idx],
|
|
|
|
value[start_idx:end_idx],
|
|
|
|
chunk_tmp_tensor,
|
|
|
|
)
|
2023-07-21 01:08:49 +00:00
|
|
|
|
|
|
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
|
|
|
|
|
|
# linear proj
|
|
|
|
hidden_states = attn.to_out[0](hidden_states)
|
|
|
|
# dropout
|
|
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
|
|
|
|
|
if input_ndim == 4:
|
|
|
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
|
|
|
|
|
|
if attn.residual_connection:
|
|
|
|
hidden_states = hidden_states + residual
|
|
|
|
|
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor
|
|
|
|
|
|
|
|
return hidden_states
|
|
|
|
|
|
|
|
def get_attention_scores_chunked(self, attn, query, key, attention_mask, hidden_states, value, chunk):
|
|
|
|
# batch size = 1
|
|
|
|
assert query.shape[0] == 1
|
|
|
|
assert key.shape[0] == 1
|
|
|
|
assert value.shape[0] == 1
|
|
|
|
assert hidden_states.shape[0] == 1
|
|
|
|
|
|
|
|
dtype = query.dtype
|
|
|
|
if attn.upcast_attention:
|
|
|
|
query = query.float()
|
|
|
|
key = key.float()
|
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
# out_item_size = query.dtype.itemsize
|
|
|
|
# if attn.upcast_attention:
|
2023-07-21 01:08:49 +00:00
|
|
|
# out_item_size = torch.float32.itemsize
|
|
|
|
out_item_size = query.element_size()
|
|
|
|
if attn.upcast_attention:
|
|
|
|
out_item_size = 4
|
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
chunk_size = 2**29
|
2023-07-21 01:08:49 +00:00
|
|
|
|
|
|
|
out_size = query.shape[1] * key.shape[1] * out_item_size
|
|
|
|
chunks_count = min(query.shape[1], math.ceil((out_size - 1) / chunk_size))
|
|
|
|
chunk_step = max(1, int(query.shape[1] / chunks_count))
|
|
|
|
|
|
|
|
key = key.transpose(-1, -2)
|
|
|
|
|
|
|
|
def _get_chunk_view(tensor, start, length):
|
|
|
|
if start + length > tensor.shape[1]:
|
|
|
|
length = tensor.shape[1] - start
|
2023-07-28 13:46:44 +00:00
|
|
|
# print(f"view: [{tensor.shape[0]},{tensor.shape[1]},{tensor.shape[2]}] - start: {start}, length: {length}")
|
|
|
|
return tensor[:, start : start + length]
|
2023-07-21 01:08:49 +00:00
|
|
|
|
|
|
|
for chunk_pos in range(0, query.shape[1], chunk_step):
|
|
|
|
if attention_mask is not None:
|
|
|
|
torch.baddbmm(
|
|
|
|
_get_chunk_view(attention_mask, chunk_pos, chunk_step),
|
|
|
|
_get_chunk_view(query, chunk_pos, chunk_step),
|
|
|
|
key,
|
|
|
|
beta=1,
|
|
|
|
alpha=attn.scale,
|
|
|
|
out=chunk,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
torch.baddbmm(
|
2023-07-28 13:46:44 +00:00
|
|
|
torch.zeros((1, 1, 1), device=query.device, dtype=query.dtype),
|
2023-07-21 01:08:49 +00:00
|
|
|
_get_chunk_view(query, chunk_pos, chunk_step),
|
|
|
|
key,
|
|
|
|
beta=0,
|
|
|
|
alpha=attn.scale,
|
|
|
|
out=chunk,
|
|
|
|
)
|
|
|
|
chunk = chunk.softmax(dim=-1)
|
|
|
|
torch.bmm(chunk, value, out=_get_chunk_view(hidden_states, chunk_pos, chunk_step))
|
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
# del chunk
|
2023-07-21 01:08:49 +00:00
|
|
|
|
|
|
|
|
|
|
|
diffusers.models.attention_processor.SlicedAttnProcessor = ChunkedSlicedAttnProcessor
|