feat(nodes): raise on NSFW

This commit is contained in:
psychedelicious 2024-05-13 17:35:35 +10:00
parent 8bddcce2b0
commit 60d16fde2c
3 changed files with 14 additions and 24 deletions

View File

@ -186,7 +186,7 @@ class InvokeAIAppConfig(BaseSettings):
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
nsfw_check: bool = Field(default=False, description="Enable NSFW checking for images. NSFW images will be blurred.")
nsfw_check: bool = Field(default=False, description="Enable NSFW checking for images. If an NSFW image is encountered during generation, execution will immediately stop. If disabled, the NSFW model is never loaded.")
watermark: bool = Field(default=False, description="Watermark all images with `invisible-watermark`.")
# NODES

View File

@ -194,7 +194,7 @@ class ImagesInterface(InvocationContextInterface):
board_id_ = self._data.invocation.board.board_id
if self._services.configuration.nsfw_check:
image = SafetyChecker.blur_if_nsfw(image)
SafetyChecker.raise_if_nsfw(image)
if self._services.configuration.watermark:
image = InvisibleWatermark.add_watermark(image, "InvokeAI")

View File

@ -4,11 +4,9 @@ wraps the safety_checker model. It respects the global "nsfw_checker"
configuration variable, that allows the checker to be supressed.
"""
from pathlib import Path
import numpy as np
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from PIL import Image, ImageFilter
from PIL import Image
from transformers import AutoFeatureExtractor
import invokeai.backend.util.logging as logger
@ -20,10 +18,15 @@ repo_id = "CompVis/stable-diffusion-safety-checker"
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
class NSFWImageException(Exception):
"""Raised when a NSFW image is detected."""
def __init__(self):
super().__init__("A potentially NSFW image has been detected.")
class SafetyChecker:
"""
Wrapper around SafetyChecker model.
"""
"""Wrapper around SafetyChecker model."""
feature_extractor = None
safety_checker = None
@ -72,22 +75,9 @@ class SafetyChecker:
return has_nsfw_concept[0]
@classmethod
def blur_if_nsfw(cls, image: Image.Image) -> Image.Image:
def raise_if_nsfw(cls, image: Image.Image) -> Image.Image:
"""Raises an exception if the image contains NSFW content."""
if cls.has_nsfw_concept(image):
logger.warning("A potentially NSFW image has been detected. Image will be blurred.")
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
caution = cls._get_caution_img()
# Center the caution image on the blurred image
x = (blurry_image.width - caution.width) // 2
y = (blurry_image.height - caution.height) // 2
blurry_image.paste(caution, (x, y), caution)
image = blurry_image
raise NSFWImageException()
return image
@classmethod
def _get_caution_img(cls) -> Image.Image:
import invokeai.app.assets.images as image_assets
caution = Image.open(Path(image_assets.__path__[0]) / "caution.png")
return caution.resize((caution.width // 2, caution.height // 2))