import io
from typing import Literal, Optional, Any

# from PIL.Image import Image
import PIL.Image
from matplotlib.ticker import MaxNLocator
from matplotlib.figure import Figure

from pydantic import BaseModel, Field
import numpy as np
import matplotlib.pyplot as plt

from easing_functions import (
    LinearInOut,
    QuadEaseInOut, QuadEaseIn, QuadEaseOut,
    CubicEaseInOut, CubicEaseIn, CubicEaseOut,
    QuarticEaseInOut, QuarticEaseIn, QuarticEaseOut,
    QuinticEaseInOut, QuinticEaseIn, QuinticEaseOut,
    SineEaseInOut, SineEaseIn, SineEaseOut,
    CircularEaseIn, CircularEaseInOut, CircularEaseOut,
    ExponentialEaseInOut, ExponentialEaseIn, ExponentialEaseOut,
    ElasticEaseIn, ElasticEaseInOut, ElasticEaseOut,
    BackEaseIn, BackEaseInOut, BackEaseOut,
    BounceEaseIn, BounceEaseInOut, BounceEaseOut)

from .baseinvocation import (
    BaseInvocation,
    BaseInvocationOutput,
    InvocationContext,
    InvocationConfig,
)
from ...backend.util.logging import InvokeAILogger
from .collections import FloatCollectionOutput


class FloatLinearRangeInvocation(BaseInvocation):
    """Creates a range"""

    type: Literal["float_range"] = "float_range"

    # Inputs
    start: float = Field(default=5, description="The first value of the range")
    stop: float = Field(default=10, description="The last value of the range")
    steps: int = Field(default=30, description="number of values to interpolate over (including start and stop)")

    class Config(InvocationConfig):
        schema_extra = {
            "ui": {
                "title": "Linear Range (Float)",
                "tags": ["math", "float", "linear", "range"]
            },
        }

    def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
        param_list = list(np.linspace(self.start, self.stop, self.steps))
        return FloatCollectionOutput(
            collection=param_list
        )


EASING_FUNCTIONS_MAP = {
    "Linear": LinearInOut,
    "QuadIn": QuadEaseIn,
    "QuadOut": QuadEaseOut,
    "QuadInOut": QuadEaseInOut,
    "CubicIn": CubicEaseIn,
    "CubicOut": CubicEaseOut,
    "CubicInOut": CubicEaseInOut,
    "QuarticIn": QuarticEaseIn,
    "QuarticOut": QuarticEaseOut,
    "QuarticInOut": QuarticEaseInOut,
    "QuinticIn": QuinticEaseIn,
    "QuinticOut": QuinticEaseOut,
    "QuinticInOut": QuinticEaseInOut,
    "SineIn": SineEaseIn,
    "SineOut": SineEaseOut,
    "SineInOut": SineEaseInOut,
    "CircularIn": CircularEaseIn,
    "CircularOut": CircularEaseOut,
    "CircularInOut": CircularEaseInOut,
    "ExponentialIn": ExponentialEaseIn,
    "ExponentialOut": ExponentialEaseOut,
    "ExponentialInOut": ExponentialEaseInOut,
    "ElasticIn": ElasticEaseIn,
    "ElasticOut": ElasticEaseOut,
    "ElasticInOut": ElasticEaseInOut,
    "BackIn": BackEaseIn,
    "BackOut": BackEaseOut,
    "BackInOut": BackEaseInOut,
    "BounceIn": BounceEaseIn,
    "BounceOut": BounceEaseOut,
    "BounceInOut": BounceEaseInOut,
}

EASING_FUNCTION_KEYS: Any = Literal[
    tuple(list(EASING_FUNCTIONS_MAP.keys()))
]


# actually I think for now could just use CollectionOutput (which is list[Any]
class StepParamEasingInvocation(BaseInvocation):
    """Experimental per-step parameter easing for denoising steps"""

    type: Literal["step_param_easing"] = "step_param_easing"

    # Inputs
    # fmt: off
    easing: EASING_FUNCTION_KEYS = Field(default="Linear", description="The easing function to use")
    num_steps: int = Field(default=20, description="number of denoising steps")
    start_value: float = Field(default=0.0, description="easing starting value")
    end_value: float = Field(default=1.0, description="easing ending value")
    start_step_percent: float = Field(default=0.0, description="fraction of steps at which to start easing")
    end_step_percent: float = Field(default=1.0, description="fraction of steps after which to end easing")
    # if None, then start_value is used prior to easing start
    pre_start_value: Optional[float] = Field(default=None, description="value before easing start")
    # if None, then end value is used prior to easing end
    post_end_value: Optional[float] = Field(default=None, description="value after easing end")
    mirror: bool = Field(default=False, description="include mirror of easing function")
    # FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
    # alt_mirror: bool = Field(default=False, description="alternative mirroring by dual easing")
    show_easing_plot: bool = Field(default=False, description="show easing plot")
    # fmt: on

    class Config(InvocationConfig):
        schema_extra = {
            "ui": {
                "title": "Param Easing By Step",
                "tags": ["param", "step", "easing"]
            },
        }


    def invoke(self, context: InvocationContext) -> FloatCollectionOutput:
        log_diagnostics = False
        # convert from start_step_percent to nearest step <= (steps * start_step_percent)
        # start_step = int(np.floor(self.num_steps * self.start_step_percent))
        start_step = int(np.round(self.num_steps * self.start_step_percent))
        # convert from end_step_percent to nearest step >= (steps * end_step_percent)
        # end_step = int(np.ceil((self.num_steps - 1) * self.end_step_percent))
        end_step = int(np.round((self.num_steps - 1) * self.end_step_percent))

        # end_step = int(np.ceil(self.num_steps * self.end_step_percent))
        num_easing_steps = end_step - start_step + 1

        # num_presteps = max(start_step - 1, 0)
        num_presteps = start_step
        num_poststeps = self.num_steps - (num_presteps + num_easing_steps)
        prelist = list(num_presteps * [self.pre_start_value])
        postlist = list(num_poststeps * [self.post_end_value])

        if log_diagnostics:
            context.services.logger.debug("start_step: " + str(start_step))
            context.services.logger.debug("end_step: " + str(end_step))
            context.services.logger.debug("num_easing_steps: " + str(num_easing_steps))
            context.services.logger.debug("num_presteps: " + str(num_presteps))
            context.services.logger.debug("num_poststeps: " + str(num_poststeps))
            context.services.logger.debug("prelist size: " + str(len(prelist)))
            context.services.logger.debug("postlist size: " + str(len(postlist)))
            context.services.logger.debug("prelist: " + str(prelist))
            context.services.logger.debug("postlist: " + str(postlist))

        easing_class = EASING_FUNCTIONS_MAP[self.easing]
        if log_diagnostics:
            context.services.logger.debug("easing class: " + str(easing_class))
        easing_list = list()
        if self.mirror:  # "expected" mirroring
            # if number of steps is even, squeeze duration down to (number_of_steps)/2
            # and create reverse copy of list to append
            # if number of steps is odd, squeeze duration down to ceil(number_of_steps/2)
            # and create reverse copy of list[1:end-1]
            # but if even then number_of_steps/2 === ceil(number_of_steps/2), so can just use ceil always

            base_easing_duration = int(np.ceil(num_easing_steps/2.0))
            if log_diagnostics: context.services.logger.debug("base easing duration: " + str(base_easing_duration))
            even_num_steps = (num_easing_steps % 2 == 0)  # even number of steps
            easing_function = easing_class(start=self.start_value,
                                           end=self.end_value,
                                           duration=base_easing_duration - 1)
            base_easing_vals = list()
            for step_index in range(base_easing_duration):
                easing_val = easing_function.ease(step_index)
                base_easing_vals.append(easing_val)
                if log_diagnostics:
                    context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(easing_val))
            if even_num_steps:
                mirror_easing_vals = list(reversed(base_easing_vals))
            else:
                mirror_easing_vals = list(reversed(base_easing_vals[0:-1]))
            if log_diagnostics:
                context.services.logger.debug("base easing vals: " + str(base_easing_vals))
                context.services.logger.debug("mirror easing vals: " + str(mirror_easing_vals))
            easing_list = base_easing_vals + mirror_easing_vals

        # FIXME: add alt_mirror option (alternative to default or mirror), or remove entirely
        # elif self.alt_mirror:  # function mirroring (unintuitive behavior (at least to me))
        #     # half_ease_duration = round(num_easing_steps - 1 / 2)
        #     half_ease_duration = round((num_easing_steps - 1) / 2)
        #     easing_function = easing_class(start=self.start_value,
        #                                    end=self.end_value,
        #                                    duration=half_ease_duration,
        #                                    )
        #
        #     mirror_function = easing_class(start=self.end_value,
        #                                    end=self.start_value,
        #                                    duration=half_ease_duration,
        #                                    )
        #     for step_index in range(num_easing_steps):
        #         if step_index <= half_ease_duration:
        #             step_val = easing_function.ease(step_index)
        #         else:
        #             step_val = mirror_function.ease(step_index - half_ease_duration)
        #         easing_list.append(step_val)
        #         if log_diagnostics: logger.debug(step_index, step_val)
        #

        else:  # no mirroring (default)
            easing_function = easing_class(start=self.start_value,
                                           end=self.end_value,
                                           duration=num_easing_steps - 1)
            for step_index in range(num_easing_steps):
                step_val = easing_function.ease(step_index)
                easing_list.append(step_val)
                if log_diagnostics:
                    context.services.logger.debug("step_index: " + str(step_index) + ", easing_val: " + str(step_val))

        if log_diagnostics:
            context.services.logger.debug("prelist size: " + str(len(prelist)))
            context.services.logger.debug("easing_list size: " + str(len(easing_list)))
            context.services.logger.debug("postlist size: " + str(len(postlist)))

        param_list = prelist + easing_list + postlist

        if self.show_easing_plot:
            plt.figure()
            plt.xlabel("Step")
            plt.ylabel("Param Value")
            plt.title("Per-Step Values Based On Easing: " + self.easing)
            plt.bar(range(len(param_list)), param_list)
            # plt.plot(param_list)
            ax = plt.gca()
            ax.xaxis.set_major_locator(MaxNLocator(integer=True))
            buf = io.BytesIO()
            plt.savefig(buf, format='png')
            buf.seek(0)
            im = PIL.Image.open(buf)
            im.show()
            buf.close()

        # output array of size steps, each entry list[i] is param value for step i
        return FloatCollectionOutput(
            collection=param_list
        )