"""
This module defines a singleton object, "safety_checker" that
wraps the safety_checker model. It respects the global "nsfw_checker"
configuration variable, that allows the checker to be supressed.
"""
import numpy as np
from PIL import Image

import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend import SilenceWarnings
from invokeai.backend.util.devices import choose_torch_device

config = InvokeAIAppConfig.get_config()

CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"


class SafetyChecker:
    """
    Wrapper around SafetyChecker model.
    """

    safety_checker = None
    feature_extractor = None
    tried_load: bool = False

    @classmethod
    def _load_safety_checker(cls):
        if cls.tried_load:
            return

        if config.nsfw_checker:
            try:
                from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
                from transformers import AutoFeatureExtractor

                cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH)
                cls.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH)
                logger.info("NSFW checker initialized")
            except Exception as e:
                logger.warning(f"Could not load NSFW checker: {str(e)}")
        else:
            logger.info("NSFW checker loading disabled")
        cls.tried_load = True

    @classmethod
    def safety_checker_available(cls) -> bool:
        cls._load_safety_checker()
        return cls.safety_checker is not None

    @classmethod
    def has_nsfw_concept(cls, image: Image.Image) -> bool:
        if not cls.safety_checker_available():
            return False

        device = choose_torch_device()
        features = cls.feature_extractor([image], return_tensors="pt")
        features.to(device)
        cls.safety_checker.to(device)
        x_image = np.array(image).astype(np.float32) / 255.0
        x_image = x_image[None].transpose(0, 3, 1, 2)
        with SilenceWarnings():
            checked_image, has_nsfw_concept = cls.safety_checker(images=x_image, clip_input=features.pixel_values)
        return has_nsfw_concept[0]