import math

import torch
from transformers import CLIPTokenizer, CLIPTextModel

from ldm.modules.textual_inversion_manager import TextualInversionManager
from ldm.invoke.devices import torch_dtype

class WeightedPromptFragmentsToEmbeddingsConverter():

    def __init__(self,
                tokenizer: CLIPTokenizer, # converts strings to lists of int token ids
                text_encoder: CLIPTextModel, # convert a list of int token ids to a tensor of embeddings
                textual_inversion_manager: TextualInversionManager = None
                ):
        self.tokenizer = tokenizer
        self.text_encoder = text_encoder
        self.textual_inversion_manager = textual_inversion_manager

    @property
    def max_length(self):
        return self.tokenizer.model_max_length

    def get_embeddings_for_weighted_prompt_fragments(self,
                                                     text: list[str],
                                                     fragment_weights: list[float],
                                                     should_return_tokens: bool = False,
                                                     device='cpu'
                                                     ) -> torch.Tensor:
        '''

        :param text: A list of fragments of text to which different weights are to be applied.
        :param fragment_weights: A batch of lists of weights, one for each entry in `fragments`.
        :return: A tensor of shape `[1, 77, token_dim]` containing weighted embeddings where token_dim is 768 for SD1
                    and 1280 for SD2
        '''
        if len(text) != len(fragment_weights):
            raise ValueError(f"lengths of text and fragment_weights lists are not the same ({len(text)} != {len(fragment_weights)})")

        batch_z = None
        batch_tokens = None
        for fragments, weights in zip(text, fragment_weights):

            # First, weight tokens in individual fragments by scaling the feature vectors as requested (effectively
            # applying a multiplier to the CFG scale on a per-token basis).
            # For tokens weighted<1, intuitively we want SD to become not merely *less* interested in the concept
            # captured by the fragment but actually *dis*interested in it (a 0.01 interest in "red" is still an active
            # interest, however small, in redness; what the user probably intends when they attach the number 0.01 to
            # "red" is to tell SD that it should almost completely *ignore* redness).
            # To do this, the embedding is lerped away from base_embedding in the direction of an embedding for a prompt
            # string from which the low-weighted fragment has been simply removed. The closer the weight is to zero, the
            # closer the resulting embedding is to an embedding for a prompt that simply lacks this fragment.

            # handle weights >=1
            tokens, per_token_weights = self.get_token_ids_and_expand_weights(fragments, weights, device=device)
            base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights)

            # this is our starting point
            embeddings = base_embedding.unsqueeze(0)
            per_embedding_weights = [1.0]

            # now handle weights <1
            # Do this by building extra embeddings tensors that lack the words being <1 weighted. These will be lerped
            # with the embeddings tensors that have the words, such that if the weight of a word is 0.5, the resulting
            # embedding will be exactly half-way between the unweighted prompt and the prompt with the <1 weighted words
            # removed.
            # eg for "mountain:1 man:0.5", intuitively the "man" should be "half-gone". therefore, append an embedding
            # for "mountain" (i.e. without "man") to the already-produced embedding for "mountain man", and weight it
            # such that the resulting lerped embedding is exactly half-way between "mountain man" and "mountain".
            for index, fragment_weight in enumerate(weights):
                if fragment_weight < 1:
                    fragments_without_this = fragments[:index] + fragments[index+1:]
                    weights_without_this = weights[:index] + weights[index+1:]
                    tokens, per_token_weights = self.get_token_ids_and_expand_weights(fragments_without_this, weights_without_this, device=device)
                    embedding_without_this = self.build_weighted_embedding_tensor(tokens, per_token_weights)

                    embeddings = torch.cat((embeddings, embedding_without_this.unsqueeze(0)), dim=1)
                    # weight of the embedding *without* this fragment gets *stronger* as its weight approaches 0
                    # if fragment_weight = 0, basically we want embedding_without_this to completely overwhelm base_embedding
                    # therefore:
                    # fragment_weight = 1: we are at base_z => lerp weight 0
                    # fragment_weight = 0.5: we are halfway between base_z and here => lerp weight 1
                    # fragment_weight = 0: we're now entirely overriding base_z ==> lerp weight inf
                    # so let's use tan(), because:
                    # tan is 0.0 at 0,
                    #        1.0 at PI/4, and
                    #        inf at PI/2
                    # -> tan((1-weight)*PI/2) should give us ideal lerp weights
                    epsilon = 1e-9
                    fragment_weight = max(epsilon, fragment_weight) # inf is bad
                    embedding_lerp_weight = math.tan((1.0 - fragment_weight) * math.pi / 2)
                    # todo handle negative weight?

                    per_embedding_weights.append(embedding_lerp_weight)

            lerped_embeddings = self.apply_embedding_weights(embeddings, per_embedding_weights, normalize=True).squeeze(0)

            #print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}")

            # append to batch
            batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat([batch_z, lerped_embeddings.unsqueeze(0)], dim=1)
            batch_tokens = tokens.unsqueeze(0) if batch_tokens is None else torch.cat([batch_tokens, tokens.unsqueeze(0)], dim=1)

        # should have shape (B, 77, 768)
        #print(f"assembled all tokens into tensor of shape {batch_z.shape}")

        if should_return_tokens:
            return batch_z, batch_tokens
        else:
            return batch_z

    def get_token_ids(self, fragments: list[str], include_start_and_end_markers: bool = True) -> list[list[int]]:
        """
        Convert a list of strings like `["a cat", "sitting", "on a mat"]` into a list of lists of token ids like
        `[[bos, 0, 1, eos], [bos, 2, eos], [bos, 3, 0, 4, eos]]`. bos/eos markers are skipped if
        `include_start_and_end_markers` is `False`. Each list will be restricted to the maximum permitted length
        (typically 75 tokens + eos/bos markers).

        :param fragments: The strings to convert.
        :param include_start_and_end_markers:
        :return:
        """
        # for args documentation see ENCODE_KWARGS_DOCSTRING in tokenization_utils_base.py (in `transformers` lib)
        token_ids_list = self.tokenizer(
            fragments,
            truncation=True,
            max_length=self.max_length,
            return_overflowing_tokens=False,
            padding='do_not_pad',
            return_tensors=None,  # just give me lists of ints
        )['input_ids']

        result = []
        for token_ids in token_ids_list:
            # trim eos/bos
            token_ids = token_ids[1:-1]
            # pad for textual inversions with vector length >1
            token_ids = self.textual_inversion_manager.expand_textual_inversion_token_ids_if_necessary(token_ids)
            # restrict length to max_length-2 (leaving room for bos/eos)
            token_ids = token_ids[0:self.max_length - 2]
            # add back eos/bos if requested
            if include_start_and_end_markers:
                token_ids = [self.tokenizer.bos_token_id] + token_ids + [self.tokenizer.eos_token_id]

            result.append(token_ids)

        return result


    @classmethod
    def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weights: list[float], normalize:bool) -> torch.Tensor:
        per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device)
        if normalize:
            per_embedding_weights = per_embedding_weights / torch.sum(per_embedding_weights)
        reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1, 1,))
        #reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1,1,)).expand(embeddings.shape)
        return torch.sum(embeddings * reshaped_weights, dim=1)
        # lerped embeddings has shape (77, 768)


    def get_token_ids_and_expand_weights(self, fragments: list[str], weights: list[float], device: str) -> (torch.Tensor, torch.Tensor):
        '''
        Given a list of text fragments and corresponding weights: tokenize each fragment, append the token sequences
        together and return a padded token sequence starting with the bos marker, ending with the eos marker, and padded
        or truncated as appropriate to `self.max_length`. Also return a list of weights expanded from the passed-in
        weights to match each token.

        :param fragments: Text fragments to tokenize and concatenate. May be empty.
        :param weights: Per-fragment weights (i.e. quasi-CFG scaling). Values from 0 to inf are permitted. In practise with SD1.5
                        values >1.6 tend to produce garbage output. Must have same length as `fragment`.
        :return: A tuple of tensors `(token_ids, weights)`. `token_ids` is ints, `weights` is floats, both have shape `[self.max_length]`.
        '''
        if len(fragments) != len(weights):
            raise ValueError(f"lengths of text and fragment_weights lists are not the same ({len(fragments)} != {len(weights)})")

        # empty is meaningful
        if len(fragments) == 0:
            fragments = ['']
            weights = [1.0]
        per_fragment_token_ids = self.get_token_ids(fragments, include_start_and_end_markers=False)
        all_token_ids = []
        per_token_weights = []
        #print("all fragments:", fragments, weights)
        for this_fragment_token_ids, weight in zip(per_fragment_token_ids, weights):
            # append
            all_token_ids += this_fragment_token_ids
            # fill out weights tensor with one float per token
            per_token_weights += [float(weight)] * len(this_fragment_token_ids)

        # leave room for bos/eos
        max_token_count_without_bos_eos_markers = self.max_length - 2
        if len(all_token_ids) > max_token_count_without_bos_eos_markers:
            excess_token_count = len(all_token_ids) - max_token_count_without_bos_eos_markers
            # TODO build nice description string of how the truncation was applied
            # this should be done by calling self.tokenizer.convert_ids_to_tokens() then passing the result to
            # self.tokenizer.convert_tokens_to_string() for the token_ids on each side of the truncation limit.
            print(f">> Prompt is {excess_token_count} token(s) too long and has been truncated")
            all_token_ids = all_token_ids[0:max_token_count_without_bos_eos_markers]
            per_token_weights = per_token_weights[0:max_token_count_without_bos_eos_markers]

        # pad out to a self.max_length-entry array: [eos_token, <prompt tokens>, eos_token, ..., eos_token]
        # (typically self.max_length == 77)
        all_token_ids = [self.tokenizer.bos_token_id] + all_token_ids + [self.tokenizer.eos_token_id]
        per_token_weights = [1.0] + per_token_weights + [1.0]
        pad_length = self.max_length - len(all_token_ids)
        all_token_ids += [self.tokenizer.eos_token_id] * pad_length
        per_token_weights += [1.0] * pad_length

        all_token_ids_tensor = torch.tensor(all_token_ids, dtype=torch.long, device=device)
        per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch_dtype(self.text_encoder.device), device=device)
        #print(f"assembled all_token_ids_tensor with shape {all_token_ids_tensor.shape}")
        return all_token_ids_tensor, per_token_weights_tensor

    def build_weighted_embedding_tensor(self, token_ids: torch.Tensor, per_token_weights: torch.Tensor) -> torch.Tensor:
        '''
        Build a tensor that embeds the passed-in token IDs and applyies the given per_token weights
        :param token_ids: A tensor of shape `[self.max_length]` containing token IDs (ints)
        :param per_token_weights: A tensor of shape `[self.max_length]` containing weights (floats)
        :return: A tensor of shape `[1, self.max_length, token_dim]` representing the requested weighted embeddings
        where `token_dim` is 768 for SD1 and 1280 for SD2.
        '''
        #print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}")
        if token_ids.shape != torch.Size([self.max_length]):
            raise ValueError(f"token_ids has shape {token_ids.shape} - expected [{self.max_length}]")

        z = self.text_encoder.forward(input_ids=token_ids.unsqueeze(0),
                                      return_dict=False)[0]
        empty_token_ids = torch.tensor([self.tokenizer.bos_token_id] +
                                    [self.tokenizer.pad_token_id] * (self.max_length-2) +
                                    [self.tokenizer.eos_token_id], dtype=torch.int, device=token_ids.device).unsqueeze(0)
        empty_z = self.text_encoder(input_ids=empty_token_ids).last_hidden_state
        batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape)
        z_delta_from_empty = z - empty_z
        weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded)

        return weighted_z