import os
import torch
import pytorch_lightning as pl
from omegaconf import OmegaConf
from torch.nn import functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from copy import deepcopy
from einops import rearrange
from glob import glob
from natsort import natsorted

from ldm.modules.diffusionmodules.openaimodel import (
    EncoderUNetModel,
    UNetModel,
)
from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config

__models__ = {'class_label': EncoderUNetModel, 'segmentation': UNetModel}


def disabled_train(self, mode=True):
    """Overwrite model.train with this function to make sure train/eval mode
    does not change anymore."""
    return self


class NoisyLatentImageClassifier(pl.LightningModule):
    def __init__(
        self,
        diffusion_path,
        num_classes,
        ckpt_path=None,
        pool='attention',
        label_key=None,
        diffusion_ckpt_path=None,
        scheduler_config=None,
        weight_decay=1.0e-2,
        log_steps=10,
        monitor='val/loss',
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.num_classes = num_classes
        # get latest config of diffusion model
        diffusion_config = natsorted(
            glob(os.path.join(diffusion_path, 'configs', '*-project.yaml'))
        )[-1]
        self.diffusion_config = OmegaConf.load(diffusion_config).model
        self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
        self.load_diffusion()

        self.monitor = monitor
        self.numd = (
            self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
        )
        self.log_time_interval = (
            self.diffusion_model.num_timesteps // log_steps
        )
        self.log_steps = log_steps

        self.label_key = (
            label_key
            if not hasattr(self.diffusion_model, 'cond_stage_key')
            else self.diffusion_model.cond_stage_key
        )

        assert (
            self.label_key is not None
        ), 'label_key neither in diffusion model nor in model.params'

        if self.label_key not in __models__:
            raise NotImplementedError()

        self.load_classifier(ckpt_path, pool)

        self.scheduler_config = scheduler_config
        self.use_scheduler = self.scheduler_config is not None
        self.weight_decay = weight_decay

    def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
        sd = torch.load(path, map_location='cpu')
        if 'state_dict' in list(sd.keys()):
            sd = sd['state_dict']
        keys = list(sd.keys())
        for k in keys:
            for ik in ignore_keys:
                if k.startswith(ik):
                    print('Deleting key {} from state_dict.'.format(k))
                    del sd[k]
        missing, unexpected = (
            self.load_state_dict(sd, strict=False)
            if not only_model
            else self.model.load_state_dict(sd, strict=False)
        )
        print(
            f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
        )
        if len(missing) > 0:
            print(f'Missing Keys: {missing}')
        if len(unexpected) > 0:
            print(f'Unexpected Keys: {unexpected}')

    def load_diffusion(self):
        model = instantiate_from_config(self.diffusion_config)
        self.diffusion_model = model.eval()
        self.diffusion_model.train = disabled_train
        for param in self.diffusion_model.parameters():
            param.requires_grad = False

    def load_classifier(self, ckpt_path, pool):
        model_config = deepcopy(
            self.diffusion_config.params.unet_config.params
        )
        model_config.in_channels = (
            self.diffusion_config.params.unet_config.params.out_channels
        )
        model_config.out_channels = self.num_classes
        if self.label_key == 'class_label':
            model_config.pool = pool

        self.model = __models__[self.label_key](**model_config)
        if ckpt_path is not None:
            print(
                '#####################################################################'
            )
            print(f'load from ckpt "{ckpt_path}"')
            print(
                '#####################################################################'
            )
            self.init_from_ckpt(ckpt_path)

    @torch.no_grad()
    def get_x_noisy(self, x, t, noise=None):
        noise = default(noise, lambda: torch.randn_like(x))
        continuous_sqrt_alpha_cumprod = None
        if self.diffusion_model.use_continuous_noise:
            continuous_sqrt_alpha_cumprod = (
                self.diffusion_model.sample_continuous_noise_level(
                    x.shape[0], t + 1
                )
            )
            # todo: make sure t+1 is correct here

        return self.diffusion_model.q_sample(
            x_start=x,
            t=t,
            noise=noise,
            continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod,
        )

    def forward(self, x_noisy, t, *args, **kwargs):
        return self.model(x_noisy, t)

    @torch.no_grad()
    def get_input(self, batch, k):
        x = batch[k]
        if len(x.shape) == 3:
            x = x[..., None]
        x = rearrange(x, 'b h w c -> b c h w')
        x = x.to(memory_format=torch.contiguous_format).float()
        return x

    @torch.no_grad()
    def get_conditioning(self, batch, k=None):
        if k is None:
            k = self.label_key
        assert k is not None, 'Needs to provide label key'

        targets = batch[k].to(self.device)

        if self.label_key == 'segmentation':
            targets = rearrange(targets, 'b h w c -> b c h w')
            for down in range(self.numd):
                h, w = targets.shape[-2:]
                targets = F.interpolate(
                    targets, size=(h // 2, w // 2), mode='nearest'
                )

            # targets = rearrange(targets,'b c h w -> b h w c')

        return targets

    def compute_top_k(self, logits, labels, k, reduction='mean'):
        _, top_ks = torch.topk(logits, k, dim=1)
        if reduction == 'mean':
            return (
                (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
            )
        elif reduction == 'none':
            return (top_ks == labels[:, None]).float().sum(dim=-1)

    def on_train_epoch_start(self):
        # save some memory
        self.diffusion_model.model.to('cpu')

    @torch.no_grad()
    def write_logs(self, loss, logits, targets):
        log_prefix = 'train' if self.training else 'val'
        log = {}
        log[f'{log_prefix}/loss'] = loss.mean()
        log[f'{log_prefix}/acc@1'] = self.compute_top_k(
            logits, targets, k=1, reduction='mean'
        )
        log[f'{log_prefix}/acc@5'] = self.compute_top_k(
            logits, targets, k=5, reduction='mean'
        )

        self.log_dict(
            log,
            prog_bar=False,
            logger=True,
            on_step=self.training,
            on_epoch=True,
        )
        self.log(
            'loss', log[f'{log_prefix}/loss'], prog_bar=True, logger=False
        )
        self.log(
            'global_step',
            self.global_step,
            logger=False,
            on_epoch=False,
            prog_bar=True,
        )
        lr = self.optimizers().param_groups[0]['lr']
        self.log(
            'lr_abs',
            lr,
            on_step=True,
            logger=True,
            on_epoch=False,
            prog_bar=True,
        )

    def shared_step(self, batch, t=None):
        x, *_ = self.diffusion_model.get_input(
            batch, k=self.diffusion_model.first_stage_key
        )
        targets = self.get_conditioning(batch)
        if targets.dim() == 4:
            targets = targets.argmax(dim=1)
        if t is None:
            t = torch.randint(
                0,
                self.diffusion_model.num_timesteps,
                (x.shape[0],),
                device=self.device,
            ).long()
        else:
            t = torch.full(
                size=(x.shape[0],), fill_value=t, device=self.device
            ).long()
        x_noisy = self.get_x_noisy(x, t)
        logits = self(x_noisy, t)

        loss = F.cross_entropy(logits, targets, reduction='none')

        self.write_logs(loss.detach(), logits.detach(), targets.detach())

        loss = loss.mean()
        return loss, logits, x_noisy, targets

    def training_step(self, batch, batch_idx):
        loss, *_ = self.shared_step(batch)
        return loss

    def reset_noise_accs(self):
        self.noisy_acc = {
            t: {'acc@1': [], 'acc@5': []}
            for t in range(
                0,
                self.diffusion_model.num_timesteps,
                self.diffusion_model.log_every_t,
            )
        }

    def on_validation_start(self):
        self.reset_noise_accs()

    @torch.no_grad()
    def validation_step(self, batch, batch_idx):
        loss, *_ = self.shared_step(batch)

        for t in self.noisy_acc:
            _, logits, _, targets = self.shared_step(batch, t)
            self.noisy_acc[t]['acc@1'].append(
                self.compute_top_k(logits, targets, k=1, reduction='mean')
            )
            self.noisy_acc[t]['acc@5'].append(
                self.compute_top_k(logits, targets, k=5, reduction='mean')
            )

        return loss

    def configure_optimizers(self):
        optimizer = AdamW(
            self.model.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay,
        )

        if self.use_scheduler:
            scheduler = instantiate_from_config(self.scheduler_config)

            print('Setting up LambdaLR scheduler...')
            scheduler = [
                {
                    'scheduler': LambdaLR(
                        optimizer, lr_lambda=scheduler.schedule
                    ),
                    'interval': 'step',
                    'frequency': 1,
                }
            ]
            return [optimizer], scheduler

        return optimizer

    @torch.no_grad()
    def log_images(self, batch, N=8, *args, **kwargs):
        log = dict()
        x = self.get_input(batch, self.diffusion_model.first_stage_key)
        log['inputs'] = x

        y = self.get_conditioning(batch)

        if self.label_key == 'class_label':
            y = log_txt_as_img((x.shape[2], x.shape[3]), batch['human_label'])
            log['labels'] = y

        if ismap(y):
            log['labels'] = self.diffusion_model.to_rgb(y)

            for step in range(self.log_steps):
                current_time = step * self.log_time_interval

                _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)

                log[f'inputs@t{current_time}'] = x_noisy

                pred = F.one_hot(
                    logits.argmax(dim=1), num_classes=self.num_classes
                )
                pred = rearrange(pred, 'b h w c -> b c h w')

                log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(
                    pred
                )

        for key in log:
            log[key] = log[key][:N]

        return log