feat(nodes): raise on NSFW

This commit is contained in:
psychedelicious 2024-05-13 17:35:35 +10:00
parent 8bddcce2b0
commit 60d16fde2c
3 changed files with 14 additions and 24 deletions

View File

@ -186,7 +186,7 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).") force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.") pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.") max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
nsfw_check: bool = Field(default=False, description="Enable NSFW checking for images. NSFW images will be blurred.") nsfw_check: bool = Field(default=False, description="Enable NSFW checking for images. If an NSFW image is encountered during generation, execution will immediately stop. If disabled, the NSFW model is never loaded.")
watermark: bool = Field(default=False, description="Watermark all images with `invisible-watermark`.") watermark: bool = Field(default=False, description="Watermark all images with `invisible-watermark`.")
# NODES # NODES

View File

@ -194,7 +194,7 @@ class ImagesInterface(InvocationContextInterface):
board_id_ = self._data.invocation.board.board_id board_id_ = self._data.invocation.board.board_id
if self._services.configuration.nsfw_check: if self._services.configuration.nsfw_check:
image = SafetyChecker.blur_if_nsfw(image) SafetyChecker.raise_if_nsfw(image)
if self._services.configuration.watermark: if self._services.configuration.watermark:
image = InvisibleWatermark.add_watermark(image, "InvokeAI") image = InvisibleWatermark.add_watermark(image, "InvokeAI")

View File

@ -4,11 +4,9 @@ wraps the safety_checker model. It respects the global "nsfw_checker"
configuration variable, that allows the checker to be supressed. configuration variable, that allows the checker to be supressed.
""" """
from pathlib import Path
import numpy as np import numpy as np
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from PIL import Image, ImageFilter from PIL import Image
from transformers import AutoFeatureExtractor from transformers import AutoFeatureExtractor
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
@ -20,10 +18,15 @@ repo_id = "CompVis/stable-diffusion-safety-checker"
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker" CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
class NSFWImageException(Exception):
"""Raised when a NSFW image is detected."""
def __init__(self):
super().__init__("A potentially NSFW image has been detected.")
class SafetyChecker: class SafetyChecker:
""" """Wrapper around SafetyChecker model."""
Wrapper around SafetyChecker model.
"""
feature_extractor = None feature_extractor = None
safety_checker = None safety_checker = None
@ -72,22 +75,9 @@ class SafetyChecker:
return has_nsfw_concept[0] return has_nsfw_concept[0]
@classmethod @classmethod
def blur_if_nsfw(cls, image: Image.Image) -> Image.Image: def raise_if_nsfw(cls, image: Image.Image) -> Image.Image:
"""Raises an exception if the image contains NSFW content."""
if cls.has_nsfw_concept(image): if cls.has_nsfw_concept(image):
logger.warning("A potentially NSFW image has been detected. Image will be blurred.") raise NSFWImageException()
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
caution = cls._get_caution_img()
# Center the caution image on the blurred image
x = (blurry_image.width - caution.width) // 2
y = (blurry_image.height - caution.height) // 2
blurry_image.paste(caution, (x, y), caution)
image = blurry_image
return image return image
@classmethod
def _get_caution_img(cls) -> Image.Image:
import invokeai.app.assets.images as image_assets
caution = Image.open(Path(image_assets.__path__[0]) / "caution.png")
return caution.resize((caution.width // 2, caution.height // 2))