"""
This module defines a singleton object, "safety_checker" that
wraps the safety_checker model. It respects the global "nsfw_checker"
configuration variable, that allows the checker to be supressed.
"""

from pathlib import Path

import numpy as np
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from PIL import Image, ImageFilter
from transformers import AutoFeatureExtractor

import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.silence_warnings import SilenceWarnings

repo_id = "CompVis/stable-diffusion-safety-checker"
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"


class SafetyChecker:
    """
    Wrapper around SafetyChecker model.
    """

    feature_extractor = None
    safety_checker = None

    @classmethod
    def _load_safety_checker(cls):
        if cls.safety_checker is not None and cls.feature_extractor is not None:
            return

        try:
            model_path = get_config().models_path / CHECKER_PATH
            if model_path.exists():
                cls.feature_extractor = AutoFeatureExtractor.from_pretrained(model_path)
                cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(model_path)
            else:
                model_path.mkdir(parents=True, exist_ok=True)
                cls.feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
                cls.feature_extractor.save_pretrained(model_path, safe_serialization=True)
                cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(repo_id)
                cls.safety_checker.save_pretrained(model_path, safe_serialization=True)
        except Exception as e:
            logger.warning(f"Could not load NSFW checker: {str(e)}")

    @classmethod
    def has_nsfw_concept(cls, image: Image.Image) -> bool:
        cls._load_safety_checker()
        if cls.safety_checker is None or cls.feature_extractor is None:
            return False
        device = TorchDevice.choose_torch_device()
        features = cls.feature_extractor([image], return_tensors="pt")
        features.to(device)
        cls.safety_checker.to(device)
        x_image = np.array(image).astype(np.float32) / 255.0
        x_image = x_image[None].transpose(0, 3, 1, 2)
        with SilenceWarnings():
            checked_image, has_nsfw_concept = cls.safety_checker(images=x_image, clip_input=features.pixel_values)
        return has_nsfw_concept[0]

    @classmethod
    def blur_if_nsfw(cls, image: Image.Image) -> Image.Image:
        if cls.has_nsfw_concept(image):
            logger.warning("A potentially NSFW image has been detected. Image will be blurred.")
            blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
            caution = cls._get_caution_img()
            # Center the caution image on the blurred image
            x = (blurry_image.width - caution.width) // 2
            y = (blurry_image.height - caution.height) // 2
            blurry_image.paste(caution, (x, y), caution)
            image = blurry_image

        return image

    @classmethod
    def _get_caution_img(cls) -> Image.Image:
        import invokeai.app.assets.images as image_assets

        caution = Image.open(Path(image_assets.__path__[0]) / "caution.png")
        return caution.resize((caution.width // 2, caution.height // 2))