import torch
import inspect
from tqdm import tqdm
from typing import List, Literal, Optional, Union

from pydantic import Field, validator

from ...backend.model_management import ModelType, SubModelType
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
                             InvocationConfig, InvocationContext)

from .model import UNetField, ClipField, VaeField, MainModelField, ModelInfo
from .compel import ConditioningField
from .latent import LatentsField, SAMPLER_NAME_VALUES, LatentsOutput, get_scheduler, build_latents_output

class SDXLModelLoaderOutput(BaseInvocationOutput):
    """SDXL base model loader output"""

    # fmt: off
    type: Literal["sdxl_model_loader_output"] = "sdxl_model_loader_output"

    unet: UNetField = Field(default=None, description="UNet submodel")
    clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
    clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels (SDXL only)")
    vae: VaeField = Field(default=None, description="Vae submodel")
    # fmt: on

class SDXLRefinerModelLoaderOutput(SDXLModelLoaderOutput):
    """SDXL refiner model loader output"""
    # fmt: off
    type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
    #fmt: on
    

class SDXLModelLoaderInvocation(BaseInvocation):
    """Loads an sdxl base model, outputting its submodels."""

    type: Literal["sdxl_model_loader"] = "sdxl_main_model_loader"

    model: MainModelField = Field(description="The model to load")
    # TODO: precision?

    # Schema customisation
    class Config(InvocationConfig):
        schema_extra = {
            "ui": {
                "title": "SDXL Model Loader",
                "tags": ["model", "loader", "sdxl"],
                "type_hints": {"model": "model"},
            },
        }

    @classmethod
    def _output_class(cls):
        return SDXLModelLoaderOutput

    def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
        base_model = self.model.base_model
        model_name = self.model.model_name
        model_type = ModelType.Main

        # TODO: not found exceptions
        if not context.services.model_manager.model_exists(
            model_name=model_name,
            base_model=base_model,
            model_type=model_type,
        ):
            raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")

        return self._output_class(
            unet=UNetField(
                unet=ModelInfo(
                    model_name=model_name,
                    base_model=base_model,
                    model_type=model_type,
                    submodel=SubModelType.UNet,
                ),
                scheduler=ModelInfo(
                    model_name=model_name,
                    base_model=base_model,
                    model_type=model_type,
                    submodel=SubModelType.Scheduler,
                ),
                loras=[],
            ),
            clip=ClipField(
                tokenizer=ModelInfo(
                    model_name=model_name,
                    base_model=base_model,
                    model_type=model_type,
                    submodel=SubModelType.Tokenizer,
                ),
                text_encoder=ModelInfo(
                    model_name=model_name,
                    base_model=base_model,
                    model_type=model_type,
                    submodel=SubModelType.TextEncoder,
                ),
                loras=[],
                skipped_layers=0,
            ),
            clip2=ClipField(
                tokenizer=ModelInfo(
                    model_name=model_name,
                    base_model=base_model,
                    model_type=model_type,
                    submodel=SubModelType.Tokenizer2,
                ),
                text_encoder=ModelInfo(
                    model_name=model_name,
                    base_model=base_model,
                    model_type=model_type,
                    submodel=SubModelType.TextEncoder2,
                ),
                loras=[],
                skipped_layers=0,
            ),
            vae=VaeField(
                vae=ModelInfo(
                    model_name=model_name,
                    base_model=base_model,
                    model_type=model_type,
                    submodel=SubModelType.Vae,
                ),
            ),
        )

class SDXLRefinerModelLoaderInvocation(SDXLModelLoaderInvocation):
    """Loads an sdxl refiner model, outputting its submodels."""
    type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"

    # Schema customisation
    class Config(InvocationConfig):
        schema_extra = {
            "ui": {
                "title": "SDXL Refiner Model Loader",
                "tags": ["model", "loader", "sdxl_refiner"],
                "type_hints": {"model": "model"},
            },
        }

    @classmethod
    def _output_class(cls):
        return SDXLRefinerModelLoaderOutput

# Text to image
class SDXLTextToLatentsInvocation(BaseInvocation):
    """Generates latents from conditionings."""

    type: Literal["t2l_sdxl"] = "t2l_sdxl"

    # Inputs
    # fmt: off
    positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
    negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
    noise: Optional[LatentsField] = Field(description="The noise to use")
    steps:       int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
    cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
    scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
    unet: UNetField = Field(default=None, description="UNet submodel")
    denoising_end: float = Field(default=1.0, gt=0, le=1, description="")
    #control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
    #seamless:   bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
    #seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
    # fmt: on

    @validator("cfg_scale")
    def ge_one(cls, v):
        """validate that all cfg_scale values are >= 1"""
        if isinstance(v, list):
            for i in v:
                if i < 1:
                    raise ValueError('cfg_scale must be greater than 1')
        else:
            if v < 1:
                raise ValueError('cfg_scale must be greater than 1')
        return v

    # Schema customisation
    class Config(InvocationConfig):
        schema_extra = {
            "ui": {
                "tags": ["latents"],
                "type_hints": {
                  "model": "model",
                  # "cfg_scale": "float",
                  "cfg_scale": "number"
                }
            },
        }

    # based on
    # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
    @torch.no_grad()
    def invoke(self, context: InvocationContext) -> LatentsOutput:
        latents = context.services.latents.get(self.noise.latents_name)

        positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
        prompt_embeds = positive_cond_data.conditionings[0].embeds
        pooled_prompt_embeds = positive_cond_data.conditionings[0].pooled_embeds
        add_time_ids = positive_cond_data.conditionings[0].add_time_ids

        negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
        negative_prompt_embeds = negative_cond_data.conditionings[0].embeds
        negative_pooled_prompt_embeds = negative_cond_data.conditionings[0].pooled_embeds
        add_neg_time_ids = negative_cond_data.conditionings[0].add_time_ids

        scheduler = get_scheduler(
            context=context,
            scheduler_info=self.unet.scheduler,
            scheduler_name=self.scheduler,
        )

        num_inference_steps = self.steps
        scheduler.set_timesteps(num_inference_steps)
        timesteps = scheduler.timesteps

        latents = latents * scheduler.init_noise_sigma


        unet_info = context.services.model_manager.get_model(
            **self.unet.unet.dict()
        )
        do_classifier_free_guidance = True
        cross_attention_kwargs = None
        with unet_info as unet:

            extra_step_kwargs = dict()
            if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
                extra_step_kwargs.update(
                    eta=0.0,
                )
            if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
                extra_step_kwargs.update(
                    generator=torch.Generator(device=unet.device).manual_seed(0),
                )

            num_warmup_steps = len(timesteps) - self.steps * scheduler.order

            # apply denoising_end
            skipped_final_steps = int(round((1 - self.denoising_end) * self.steps))
            num_inference_steps = num_inference_steps - skipped_final_steps
            timesteps = timesteps[: num_warmup_steps + scheduler.order * num_inference_steps]

            if not context.services.configuration.sequential_guidance:
                prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
                add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
                add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)

                prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
                add_text_embeds = add_text_embeds.to(device=unet.device, dtype=unet.dtype)
                add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
                latents = latents.to(device=unet.device, dtype=unet.dtype)

                with tqdm(total=self.steps) as progress_bar:
                    for i, t in enumerate(timesteps):
                        # expand the latents if we are doing classifier free guidance
                        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents

                        latent_model_input = scheduler.scale_model_input(latent_model_input, t)

                        # predict the noise residual
                        added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
                        noise_pred = unet(
                            latent_model_input,
                            t,
                            encoder_hidden_states=prompt_embeds,
                            cross_attention_kwargs=cross_attention_kwargs,
                            added_cond_kwargs=added_cond_kwargs,
                            return_dict=False,
                        )[0]

                        # perform guidance
                        if do_classifier_free_guidance:
                            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                            noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
                            #del noise_pred_uncond
                            #del noise_pred_text

                        #if do_classifier_free_guidance and guidance_rescale > 0.0:
                        #    # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
                        #    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)

                        # compute the previous noisy sample x_t -> x_t-1
                        latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

                        # call the callback, if provided
                        if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
                            progress_bar.update()
                            #if callback is not None and i % callback_steps == 0:
                            #    callback(i, t, latents)
            else:
                negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
                negative_prompt_embeds = negative_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
                add_neg_time_ids = add_neg_time_ids.to(device=unet.device, dtype=unet.dtype)
                pooled_prompt_embeds = pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
                prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
                add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
                latents = latents.to(device=unet.device, dtype=unet.dtype)

                with tqdm(total=self.steps) as progress_bar:
                    for i, t in enumerate(timesteps):
                        # expand the latents if we are doing classifier free guidance
                        #latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents

                        latent_model_input = scheduler.scale_model_input(latents, t)

                        #import gc
                        #gc.collect()
                        #torch.cuda.empty_cache()

                        # predict the noise residual

                        added_cond_kwargs = {"text_embeds": negative_pooled_prompt_embeds, "time_ids": add_neg_time_ids}
                        noise_pred_uncond = unet(
                            latent_model_input,
                            t,
                            encoder_hidden_states=negative_prompt_embeds,
                            cross_attention_kwargs=cross_attention_kwargs,
                            added_cond_kwargs=added_cond_kwargs,
                            return_dict=False,
                        )[0]

                        added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
                        noise_pred_text = unet(
                            latent_model_input,
                            t,
                            encoder_hidden_states=prompt_embeds,
                            cross_attention_kwargs=cross_attention_kwargs,
                            added_cond_kwargs=added_cond_kwargs,
                            return_dict=False,
                        )[0]

                        # perform guidance
                        noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)

                        #del noise_pred_text
                        #del noise_pred_uncond
                        #import gc
                        #gc.collect()
                        #torch.cuda.empty_cache()

                        #if do_classifier_free_guidance and guidance_rescale > 0.0:
                        #    # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
                        #    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)

                        # compute the previous noisy sample x_t -> x_t-1
                        latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

                        #del noise_pred
                        #import gc
                        #gc.collect()
                        #torch.cuda.empty_cache()

                        # call the callback, if provided
                        if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
                            progress_bar.update()
                            #if callback is not None and i % callback_steps == 0:
                            #    callback(i, t, latents)



        #################

        torch.cuda.empty_cache()

        name = f'{context.graph_execution_state_id}__{self.id}'
        context.services.latents.save(name, latents)
        return build_latents_output(latents_name=name, latents=latents)

class SDXLLatentsToLatentsInvocation(BaseInvocation):
    """Generates latents from conditionings."""

    type: Literal["l2l_sdxl"] = "l2l_sdxl"

    # Inputs
    # fmt: off
    positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
    negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
    noise: Optional[LatentsField] = Field(description="The noise to use")
    steps:       int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
    cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
    scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
    unet: UNetField = Field(default=None, description="UNet submodel")
    latents: Optional[LatentsField] = Field(description="Initial latents")

    denoising_start: float = Field(default=0.0, ge=0, lt=1, description="")
    denoising_end: float = Field(default=1.0, gt=0, le=1, description="")

    #control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
    #seamless:   bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
    #seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
    # fmt: on

    @validator("cfg_scale")
    def ge_one(cls, v):
        """validate that all cfg_scale values are >= 1"""
        if isinstance(v, list):
            for i in v:
                if i < 1:
                    raise ValueError('cfg_scale must be greater than 1')
        else:
            if v < 1:
                raise ValueError('cfg_scale must be greater than 1')
        return v

    # Schema customisation
    class Config(InvocationConfig):
        schema_extra = {
            "ui": {
                "tags": ["latents"],
                "type_hints": {
                  "model": "model",
                  # "cfg_scale": "float",
                  "cfg_scale": "number"
                }
            },
        }

    # based on
    # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
    @torch.no_grad()
    def invoke(self, context: InvocationContext) -> LatentsOutput:
        latents = context.services.latents.get(self.latents.latents_name)

        positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
        prompt_embeds = positive_cond_data.conditionings[0].embeds
        pooled_prompt_embeds = positive_cond_data.conditionings[0].pooled_embeds
        add_time_ids = positive_cond_data.conditionings[0].add_time_ids

        negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
        negative_prompt_embeds = negative_cond_data.conditionings[0].embeds
        negative_pooled_prompt_embeds = negative_cond_data.conditionings[0].pooled_embeds
        add_neg_time_ids = negative_cond_data.conditionings[0].add_time_ids

        scheduler = get_scheduler(
            context=context,
            scheduler_info=self.unet.scheduler,
            scheduler_name=self.scheduler,
        )

        # apply denoising_start
        num_inference_steps = self.steps
        scheduler.set_timesteps(num_inference_steps)

        t_start = int(round(self.denoising_start * num_inference_steps))
        timesteps = scheduler.timesteps[t_start * scheduler.order:]
        num_inference_steps = num_inference_steps - t_start

        # apply noise(if provided)
        if self.noise is not None:
            noise = context.services.latents.get(self.noise.latents_name)
            latents = scheduler.add_noise(latents, noise, timesteps[:1])
            del noise

        unet_info = context.services.model_manager.get_model(
            **self.unet.unet.dict()
        )
        do_classifier_free_guidance = True
        cross_attention_kwargs = None
        with unet_info as unet:

            # apply scheduler extra args
            extra_step_kwargs = dict()
            if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
                extra_step_kwargs.update(
                    eta=0.0,
                )
            if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
                extra_step_kwargs.update(
                    generator=torch.Generator(device=unet.device).manual_seed(0),
                )

            num_warmup_steps = max(len(timesteps) - num_inference_steps * scheduler.order, 0)

            # apply denoising_end
            skipped_final_steps = int(round((1 - self.denoising_end) * self.steps))
            num_inference_steps = num_inference_steps - skipped_final_steps
            timesteps = timesteps[: num_warmup_steps + scheduler.order * num_inference_steps]

            if not context.services.configuration.sequential_guidance:
                prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
                add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
                add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)

                prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
                add_text_embeds = add_text_embeds.to(device=unet.device, dtype=unet.dtype)
                add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
                latents = latents.to(device=unet.device, dtype=unet.dtype)

                with tqdm(total=num_inference_steps) as progress_bar:
                    for i, t in enumerate(timesteps):
                        # expand the latents if we are doing classifier free guidance
                        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents

                        latent_model_input = scheduler.scale_model_input(latent_model_input, t)

                        # predict the noise residual
                        added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
                        noise_pred = unet(
                            latent_model_input,
                            t,
                            encoder_hidden_states=prompt_embeds,
                            cross_attention_kwargs=cross_attention_kwargs,
                            added_cond_kwargs=added_cond_kwargs,
                            return_dict=False,
                        )[0]

                        # perform guidance
                        if do_classifier_free_guidance:
                            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                            noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
                            #del noise_pred_uncond
                            #del noise_pred_text

                        #if do_classifier_free_guidance and guidance_rescale > 0.0:
                        #    # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
                        #    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)

                        # compute the previous noisy sample x_t -> x_t-1
                        latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

                        # call the callback, if provided
                        if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
                            progress_bar.update()
                            #if callback is not None and i % callback_steps == 0:
                            #    callback(i, t, latents)
            else:
                negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
                negative_prompt_embeds = negative_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
                add_neg_time_ids = add_neg_time_ids.to(device=unet.device, dtype=unet.dtype)
                pooled_prompt_embeds = pooled_prompt_embeds.to(device=unet.device, dtype=unet.dtype)
                prompt_embeds = prompt_embeds.to(device=unet.device, dtype=unet.dtype)
                add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
                latents = latents.to(device=unet.device, dtype=unet.dtype)

                with tqdm(total=num_inference_steps) as progress_bar:
                    for i, t in enumerate(timesteps):
                        # expand the latents if we are doing classifier free guidance
                        #latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents

                        latent_model_input = scheduler.scale_model_input(latents, t)

                        #import gc
                        #gc.collect()
                        #torch.cuda.empty_cache()

                        # predict the noise residual

                        added_cond_kwargs = {"text_embeds": negative_pooled_prompt_embeds, "time_ids": add_time_ids}
                        noise_pred_uncond = unet(
                            latent_model_input,
                            t,
                            encoder_hidden_states=negative_prompt_embeds,
                            cross_attention_kwargs=cross_attention_kwargs,
                            added_cond_kwargs=added_cond_kwargs,
                            return_dict=False,
                        )[0]

                        added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
                        noise_pred_text = unet(
                            latent_model_input,
                            t,
                            encoder_hidden_states=prompt_embeds,
                            cross_attention_kwargs=cross_attention_kwargs,
                            added_cond_kwargs=added_cond_kwargs,
                            return_dict=False,
                        )[0]

                        # perform guidance
                        noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)

                        #del noise_pred_text
                        #del noise_pred_uncond
                        #import gc
                        #gc.collect()
                        #torch.cuda.empty_cache()

                        #if do_classifier_free_guidance and guidance_rescale > 0.0:
                        #    # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
                        #    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)

                        # compute the previous noisy sample x_t -> x_t-1
                        latents = scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

                        #del noise_pred
                        #import gc
                        #gc.collect()
                        #torch.cuda.empty_cache()

                        # call the callback, if provided
                        if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
                            progress_bar.update()
                            #if callback is not None and i % callback_steps == 0:
                            #    callback(i, t, latents)



        #################

        torch.cuda.empty_cache()

        name = f'{context.graph_execution_state_id}__{self.id}'
        context.services.latents.save(name, latents)
        return build_latents_output(latents_name=name, latents=latents)