diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 5342c97e3c..010de58533 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -186,7 +186,7 @@ class InvokeAIAppConfig(BaseSettings): force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).") pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.") max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.") - nsfw_check: bool = Field(default=False, description="Enable NSFW checking for images. NSFW images will be blurred.") + nsfw_check: bool = Field(default=False, description="Enable NSFW checking for images. If an NSFW image is encountered during generation, execution will immediately stop. If disabled, the NSFW model is never loaded.") watermark: bool = Field(default=False, description="Watermark all images with `invisible-watermark`.") # NODES diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 98c49fe3dc..e8d03adc8e 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -194,7 +194,7 @@ class ImagesInterface(InvocationContextInterface): board_id_ = self._data.invocation.board.board_id if self._services.configuration.nsfw_check: - image = SafetyChecker.blur_if_nsfw(image) + SafetyChecker.raise_if_nsfw(image) if self._services.configuration.watermark: image = InvisibleWatermark.add_watermark(image, "InvokeAI") diff --git a/invokeai/backend/image_util/safety_checker.py b/invokeai/backend/image_util/safety_checker.py index 9fe16c42ea..226237fb55 100644 --- a/invokeai/backend/image_util/safety_checker.py +++ b/invokeai/backend/image_util/safety_checker.py @@ -4,11 +4,9 @@ wraps the safety_checker model. It respects the global "nsfw_checker" configuration variable, that allows the checker to be supressed. """ -from pathlib import Path - import numpy as np from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from PIL import Image, ImageFilter +from PIL import Image from transformers import AutoFeatureExtractor import invokeai.backend.util.logging as logger @@ -20,10 +18,15 @@ repo_id = "CompVis/stable-diffusion-safety-checker" CHECKER_PATH = "core/convert/stable-diffusion-safety-checker" +class NSFWImageException(Exception): + """Raised when a NSFW image is detected.""" + + def __init__(self): + super().__init__("A potentially NSFW image has been detected.") + + class SafetyChecker: - """ - Wrapper around SafetyChecker model. - """ + """Wrapper around SafetyChecker model.""" feature_extractor = None safety_checker = None @@ -72,22 +75,9 @@ class SafetyChecker: return has_nsfw_concept[0] @classmethod - def blur_if_nsfw(cls, image: Image.Image) -> Image.Image: + def raise_if_nsfw(cls, image: Image.Image) -> Image.Image: + """Raises an exception if the image contains NSFW content.""" if cls.has_nsfw_concept(image): - logger.warning("A potentially NSFW image has been detected. Image will be blurred.") - blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32)) - caution = cls._get_caution_img() - # Center the caution image on the blurred image - x = (blurry_image.width - caution.width) // 2 - y = (blurry_image.height - caution.height) // 2 - blurry_image.paste(caution, (x, y), caution) - image = blurry_image + raise NSFWImageException() return image - - @classmethod - def _get_caution_img(cls) -> Image.Image: - import invokeai.app.assets.images as image_assets - - caution = Image.open(Path(image_assets.__path__[0]) / "caution.png") - return caution.resize((caution.width // 2, caution.height // 2))