import math

import diffusers
import torch

if torch.backends.mps.is_available():
    torch.empty = torch.zeros


_torch_layer_norm = torch.nn.functional.layer_norm


def new_layer_norm(input, normalized_shape, weight=None, bias=None, eps=1e-05):
    if input.device.type == "mps" and input.dtype == torch.float16:
        input = input.float()
        if weight is not None:
            weight = weight.float()
        if bias is not None:
            bias = bias.float()
        return _torch_layer_norm(input, normalized_shape, weight, bias, eps).half()
    else:
        return _torch_layer_norm(input, normalized_shape, weight, bias, eps)


torch.nn.functional.layer_norm = new_layer_norm


_torch_tensor_permute = torch.Tensor.permute


def new_torch_tensor_permute(input, *dims):
    result = _torch_tensor_permute(input, *dims)
    if input.device == "mps" and input.dtype == torch.float16:
        result = result.contiguous()
    return result


torch.Tensor.permute = new_torch_tensor_permute


_torch_lerp = torch.lerp


def new_torch_lerp(input, end, weight, *, out=None):
    if input.device.type == "mps" and input.dtype == torch.float16:
        input = input.float()
        end = end.float()
        if isinstance(weight, torch.Tensor):
            weight = weight.float()
        if out is not None:
            out_fp32 = torch.zeros_like(out, dtype=torch.float32)
        else:
            out_fp32 = None
        result = _torch_lerp(input, end, weight, out=out_fp32)
        if out is not None:
            out.copy_(out_fp32.half())
            del out_fp32
        return result.half()

    else:
        return _torch_lerp(input, end, weight, out=out)


torch.lerp = new_torch_lerp


_torch_interpolate = torch.nn.functional.interpolate


def new_torch_interpolate(
    input,
    size=None,
    scale_factor=None,
    mode="nearest",
    align_corners=None,
    recompute_scale_factor=None,
    antialias=False,
):
    if input.device.type == "mps" and input.dtype == torch.float16:
        return _torch_interpolate(
            input.float(), size, scale_factor, mode, align_corners, recompute_scale_factor, antialias
        ).half()
    else:
        return _torch_interpolate(input, size, scale_factor, mode, align_corners, recompute_scale_factor, antialias)


torch.nn.functional.interpolate = new_torch_interpolate

# TODO: refactor it
_SlicedAttnProcessor = diffusers.models.attention_processor.SlicedAttnProcessor


class ChunkedSlicedAttnProcessor:
    r"""
    Processor for implementing sliced attention.

    Args:
        slice_size (`int`, *optional*):
            The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and
            `attention_head_dim` must be a multiple of the `slice_size`.
    """

    def __init__(self, slice_size):
        assert isinstance(slice_size, int)
        slice_size = 1  # TODO: maybe implement chunking in batches too when enough memory
        self.slice_size = slice_size
        self._sliced_attn_processor = _SlicedAttnProcessor(slice_size)

    def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
        if self.slice_size != 1 or attn.upcast_attention:
            return self._sliced_attn_processor(attn, hidden_states, encoder_hidden_states, attention_mask)

        residual = hidden_states

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)
        dim = query.shape[-1]
        query = attn.head_to_batch_dim(query)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        batch_size_attention, query_tokens, _ = query.shape
        hidden_states = torch.zeros(
            (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
        )

        chunk_tmp_tensor = torch.empty(
            self.slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
        )

        for i in range(batch_size_attention // self.slice_size):
            start_idx = i * self.slice_size
            end_idx = (i + 1) * self.slice_size

            query_slice = query[start_idx:end_idx]
            key_slice = key[start_idx:end_idx]
            attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None

            self.get_attention_scores_chunked(
                attn,
                query_slice,
                key_slice,
                attn_mask_slice,
                hidden_states[start_idx:end_idx],
                value[start_idx:end_idx],
                chunk_tmp_tensor,
            )

        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

    def get_attention_scores_chunked(self, attn, query, key, attention_mask, hidden_states, value, chunk):
        # batch size = 1
        assert query.shape[0] == 1
        assert key.shape[0] == 1
        assert value.shape[0] == 1
        assert hidden_states.shape[0] == 1

        # dtype = query.dtype
        if attn.upcast_attention:
            query = query.float()
            key = key.float()

        # out_item_size = query.dtype.itemsize
        # if attn.upcast_attention:
        #    out_item_size = torch.float32.itemsize
        out_item_size = query.element_size()
        if attn.upcast_attention:
            out_item_size = 4

        chunk_size = 2**29

        out_size = query.shape[1] * key.shape[1] * out_item_size
        chunks_count = min(query.shape[1], math.ceil((out_size - 1) / chunk_size))
        chunk_step = max(1, int(query.shape[1] / chunks_count))

        key = key.transpose(-1, -2)

        def _get_chunk_view(tensor, start, length):
            if start + length > tensor.shape[1]:
                length = tensor.shape[1] - start
            # print(f"view: [{tensor.shape[0]},{tensor.shape[1]},{tensor.shape[2]}] - start: {start}, length: {length}")
            return tensor[:, start : start + length]

        for chunk_pos in range(0, query.shape[1], chunk_step):
            if attention_mask is not None:
                torch.baddbmm(
                    _get_chunk_view(attention_mask, chunk_pos, chunk_step),
                    _get_chunk_view(query, chunk_pos, chunk_step),
                    key,
                    beta=1,
                    alpha=attn.scale,
                    out=chunk,
                )
            else:
                torch.baddbmm(
                    torch.zeros((1, 1, 1), device=query.device, dtype=query.dtype),
                    _get_chunk_view(query, chunk_pos, chunk_step),
                    key,
                    beta=0,
                    alpha=attn.scale,
                    out=chunk,
                )
            chunk = chunk.softmax(dim=-1)
            torch.bmm(chunk, value, out=_get_chunk_view(hidden_states, chunk_pos, chunk_step))

        # del chunk


diffusers.models.attention_processor.SlicedAttnProcessor = ChunkedSlicedAttnProcessor