from __future__ import annotations

import math
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import torch

from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData

if TYPE_CHECKING:
    from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
    from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs


@dataclass
class BasicConditioningInfo:
    """SD 1/2 text conditioning information produced by Compel."""

    embeds: torch.Tensor

    def to(self, device, dtype=None):
        self.embeds = self.embeds.to(device=device, dtype=dtype)
        return self


@dataclass
class ConditioningFieldData:
    conditionings: List[BasicConditioningInfo]


@dataclass
class SDXLConditioningInfo(BasicConditioningInfo):
    """SDXL text conditioning information produced by Compel."""

    pooled_embeds: torch.Tensor
    add_time_ids: torch.Tensor

    def to(self, device, dtype=None):
        self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
        self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
        return super().to(device=device, dtype=dtype)


@dataclass
class IPAdapterConditioningInfo:
    cond_image_prompt_embeds: torch.Tensor
    """IP-Adapter image encoder conditioning embeddings.
    Shape: (num_images, num_tokens, encoding_dim).
    """
    uncond_image_prompt_embeds: torch.Tensor
    """IP-Adapter image encoding embeddings to use for unconditional generation.
    Shape: (num_images, num_tokens, encoding_dim).
    """


@dataclass
class IPAdapterData:
    ip_adapter_model: IPAdapter
    ip_adapter_conditioning: IPAdapterConditioningInfo
    mask: torch.Tensor
    target_blocks: List[str]

    # Either a single weight applied to all steps, or a list of weights for each step.
    weight: Union[float, List[float]] = 1.0
    begin_step_percent: float = 0.0
    end_step_percent: float = 1.0

    def scale_for_step(self, step_index: int, total_steps: int) -> float:
        first_adapter_step = math.floor(self.begin_step_percent * total_steps)
        last_adapter_step = math.ceil(self.end_step_percent * total_steps)
        weight = self.weight[step_index] if isinstance(self.weight, List) else self.weight
        if step_index >= first_adapter_step and step_index <= last_adapter_step:
            # Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
            return weight
        # Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
        return 0.0


@dataclass
class Range:
    start: int
    end: int


class TextConditioningRegions:
    def __init__(
        self,
        masks: torch.Tensor,
        ranges: list[Range],
    ):
        # A binary mask indicating the regions of the image that the prompt should be applied to.
        # Shape: (1, num_prompts, height, width)
        # Dtype: torch.bool
        self.masks = masks

        # A list of ranges indicating the start and end indices of the embeddings that corresponding mask applies to.
        # ranges[i] contains the embedding range for the i'th prompt / mask.
        self.ranges = ranges

        assert self.masks.shape[1] == len(self.ranges)


class ConditioningMode(Enum):
    Both = "both"
    Negative = "negative"
    Positive = "positive"


class TextConditioningData:
    def __init__(
        self,
        uncond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
        cond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
        uncond_regions: Optional[TextConditioningRegions],
        cond_regions: Optional[TextConditioningRegions],
        guidance_scale: Union[float, List[float]],
        guidance_rescale_multiplier: float = 0,  # TODO: old backend, remove
    ):
        self.uncond_text = uncond_text
        self.cond_text = cond_text
        self.uncond_regions = uncond_regions
        self.cond_regions = cond_regions
        # Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
        # `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
        # Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
        # images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
        self.guidance_scale = guidance_scale
        # TODO: old backend, remove
        # For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
        # See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
        self.guidance_rescale_multiplier = guidance_rescale_multiplier

    def is_sdxl(self):
        assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
        return isinstance(self.cond_text, SDXLConditioningInfo)

    def to_unet_kwargs(self, unet_kwargs: UNetKwargs, conditioning_mode: ConditioningMode):
        """Fills unet arguments with data from provided conditionings.

        Args:
            unet_kwargs (UNetKwargs): Object which stores UNet model arguments.
            conditioning_mode (ConditioningMode): Describes which conditionings should be used.
        """
        _, _, h, w = unet_kwargs.sample.shape
        device = unet_kwargs.sample.device
        dtype = unet_kwargs.sample.dtype

        # TODO: combine regions with conditionings
        if conditioning_mode == ConditioningMode.Both:
            conditionings = [self.uncond_text, self.cond_text]
            c_regions = [self.uncond_regions, self.cond_regions]
        elif conditioning_mode == ConditioningMode.Positive:
            conditionings = [self.cond_text]
            c_regions = [self.cond_regions]
        elif conditioning_mode == ConditioningMode.Negative:
            conditionings = [self.uncond_text]
            c_regions = [self.uncond_regions]
        else:
            raise ValueError(f"Unexpected conditioning mode: {conditioning_mode}")

        encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch(
            [c.embeds for c in conditionings]
        )

        unet_kwargs.encoder_hidden_states = encoder_hidden_states
        unet_kwargs.encoder_attention_mask = encoder_attention_mask

        if self.is_sdxl():
            added_cond_kwargs = dict(  # noqa: C408
                text_embeds=torch.cat([c.pooled_embeds for c in conditionings]),
                time_ids=torch.cat([c.add_time_ids for c in conditionings]),
            )

            unet_kwargs.added_cond_kwargs = added_cond_kwargs

        if any(r is not None for r in c_regions):
            tmp_regions = []
            for c, r in zip(conditionings, c_regions, strict=True):
                if r is None:
                    r = TextConditioningRegions(
                        masks=torch.ones((1, 1, h, w), dtype=dtype),
                        ranges=[Range(start=0, end=c.embeds.shape[1])],
                    )
                tmp_regions.append(r)

            if unet_kwargs.cross_attention_kwargs is None:
                unet_kwargs.cross_attention_kwargs = {}

            unet_kwargs.cross_attention_kwargs.update(
                regional_prompt_data=RegionalPromptData(regions=tmp_regions, device=device, dtype=dtype),
            )

    @staticmethod
    def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int) -> torch.Tensor:
        return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim)

    @classmethod
    def _pad_conditioning(
        cls,
        cond: torch.Tensor,
        target_len: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Pad provided conditioning tensor to target_len by zeros and returns mask of unpadded bytes.

        Args:
            cond (torch.Tensor): Conditioning tensor which to pads by zeros.
            target_len (int): To which length(tokens count) pad tensor.
        """
        conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)

        if cond.shape[1] < target_len:
            conditioning_attention_mask = cls._pad_zeros(
                conditioning_attention_mask,
                pad_shape=(cond.shape[0], target_len - cond.shape[1]),
                dim=1,
            )

            cond = cls._pad_zeros(
                cond,
                pad_shape=(cond.shape[0], target_len - cond.shape[1], cond.shape[2]),
                dim=1,
            )

        return cond, conditioning_attention_mask

    @classmethod
    def _concat_conditionings_for_batch(
        cls,
        conditionings: List[torch.Tensor],
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """Concatenate provided conditioning tensors to one batched tensor.
        If tensors have different sizes then pad them by zeros and creates
        encoder_attention_mask to exclude padding from attention.

        Args:
            conditionings (List[torch.Tensor]): List of conditioning tensors to concatenate.
        """
        encoder_attention_mask = None
        max_len = max([c.shape[1] for c in conditionings])
        if any(c.shape[1] != max_len for c in conditionings):
            encoder_attention_masks = [None] * len(conditionings)
            for i in range(len(conditionings)):
                conditionings[i], encoder_attention_masks[i] = cls._pad_conditioning(conditionings[i], max_len)
            encoder_attention_mask = torch.cat(encoder_attention_masks)

        return torch.cat(conditionings), encoder_attention_mask