""" This module defines a singleton object, "safety_checker" that wraps the safety_checker model. It respects the global "nsfw_checker" configuration variable, that allows the checker to be supressed. """ import numpy as np from PIL import Image import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend import SilenceWarnings from invokeai.backend.util.devices import choose_torch_device config = InvokeAIAppConfig.get_config() CHECKER_PATH = "core/convert/stable-diffusion-safety-checker" class SafetyChecker: """ Wrapper around SafetyChecker model. """ safety_checker = None feature_extractor = None tried_load: bool = False @classmethod def _load_safety_checker(cls): if cls.tried_load: return if config.nsfw_checker: try: from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from transformers import AutoFeatureExtractor cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH) cls.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH) logger.info("NSFW checker initialized") except Exception as e: logger.warning(f"Could not load NSFW checker: {str(e)}") else: logger.info("NSFW checker loading disabled") cls.tried_load = True @classmethod def safety_checker_available(cls) -> bool: cls._load_safety_checker() return cls.safety_checker is not None @classmethod def has_nsfw_concept(cls, image: Image.Image) -> bool: if not cls.safety_checker_available(): return False device = choose_torch_device() features = cls.feature_extractor([image], return_tensors="pt") features.to(device) cls.safety_checker.to(device) x_image = np.array(image).astype(np.float32) / 255.0 x_image = x_image[None].transpose(0, 3, 1, 2) with SilenceWarnings(): checked_image, has_nsfw_concept = cls.safety_checker(images=x_image, clip_input=features.pixel_values) return has_nsfw_concept[0]