mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): raise on NSFW
This commit is contained in:
parent
8bddcce2b0
commit
60d16fde2c
@ -186,7 +186,7 @@ class InvokeAIAppConfig(BaseSettings):
|
|||||||
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
|
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
|
||||||
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
|
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
|
||||||
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
|
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
|
||||||
nsfw_check: bool = Field(default=False, description="Enable NSFW checking for images. NSFW images will be blurred.")
|
nsfw_check: bool = Field(default=False, description="Enable NSFW checking for images. If an NSFW image is encountered during generation, execution will immediately stop. If disabled, the NSFW model is never loaded.")
|
||||||
watermark: bool = Field(default=False, description="Watermark all images with `invisible-watermark`.")
|
watermark: bool = Field(default=False, description="Watermark all images with `invisible-watermark`.")
|
||||||
|
|
||||||
# NODES
|
# NODES
|
||||||
|
@ -194,7 +194,7 @@ class ImagesInterface(InvocationContextInterface):
|
|||||||
board_id_ = self._data.invocation.board.board_id
|
board_id_ = self._data.invocation.board.board_id
|
||||||
|
|
||||||
if self._services.configuration.nsfw_check:
|
if self._services.configuration.nsfw_check:
|
||||||
image = SafetyChecker.blur_if_nsfw(image)
|
SafetyChecker.raise_if_nsfw(image)
|
||||||
|
|
||||||
if self._services.configuration.watermark:
|
if self._services.configuration.watermark:
|
||||||
image = InvisibleWatermark.add_watermark(image, "InvokeAI")
|
image = InvisibleWatermark.add_watermark(image, "InvokeAI")
|
||||||
|
@ -4,11 +4,9 @@ wraps the safety_checker model. It respects the global "nsfw_checker"
|
|||||||
configuration variable, that allows the checker to be supressed.
|
configuration variable, that allows the checker to be supressed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from PIL import Image, ImageFilter
|
from PIL import Image
|
||||||
from transformers import AutoFeatureExtractor
|
from transformers import AutoFeatureExtractor
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
@ -20,10 +18,15 @@ repo_id = "CompVis/stable-diffusion-safety-checker"
|
|||||||
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
|
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
|
||||||
|
|
||||||
|
|
||||||
|
class NSFWImageException(Exception):
|
||||||
|
"""Raised when a NSFW image is detected."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("A potentially NSFW image has been detected.")
|
||||||
|
|
||||||
|
|
||||||
class SafetyChecker:
|
class SafetyChecker:
|
||||||
"""
|
"""Wrapper around SafetyChecker model."""
|
||||||
Wrapper around SafetyChecker model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
feature_extractor = None
|
feature_extractor = None
|
||||||
safety_checker = None
|
safety_checker = None
|
||||||
@ -72,22 +75,9 @@ class SafetyChecker:
|
|||||||
return has_nsfw_concept[0]
|
return has_nsfw_concept[0]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def blur_if_nsfw(cls, image: Image.Image) -> Image.Image:
|
def raise_if_nsfw(cls, image: Image.Image) -> Image.Image:
|
||||||
|
"""Raises an exception if the image contains NSFW content."""
|
||||||
if cls.has_nsfw_concept(image):
|
if cls.has_nsfw_concept(image):
|
||||||
logger.warning("A potentially NSFW image has been detected. Image will be blurred.")
|
raise NSFWImageException()
|
||||||
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
|
||||||
caution = cls._get_caution_img()
|
|
||||||
# Center the caution image on the blurred image
|
|
||||||
x = (blurry_image.width - caution.width) // 2
|
|
||||||
y = (blurry_image.height - caution.height) // 2
|
|
||||||
blurry_image.paste(caution, (x, y), caution)
|
|
||||||
image = blurry_image
|
|
||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _get_caution_img(cls) -> Image.Image:
|
|
||||||
import invokeai.app.assets.images as image_assets
|
|
||||||
|
|
||||||
caution = Image.open(Path(image_assets.__path__[0]) / "caution.png")
|
|
||||||
return caution.resize((caution.width // 2, caution.height // 2))
|
|
||||||
|
Loading…
Reference in New Issue
Block a user