import json
import os
from enum import Enum
from typing import Literal, Optional

from omegaconf import OmegaConf
from pydantic import Field

from .base import (
    BaseModelType,
    DiffusersModel,
    InvalidModelException,
    ModelConfigBase,
    ModelType,
    ModelVariantType,
    classproperty,
    read_checkpoint_meta,
)


class StableDiffusionXLModelFormat(str, Enum):
    Checkpoint = "checkpoint"
    Diffusers = "diffusers"


class StableDiffusionXLModel(DiffusersModel):
    # TODO: check that configs overwriten properly
    class DiffusersConfig(ModelConfigBase):
        model_format: Literal[StableDiffusionXLModelFormat.Diffusers]
        vae: Optional[str] = Field(None)
        variant: ModelVariantType

    class CheckpointConfig(ModelConfigBase):
        model_format: Literal[StableDiffusionXLModelFormat.Checkpoint]
        vae: Optional[str] = Field(None)
        config: str
        variant: ModelVariantType

    def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
        assert base_model in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner}
        assert model_type == ModelType.Main
        super().__init__(
            model_path=model_path,
            base_model=BaseModelType.StableDiffusionXL,
            model_type=ModelType.Main,
        )

    @classmethod
    def probe_config(cls, path: str, **kwargs):
        model_format = cls.detect_format(path)
        ckpt_config_path = kwargs.get("config", None)
        if model_format == StableDiffusionXLModelFormat.Checkpoint:
            if ckpt_config_path:
                ckpt_config = OmegaConf.load(ckpt_config_path)
                in_channels = ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]

            else:
                checkpoint = read_checkpoint_meta(path)
                checkpoint = checkpoint.get("state_dict", checkpoint)
                in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]

        elif model_format == StableDiffusionXLModelFormat.Diffusers:
            unet_config_path = os.path.join(path, "unet", "config.json")
            if os.path.exists(unet_config_path):
                with open(unet_config_path, "r") as f:
                    unet_config = json.loads(f.read())
                in_channels = unet_config["in_channels"]

            else:
                raise InvalidModelException(f"{path} is not a recognized Stable Diffusion diffusers model")

        else:
            raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}")

        if in_channels == 9:
            variant = ModelVariantType.Inpaint
        elif in_channels == 5:
            variant = ModelVariantType.Depth
        elif in_channels == 4:
            variant = ModelVariantType.Normal
        else:
            raise Exception("Unkown stable diffusion 2.* model format")

        if ckpt_config_path is None:
            # avoid circular import
            from .stable_diffusion import _select_ckpt_config

            ckpt_config_path = _select_ckpt_config(kwargs.get("model_base", BaseModelType.StableDiffusionXL), variant)

        return cls.create_config(
            path=path,
            model_format=model_format,
            config=ckpt_config_path,
            variant=variant,
        )

    @classproperty
    def save_to_config(cls) -> bool:
        return True

    @classmethod
    def detect_format(cls, model_path: str):
        if os.path.isdir(model_path):
            return StableDiffusionXLModelFormat.Diffusers
        else:
            return StableDiffusionXLModelFormat.Checkpoint

    @classmethod
    def convert_if_required(
        cls,
        model_path: str,
        output_path: str,
        config: ModelConfigBase,
        base_model: BaseModelType,
    ) -> str:
        # The convert script adapted from the diffusers package uses
        # strings for the base model type. To avoid making too many
        # source code changes, we simply translate here
        if isinstance(config, cls.CheckpointConfig):
            from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache

            return _convert_ckpt_and_cache(
                version=base_model,
                model_config=config,
                output_path=output_path,
                use_safetensors=False,  # corrupts sdxl models for some reason
            )
        else:
            return model_path