InvokeAI/invokeai/backend/image_util/safety_checker.py
2023-10-18 09:08:13 +11:00

66 lines
2.3 KiB
Python

"""
This module defines a singleton object, "safety_checker" that
wraps the safety_checker model. It respects the global "nsfw_checker"
configuration variable, that allows the checker to be supressed.
"""
import numpy as np
from PIL import Image
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend import SilenceWarnings
from invokeai.backend.util.devices import choose_torch_device
config = InvokeAIAppConfig.get_config()
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
class SafetyChecker:
"""
Wrapper around SafetyChecker model.
"""
safety_checker = None
feature_extractor = None
tried_load: bool = False
@classmethod
def _load_safety_checker(cls):
if cls.tried_load:
return
if config.nsfw_checker:
try:
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH)
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH)
logger.info("NSFW checker initialized")
except Exception as e:
logger.warning(f"Could not load NSFW checker: {str(e)}")
else:
logger.info("NSFW checker loading disabled")
cls.tried_load = True
@classmethod
def safety_checker_available(cls) -> bool:
cls._load_safety_checker()
return cls.safety_checker is not None
@classmethod
def has_nsfw_concept(cls, image: Image.Image) -> bool:
if not cls.safety_checker_available():
return False
device = choose_torch_device()
features = cls.feature_extractor([image], return_tensors="pt")
features.to(device)
cls.safety_checker.to(device)
x_image = np.array(image).astype(np.float32) / 255.0
x_image = x_image[None].transpose(0, 3, 1, 2)
with SilenceWarnings():
checked_image, has_nsfw_concept = cls.safety_checker(images=x_image, clip_input=features.pixel_values)
return has_nsfw_concept[0]