import json
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Dict, Literal, Optional, Union

import safetensors.torch
import torch
from diffusers import ConfigMixin, ModelMixin
from picklescan.scanner import scan_file_path

from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat

from .models import (
    BaseModelType,
    InvalidModelException,
    ModelType,
    ModelVariantType,
    SchedulerPredictionType,
    SilenceWarnings,
)
from .models.base import read_checkpoint_meta
from .util import lora_token_vector_length


@dataclass
class ModelProbeInfo(object):
    model_type: ModelType
    base_type: BaseModelType
    variant_type: ModelVariantType
    prediction_type: SchedulerPredictionType
    upcast_attention: bool
    format: Literal["diffusers", "checkpoint", "lycoris", "olive", "onnx"]
    image_size: int
    name: Optional[str] = None
    description: Optional[str] = None


class ProbeBase(object):
    """forward declaration"""

    pass


class ModelProbe(object):
    PROBES = {
        "diffusers": {},
        "checkpoint": {},
        "onnx": {},
    }

    CLASS2TYPE = {
        "StableDiffusionPipeline": ModelType.Main,
        "StableDiffusionInpaintPipeline": ModelType.Main,
        "StableDiffusionXLPipeline": ModelType.Main,
        "StableDiffusionXLImg2ImgPipeline": ModelType.Main,
        "StableDiffusionXLInpaintPipeline": ModelType.Main,
        "LatentConsistencyModelPipeline": ModelType.Main,
        "AutoencoderKL": ModelType.Vae,
        "AutoencoderTiny": ModelType.Vae,
        "ControlNetModel": ModelType.ControlNet,
        "CLIPVisionModelWithProjection": ModelType.CLIPVision,
        "T2IAdapter": ModelType.T2IAdapter,
    }

    @classmethod
    def register_probe(
        cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: ProbeBase
    ):
        cls.PROBES[format][model_type] = probe_class

    @classmethod
    def heuristic_probe(
        cls,
        model: Union[Dict, ModelMixin, Path],
        prediction_type_helper: Callable[[Path], SchedulerPredictionType] = None,
    ) -> ModelProbeInfo:
        if isinstance(model, Path):
            return cls.probe(model_path=model, prediction_type_helper=prediction_type_helper)
        elif isinstance(model, (dict, ModelMixin, ConfigMixin)):
            return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
        else:
            raise InvalidModelException("model parameter {model} is neither a Path, nor a model")

    @classmethod
    def probe(
        cls,
        model_path: Path,
        model: Optional[Union[Dict, ModelMixin]] = None,
        prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
    ) -> ModelProbeInfo:
        """
        Probe the model at model_path and return sufficient information about it
        to place it somewhere in the models directory hierarchy. If the model is
        already loaded into memory, you may provide it as model in order to avoid
        opening it a second time. The prediction_type_helper callable is a function that receives
        the path to the model and returns the SchedulerPredictionType.
        """
        if model_path:
            format_type = "diffusers" if model_path.is_dir() else "checkpoint"
        else:
            format_type = "diffusers" if isinstance(model, (ConfigMixin, ModelMixin)) else "checkpoint"
        model_info = None
        try:
            model_type = (
                cls.get_model_type_from_folder(model_path, model)
                if format_type == "diffusers"
                else cls.get_model_type_from_checkpoint(model_path, model)
            )
            format_type = "onnx" if model_type == ModelType.ONNX else format_type
            probe_class = cls.PROBES[format_type].get(model_type)
            if not probe_class:
                return None
            probe = probe_class(model_path, model, prediction_type_helper)
            base_type = probe.get_base_type()
            variant_type = probe.get_variant_type()
            prediction_type = probe.get_scheduler_prediction_type()
            name = cls.get_model_name(model_path)
            description = f"{base_type.value} {model_type.value} model {name}"
            format = probe.get_format()
            model_info = ModelProbeInfo(
                model_type=model_type,
                base_type=base_type,
                variant_type=variant_type,
                prediction_type=prediction_type,
                name=name,
                description=description,
                upcast_attention=(
                    base_type == BaseModelType.StableDiffusion2
                    and prediction_type == SchedulerPredictionType.VPrediction
                ),
                format=format,
                image_size=(
                    1024
                    if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
                    else (
                        768
                        if (
                            base_type == BaseModelType.StableDiffusion2
                            and prediction_type == SchedulerPredictionType.VPrediction
                        )
                        else 512
                    )
                ),
            )
        except Exception:
            raise

        return model_info

    @classmethod
    def get_model_name(cls, model_path: Path) -> str:
        if model_path.suffix in {".safetensors", ".bin", ".pt", ".ckpt"}:
            return model_path.stem
        else:
            return model_path.name

    @classmethod
    def get_model_type_from_checkpoint(cls, model_path: Path, checkpoint: dict) -> ModelType:
        if model_path.suffix not in (".bin", ".pt", ".ckpt", ".safetensors", ".pth"):
            return None

        if model_path.name == "learned_embeds.bin":
            return ModelType.TextualInversion

        ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
        ckpt = ckpt.get("state_dict", ckpt)

        for key in ckpt.keys():
            if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
                return ModelType.Main
            elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
                return ModelType.Vae
            elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
                return ModelType.Lora
            elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
                return ModelType.Lora
            elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
                return ModelType.ControlNet
            elif key in {"emb_params", "string_to_param"}:
                return ModelType.TextualInversion

        else:
            # diffusers-ti
            if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
                return ModelType.TextualInversion

        raise InvalidModelException(f"Unable to determine model type for {model_path}")

    @classmethod
    def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin) -> ModelType:
        """
        Get the model type of a hugging-face style folder.
        """
        class_name = None
        error_hint = None
        if model:
            class_name = model.__class__.__name__
        else:
            for suffix in ["bin", "safetensors"]:
                if (folder_path / f"learned_embeds.{suffix}").exists():
                    return ModelType.TextualInversion
                if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
                    return ModelType.Lora
            if (folder_path / "unet/model.onnx").exists():
                return ModelType.ONNX
            if (folder_path / "image_encoder.txt").exists():
                return ModelType.IPAdapter

            i = folder_path / "model_index.json"
            c = folder_path / "config.json"
            config_path = i if i.exists() else c if c.exists() else None

            if config_path:
                with open(config_path, "r") as file:
                    conf = json.load(file)
                if "_class_name" in conf:
                    class_name = conf["_class_name"]
                elif "architectures" in conf:
                    class_name = conf["architectures"][0]
                else:
                    class_name = None
            else:
                error_hint = f"No model_index.json or config.json found in {folder_path}."

        if class_name and (type := cls.CLASS2TYPE.get(class_name)):
            return type
        else:
            error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]"

        # give up
        raise InvalidModelException(
            f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "")
        )

    @classmethod
    def _scan_and_load_checkpoint(cls, model_path: Path) -> dict:
        with SilenceWarnings():
            if model_path.suffix.endswith((".ckpt", ".pt", ".bin")):
                cls._scan_model(model_path, model_path)
                return torch.load(model_path, map_location="cpu")
            else:
                return safetensors.torch.load_file(model_path)

    @classmethod
    def _scan_model(cls, model_name, checkpoint):
        """
        Apply picklescanner to the indicated checkpoint and issue a warning
        and option to exit if an infected file is identified.
        """
        # scan model
        scan_result = scan_file_path(checkpoint)
        if scan_result.infected_files != 0:
            raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")


# ##################################################3
# Checkpoint probing
# ##################################################3
class ProbeBase(object):
    def get_base_type(self) -> BaseModelType:
        pass

    def get_variant_type(self) -> ModelVariantType:
        pass

    def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
        pass

    def get_format(self) -> str:
        pass


class CheckpointProbeBase(ProbeBase):
    def __init__(
        self, checkpoint_path: Path, checkpoint: dict, helper: Callable[[Path], SchedulerPredictionType] = None
    ) -> BaseModelType:
        self.checkpoint = checkpoint or ModelProbe._scan_and_load_checkpoint(checkpoint_path)
        self.checkpoint_path = checkpoint_path
        self.helper = helper

    def get_base_type(self) -> BaseModelType:
        pass

    def get_format(self) -> str:
        return "checkpoint"

    def get_variant_type(self) -> ModelVariantType:
        model_type = ModelProbe.get_model_type_from_checkpoint(self.checkpoint_path, self.checkpoint)
        if model_type != ModelType.Main:
            return ModelVariantType.Normal
        state_dict = self.checkpoint.get("state_dict") or self.checkpoint
        in_channels = state_dict["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
        if in_channels == 9:
            return ModelVariantType.Inpaint
        elif in_channels == 5:
            return ModelVariantType.Depth
        elif in_channels == 4:
            return ModelVariantType.Normal
        else:
            raise InvalidModelException(
                f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}"
            )


class PipelineCheckpointProbe(CheckpointProbeBase):
    def get_base_type(self) -> BaseModelType:
        checkpoint = self.checkpoint
        state_dict = self.checkpoint.get("state_dict") or checkpoint
        key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
        if key_name in state_dict and state_dict[key_name].shape[-1] == 768:
            return BaseModelType.StableDiffusion1
        if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
            return BaseModelType.StableDiffusion2
        key_name = "model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight"
        if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
            return BaseModelType.StableDiffusionXL
        elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
            return BaseModelType.StableDiffusionXLRefiner
        else:
            raise InvalidModelException("Cannot determine base type")

    def get_scheduler_prediction_type(self) -> Optional[SchedulerPredictionType]:
        """Return model prediction type."""
        # if there is a .yaml associated with this checkpoint, then we do not need
        # to probe for the prediction type as it will be ignored.
        if self.checkpoint_path and self.checkpoint_path.with_suffix(".yaml").exists():
            return None

        type = self.get_base_type()
        if type == BaseModelType.StableDiffusion2:
            checkpoint = self.checkpoint
            state_dict = self.checkpoint.get("state_dict") or checkpoint
            key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
            if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
                if "global_step" in checkpoint:
                    if checkpoint["global_step"] == 220000:
                        return SchedulerPredictionType.Epsilon
                    elif checkpoint["global_step"] == 110000:
                        return SchedulerPredictionType.VPrediction
            if self.helper and self.checkpoint_path:
                if helper_guess := self.helper(self.checkpoint_path):
                    return helper_guess
            return SchedulerPredictionType.VPrediction  # a guess for sd2 ckpts

        elif type == BaseModelType.StableDiffusion1:
            if self.helper and self.checkpoint_path:
                if helper_guess := self.helper(self.checkpoint_path):
                    return helper_guess
            return SchedulerPredictionType.Epsilon  # a reasonable guess for sd1 ckpts
        else:
            return None


class VaeCheckpointProbe(CheckpointProbeBase):
    def get_base_type(self) -> BaseModelType:
        # I can't find any standalone 2.X VAEs to test with!
        return BaseModelType.StableDiffusion1


class LoRACheckpointProbe(CheckpointProbeBase):
    def get_format(self) -> str:
        return "lycoris"

    def get_base_type(self) -> BaseModelType:
        checkpoint = self.checkpoint
        token_vector_length = lora_token_vector_length(checkpoint)

        if token_vector_length == 768:
            return BaseModelType.StableDiffusion1
        elif token_vector_length == 1024:
            return BaseModelType.StableDiffusion2
        elif token_vector_length == 1280:
            return BaseModelType.StableDiffusionXL  # recognizes format at https://civitai.com/models/224641
        elif token_vector_length == 2048:
            return BaseModelType.StableDiffusionXL
        else:
            raise InvalidModelException(f"Unknown LoRA type: {self.checkpoint_path}")


class TextualInversionCheckpointProbe(CheckpointProbeBase):
    def get_format(self) -> str:
        return None

    def get_base_type(self) -> BaseModelType:
        checkpoint = self.checkpoint
        if "string_to_token" in checkpoint:
            token_dim = list(checkpoint["string_to_param"].values())[0].shape[-1]
        elif "emb_params" in checkpoint:
            token_dim = checkpoint["emb_params"].shape[-1]
        elif "clip_g" in checkpoint:
            token_dim = checkpoint["clip_g"].shape[-1]
        else:
            token_dim = list(checkpoint.values())[0].shape[-1]
        if token_dim == 768:
            return BaseModelType.StableDiffusion1
        elif token_dim == 1024:
            return BaseModelType.StableDiffusion2
        elif token_dim == 1280:
            return BaseModelType.StableDiffusionXL
        else:
            return None


class ControlNetCheckpointProbe(CheckpointProbeBase):
    def get_base_type(self) -> BaseModelType:
        checkpoint = self.checkpoint
        for key_name in (
            "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
            "input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
        ):
            if key_name not in checkpoint:
                continue
            if checkpoint[key_name].shape[-1] == 768:
                return BaseModelType.StableDiffusion1
            elif checkpoint[key_name].shape[-1] == 1024:
                return BaseModelType.StableDiffusion2
            elif self.checkpoint_path and self.helper:
                return self.helper(self.checkpoint_path)
        raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")


class IPAdapterCheckpointProbe(CheckpointProbeBase):
    def get_base_type(self) -> BaseModelType:
        raise NotImplementedError()


class CLIPVisionCheckpointProbe(CheckpointProbeBase):
    def get_base_type(self) -> BaseModelType:
        raise NotImplementedError()


class T2IAdapterCheckpointProbe(CheckpointProbeBase):
    def get_base_type(self) -> BaseModelType:
        raise NotImplementedError()


########################################################
# classes for probing folders
#######################################################
class FolderProbeBase(ProbeBase):
    def __init__(self, folder_path: Path, model: ModelMixin = None, helper: Callable = None):  # not used
        self.model = model
        self.folder_path = folder_path

    def get_variant_type(self) -> ModelVariantType:
        return ModelVariantType.Normal

    def get_format(self) -> str:
        return "diffusers"


class PipelineFolderProbe(FolderProbeBase):
    def get_base_type(self) -> BaseModelType:
        if self.model:
            unet_conf = self.model.unet.config
        else:
            with open(self.folder_path / "unet" / "config.json", "r") as file:
                unet_conf = json.load(file)
        if unet_conf["cross_attention_dim"] == 768:
            return BaseModelType.StableDiffusion1
        elif unet_conf["cross_attention_dim"] == 1024:
            return BaseModelType.StableDiffusion2
        elif unet_conf["cross_attention_dim"] == 1280:
            return BaseModelType.StableDiffusionXLRefiner
        elif unet_conf["cross_attention_dim"] == 2048:
            return BaseModelType.StableDiffusionXL
        else:
            raise InvalidModelException(f"Unknown base model for {self.folder_path}")

    def get_scheduler_prediction_type(self) -> SchedulerPredictionType:
        if self.model:
            scheduler_conf = self.model.scheduler.config
        else:
            with open(self.folder_path / "scheduler" / "scheduler_config.json", "r") as file:
                scheduler_conf = json.load(file)
        if scheduler_conf["prediction_type"] == "v_prediction":
            return SchedulerPredictionType.VPrediction
        elif scheduler_conf["prediction_type"] == "epsilon":
            return SchedulerPredictionType.Epsilon
        else:
            return None

    def get_variant_type(self) -> ModelVariantType:
        # This only works for pipelines! Any kind of
        # exception results in our returning the
        # "normal" variant type
        try:
            if self.model:
                conf = self.model.unet.config
            else:
                config_file = self.folder_path / "unet" / "config.json"
                with open(config_file, "r") as file:
                    conf = json.load(file)

            in_channels = conf["in_channels"]
            if in_channels == 9:
                return ModelVariantType.Inpaint
            elif in_channels == 5:
                return ModelVariantType.Depth
            elif in_channels == 4:
                return ModelVariantType.Normal
        except Exception:
            pass
        return ModelVariantType.Normal


class VaeFolderProbe(FolderProbeBase):
    def get_base_type(self) -> BaseModelType:
        if self._config_looks_like_sdxl():
            return BaseModelType.StableDiffusionXL
        elif self._name_looks_like_sdxl():
            # but SD and SDXL VAE are the same shape (3-channel RGB to 4-channel float scaled down
            # by a factor of 8), we can't necessarily tell them apart by config hyperparameters.
            return BaseModelType.StableDiffusionXL
        else:
            return BaseModelType.StableDiffusion1

    def _config_looks_like_sdxl(self) -> bool:
        # config values that distinguish Stability's SD 1.x VAE from their SDXL VAE.
        config_file = self.folder_path / "config.json"
        if not config_file.exists():
            raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
        with open(config_file, "r") as file:
            config = json.load(file)
        return config.get("scaling_factor", 0) == 0.13025 and config.get("sample_size") in [512, 1024]

    def _name_looks_like_sdxl(self) -> bool:
        return bool(re.search(r"xl\b", self._guess_name(), re.IGNORECASE))

    def _guess_name(self) -> str:
        name = self.folder_path.name
        if name == "vae":
            name = self.folder_path.parent.name
        return name


class TextualInversionFolderProbe(FolderProbeBase):
    def get_format(self) -> str:
        return None

    def get_base_type(self) -> BaseModelType:
        path = self.folder_path / "learned_embeds.bin"
        if not path.exists():
            return None
        checkpoint = ModelProbe._scan_and_load_checkpoint(path)
        return TextualInversionCheckpointProbe(None, checkpoint=checkpoint).get_base_type()


class ONNXFolderProbe(FolderProbeBase):
    def get_format(self) -> str:
        return "onnx"

    def get_base_type(self) -> BaseModelType:
        return BaseModelType.StableDiffusion1

    def get_variant_type(self) -> ModelVariantType:
        return ModelVariantType.Normal


class ControlNetFolderProbe(FolderProbeBase):
    def get_base_type(self) -> BaseModelType:
        config_file = self.folder_path / "config.json"
        if not config_file.exists():
            raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
        with open(config_file, "r") as file:
            config = json.load(file)
        # no obvious way to distinguish between sd2-base and sd2-768
        dimension = config["cross_attention_dim"]
        base_model = (
            BaseModelType.StableDiffusion1
            if dimension == 768
            else (
                BaseModelType.StableDiffusion2
                if dimension == 1024
                else BaseModelType.StableDiffusionXL
                if dimension == 2048
                else None
            )
        )
        if not base_model:
            raise InvalidModelException(f"Unable to determine model base for {self.folder_path}")
        return base_model


class LoRAFolderProbe(FolderProbeBase):
    def get_base_type(self) -> BaseModelType:
        model_file = None
        for suffix in ["safetensors", "bin"]:
            base_file = self.folder_path / f"pytorch_lora_weights.{suffix}"
            if base_file.exists():
                model_file = base_file
                break
        if not model_file:
            raise InvalidModelException("Unknown LoRA format encountered")
        return LoRACheckpointProbe(model_file, None).get_base_type()


class IPAdapterFolderProbe(FolderProbeBase):
    def get_format(self) -> str:
        return IPAdapterModelFormat.InvokeAI.value

    def get_base_type(self) -> BaseModelType:
        model_file = self.folder_path / "ip_adapter.bin"
        if not model_file.exists():
            raise InvalidModelException("Unknown IP-Adapter model format.")

        state_dict = torch.load(model_file, map_location="cpu")
        cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
        if cross_attention_dim == 768:
            return BaseModelType.StableDiffusion1
        elif cross_attention_dim == 1024:
            return BaseModelType.StableDiffusion2
        elif cross_attention_dim == 2048:
            return BaseModelType.StableDiffusionXL
        else:
            raise InvalidModelException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.")


class CLIPVisionFolderProbe(FolderProbeBase):
    def get_base_type(self) -> BaseModelType:
        return BaseModelType.Any


class T2IAdapterFolderProbe(FolderProbeBase):
    def get_base_type(self) -> BaseModelType:
        config_file = self.folder_path / "config.json"
        if not config_file.exists():
            raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
        with open(config_file, "r") as file:
            config = json.load(file)

        adapter_type = config.get("adapter_type", None)
        if adapter_type == "full_adapter_xl":
            return BaseModelType.StableDiffusionXL
        elif adapter_type == "full_adapter" or "light_adapter":
            # I haven't seen any T2I adapter models for SD2, so assume that this is an SD1 adapter.
            return BaseModelType.StableDiffusion1
        else:
            raise InvalidModelException(
                f"Unable to determine base model for '{self.folder_path}' (adapter_type = {adapter_type})."
            )


############## register probe classes ######
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)

ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)

ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)