import json
import torch
import safetensors.torch

from dataclasses import dataclass

from diffusers import ModelMixin, ConfigMixin
from pathlib import Path
from typing import Callable, Literal, Union, Dict, Optional
from picklescan.scanner import scan_file_path

from .models import (
    BaseModelType,
    ModelType,
    ModelVariantType,
    SchedulerPredictionType,
    SilenceWarnings,
    InvalidModelException,
)
from .util import lora_token_vector_length
from .models.base import read_checkpoint_meta


@dataclass
class ModelProbeInfo(object):
    model_type: ModelType
    base_type: BaseModelType
    variant_type: ModelVariantType
    prediction_type: SchedulerPredictionType
    upcast_attention: bool
    format: Literal["diffusers", "checkpoint", "lycoris", "olive", "onnx"]
    image_size: int


class ProbeBase(object):
    """forward declaration"""

    pass


class ModelProbe(object):
    PROBES = {
        "diffusers": {},
        "checkpoint": {},
        "onnx": {},
    }

    CLASS2TYPE = {
        "StableDiffusionPipeline": ModelType.Main,
        "StableDiffusionInpaintPipeline": ModelType.Main,
        "StableDiffusionXLPipeline": ModelType.Main,
        "StableDiffusionXLImg2ImgPipeline": ModelType.Main,
        "AutoencoderKL": ModelType.Vae,
        "ControlNetModel": ModelType.ControlNet,
    }

    @classmethod
    def register_probe(
        cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: ProbeBase
    ):
        cls.PROBES[format][model_type] = probe_class

    @classmethod
    def heuristic_probe(
        cls,
        model: Union[Dict, ModelMixin, Path],
        prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
    ) -> ModelProbeInfo:
        if isinstance(model, Path):
            return cls.probe(model_path=model, prediction_type_helper=prediction_type_helper)
        elif isinstance(model, (dict, ModelMixin, ConfigMixin)):
            return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
        else:
            raise InvalidModelException("model parameter {model} is neither a Path, nor a model")

    @classmethod
    def probe(
        cls,
        model_path: Path,
        model: Optional[Union[Dict, ModelMixin]] = None,
        prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
    ) -> ModelProbeInfo:
        """
        Probe the model at model_path and return sufficient information about it
        to place it somewhere in the models directory hierarchy. If the model is
        already loaded into memory, you may provide it as model in order to avoid
        opening it a second time. The prediction_type_helper callable is a function that receives
        the path to the model and returns the BaseModelType. It is called to distinguish
        between V2-Base and V2-768 SD models.
        """
        if model_path:
            format_type = "diffusers" if model_path.is_dir() else "checkpoint"
        else:
            format_type = "diffusers" if isinstance(model, (ConfigMixin, ModelMixin)) else "checkpoint"
        model_info = None
        try:
            model_type = (
                cls.get_model_type_from_folder(model_path, model)
                if format_type == "diffusers"
                else cls.get_model_type_from_checkpoint(model_path, model)
            )
            format_type = "onnx" if model_type == ModelType.ONNX else format_type
            probe_class = cls.PROBES[format_type].get(model_type)
            if not probe_class:
                return None
            probe = probe_class(model_path, model, prediction_type_helper)
            base_type = probe.get_base_type()
            variant_type = probe.get_variant_type()
            prediction_type = probe.get_scheduler_prediction_type()
            format = probe.get_format()
            model_info = ModelProbeInfo(
                model_type=model_type,
                base_type=base_type,
                variant_type=variant_type,
                prediction_type=prediction_type,
                upcast_attention=(
                    base_type == BaseModelType.StableDiffusion2
                    and prediction_type == SchedulerPredictionType.VPrediction
                ),
                format=format,
                image_size=1024
                if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
                else 768
                if (
                    base_type == BaseModelType.StableDiffusion2
                    and prediction_type == SchedulerPredictionType.VPrediction
                )
                else 512,
            )
        except Exception:
            raise

        return model_info

    @classmethod
    def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType:
        if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
            return None

        if model_path.name == "learned_embeds.bin":
            return ModelType.TextualInversion

        ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
        ckpt = ckpt.get("state_dict", ckpt)

        for key in ckpt.keys():
            if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
                return ModelType.Main
            elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
                return ModelType.Vae
            elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
                return ModelType.Lora
            elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
                return ModelType.Lora
            elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
                return ModelType.ControlNet
            elif key in {"emb_params", "string_to_param"}:
                return ModelType.TextualInversion

        else:
            # diffusers-ti
            if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
                return ModelType.TextualInversion

        raise InvalidModelException(f"Unable to determine model type for {model_path}")

    @classmethod
    def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin) -> ModelType:
        """
        Get the model type of a hugging-face style folder.
        """
        class_name = None
        if model:
            class_name = model.__class__.__name__
        else:
            if (folder_path / "unet/model.onnx").exists():
                return ModelType.ONNX
            if (folder_path / "learned_embeds.bin").exists():
                return ModelType.TextualInversion

            if (folder_path / "pytorch_lora_weights.bin").exists():
                return ModelType.Lora

            i = folder_path / "model_index.json"
            c = folder_path / "config.json"
            config_path = i if i.exists() else c if c.exists() else None

            if config_path:
                with open(config_path, "r") as file:
                    conf = json.load(file)
                class_name = conf["_class_name"]

        if class_name and (type := cls.CLASS2TYPE.get(class_name)):
            return type

        # give up
        raise InvalidModelException(f"Unable to determine model type for {folder_path}")

    @classmethod
    def _scan_and_load_checkpoint(cls, model_path: Path) -> dict:
        with SilenceWarnings():
            if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
                cls._scan_model(model_path, model_path)
                return torch.load(model_path)
            else:
                return safetensors.torch.load_file(model_path)

    @classmethod
    def _scan_model(cls, model_name, checkpoint):
        """
        Apply picklescanner to the indicated checkpoint and issue a warning
        and option to exit if an infected file is identified.
        """
        # scan model
        scan_result = scan_file_path(checkpoint)
        if scan_result.infected_files != 0:
            raise "The model {model_name} is potentially infected by malware. Aborting import."


# ##################################################3
# Checkpoint probing
# ##################################################3
class ProbeBase(object):
    def get_base_type(self) -> BaseModelType:
        pass

    def get_variant_type(self) -> ModelVariantType:
        pass

    def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
        pass

    def get_format(self) -> str:
        pass


class CheckpointProbeBase(ProbeBase):
    def __init__(
        self, checkpoint_path: Path, checkpoint: dict, helper: Callable[[Path], SchedulerPredictionType] = None
    ) -> BaseModelType:
        self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path)
        self.checkpoint_path = checkpoint_path
        self.helper = helper

    def get_base_type(self) -> BaseModelType:
        pass

    def get_format(self) -> str:
        return "checkpoint"

    def get_variant_type(self) -> ModelVariantType:
        model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path, self.checkpoint)
        if model_type != ModelType.Main:
            return ModelVariantType.Normal
        state_dict = self.checkpoint.get("state_dict") or self.checkpoint
        in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
        if in_channels == 9:
            return ModelVariantType.Inpaint
        elif in_channels == 5:
            return ModelVariantType.Depth
        elif in_channels == 4:
            return ModelVariantType.Normal
        else:
            raise InvalidModelException(
                f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}"
            )


class PipelineCheckpointProbe(CheckpointProbeBase):
    def get_base_type(self) -> BaseModelType:
        checkpoint = self.checkpoint
        state_dict = self.checkpoint.get("state_dict") or checkpoint
        key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
        if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
            return BaseModelType.StableDiffusion1
        if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
            return BaseModelType.StableDiffusion2
        key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
        if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
            return BaseModelType.StableDiffusionXL
        elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
            return BaseModelType.StableDiffusionXLRefiner
        else:
            raise InvalidModelException("Cannot determine base type")

    def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
        type = self.get_base_type()
        if type == BaseModelType.StableDiffusion1:
            return SchedulerPredictionType.Epsilon
        checkpoint = self.checkpoint
        state_dict = self.checkpoint.get("state_dict") or checkpoint
        key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
        if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
            if "global_step" in checkpoint:
                if checkpoint["global_step"] == 220000:
                    return SchedulerPredictionType.Epsilon
                elif checkpoint["global_step"] == 110000:
                    return SchedulerPredictionType.VPrediction
            if (
                self.checkpoint_path and self.helper and not self.checkpoint_path.with_suffix(".yaml").exists()
            ):  # if a .yaml config file exists, then this step not needed
                return self.helper(self.checkpoint_path)
            else:
                return None


class VaeCheckpointProbe(CheckpointProbeBase):
    def get_base_type(self) -> BaseModelType:
        # I can't find any standalone 2.X VAEs to test with!
        return BaseModelType.StableDiffusion1


class LoRACheckpointProbe(CheckpointProbeBase):
    def get_format(self) -> str:
        return "lycoris"

    def get_base_type(self) -> BaseModelType:
        checkpoint = self.checkpoint
        token_vector_length = lora_token_vector_length(checkpoint)

        if token_vector_length == 768:
            return BaseModelType.StableDiffusion1
        elif token_vector_length == 1024:
            return BaseModelType.StableDiffusion2
        elif token_vector_length == 2048:
            return BaseModelType.StableDiffusionXL
        else:
            raise InvalidModelException(f"Unknown LoRA type: {self.checkpoint_path}")


class TextualInversionCheckpointProbe(CheckpointProbeBase):
    def get_format(self) -> str:
        return None

    def get_base_type(self) -> BaseModelType:
        checkpoint = self.checkpoint
        if "string_to_token" in checkpoint:
            token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
        elif "emb_params" in checkpoint:
            token_dim = checkpoint["emb_params"].shape[-1]
        else:
            token_dim = list(checkpoint.values())[0].shape[0]
        if token_dim == 768:
            return BaseModelType.StableDiffusion1
        elif token_dim == 1024:
            return BaseModelType.StableDiffusion2
        else:
            return None


class ControlNetCheckpointProbe(CheckpointProbeBase):
    def get_base_type(self) -> BaseModelType:
        checkpoint = self.checkpoint
        for key_name in (
            "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
            "input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
        ):
            if key_name not in checkpoint:
                continue
            if checkpoint[key_name].shape[-1] == 768:
                return BaseModelType.StableDiffusion1
            elif checkpoint[key_name].shape[-1] == 1024:
                return BaseModelType.StableDiffusion2
            elif self.checkpoint_path and self.helper:
                return self.helper(self.checkpoint_path)
        raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")


########################################################
# classes for probing folders
#######################################################
class FolderProbeBase(ProbeBase):
    def __init__(self, folder_path: Path, model: ModelMixin = None, helper: Callable = None):  # not used
        self.model = model
        self.folder_path = folder_path

    def get_variant_type(self) -> ModelVariantType:
        return ModelVariantType.Normal

    def get_format(self) -> str:
        return "diffusers"


class PipelineFolderProbe(FolderProbeBase):
    def get_base_type(self) -> BaseModelType:
        if self.model:
            unet_conf = self.model.unet.config
        else:
            with open(self.folder_path / "unet" / "config.json", "r") as file:
                unet_conf = json.load(file)
        if unet_conf["cross_attention_dim"] == 768:
            return BaseModelType.StableDiffusion1
        elif unet_conf["cross_attention_dim"] == 1024:
            return BaseModelType.StableDiffusion2
        elif unet_conf["cross_attention_dim"] == 1280:
            return BaseModelType.StableDiffusionXLRefiner
        elif unet_conf["cross_attention_dim"] == 2048:
            return BaseModelType.StableDiffusionXL
        else:
            raise InvalidModelException(f"Unknown base model for {self.folder_path}")

    def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
        if self.model:
            scheduler_conf = self.model.scheduler.config
        else:
            with open(self.folder_path / "scheduler" / "scheduler_config.json", "r") as file:
                scheduler_conf = json.load(file)
        if scheduler_conf["prediction_type"] == "v_prediction":
            return SchedulerPredictionType.VPrediction
        elif scheduler_conf["prediction_type"] == "epsilon":
            return SchedulerPredictionType.Epsilon
        else:
            return None

    def get_variant_type(self) -> ModelVariantType:
        # This only works for pipelines! Any kind of
        # exception results in our returning the
        # "normal" variant type
        try:
            if self.model:
                conf = self.model.unet.config
            else:
                config_file = self.folder_path / "unet" / "config.json"
                with open(config_file, "r") as file:
                    conf = json.load(file)

            in_channels = conf["in_channels"]
            if in_channels == 9:
                return ModelVariantType.Inpaint
            elif in_channels == 5:
                return ModelVariantType.Depth
            elif in_channels == 4:
                return ModelVariantType.Normal
        except Exception:
            pass
        return ModelVariantType.Normal


class VaeFolderProbe(FolderProbeBase):
    def get_base_type(self) -> BaseModelType:
        config_file = self.folder_path / "config.json"
        if not config_file.exists():
            raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
        with open(config_file, "r") as file:
            config = json.load(file)
        return (
            BaseModelType.StableDiffusionXL
            if config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]
            else BaseModelType.StableDiffusion1
        )


class TextualInversionFolderProbe(FolderProbeBase):
    def get_format(self) -> str:
        return None

    def get_base_type(self) -> BaseModelType:
        path = self.folder_path / "learned_embeds.bin"
        if not path.exists():
            return None
        checkpoint = ModelProbe._scan_and_load_checkpoint(path)
        return TextualInversionCheckpointProbe(None, checkpoint=checkpoint).get_base_type()


class ONNXFolderProbe(FolderProbeBase):
    def get_format(self) -> str:
        return "onnx"

    def get_base_type(self) -> BaseModelType:
        return BaseModelType.StableDiffusion1

    def get_variant_type(self) -> ModelVariantType:
        return ModelVariantType.Normal


class ControlNetFolderProbe(FolderProbeBase):
    def get_base_type(self) -> BaseModelType:
        config_file = self.folder_path / "config.json"
        if not config_file.exists():
            raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
        with open(config_file, "r") as file:
            config = json.load(file)
        # no obvious way to distinguish between sd2-base and sd2-768
        dimension = config["cross_attention_dim"]
        base_model = (
            BaseModelType.StableDiffusion1
            if dimension == 768
            else BaseModelType.StableDiffusion2
            if dimension == 1024
            else BaseModelType.StableDiffusionXL
            if dimension == 2048
            else None
        )
        if not base_model:
            raise InvalidModelException(f"Unable to determine model base for {self.folder_path}")
        return base_model


class LoRAFolderProbe(FolderProbeBase):
    def get_base_type(self) -> BaseModelType:
        model_file = None
        for suffix in ["safetensors", "bin"]:
            base_file = self.folder_path / f"pytorch_lora_weights.{suffix}"
            if base_file.exists():
                model_file = base_file
                break
        if not model_file:
            raise InvalidModelException("Unknown LoRA format encountered")
        return LoRACheckpointProbe(model_file, None).get_base_type()


############## register probe classes ######
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)