mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): raise on NSFW
This commit is contained in:
parent
8bddcce2b0
commit
60d16fde2c
@ -186,7 +186,7 @@ class InvokeAIAppConfig(BaseSettings):
|
||||
force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).")
|
||||
pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.")
|
||||
max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.")
|
||||
nsfw_check: bool = Field(default=False, description="Enable NSFW checking for images. NSFW images will be blurred.")
|
||||
nsfw_check: bool = Field(default=False, description="Enable NSFW checking for images. If an NSFW image is encountered during generation, execution will immediately stop. If disabled, the NSFW model is never loaded.")
|
||||
watermark: bool = Field(default=False, description="Watermark all images with `invisible-watermark`.")
|
||||
|
||||
# NODES
|
||||
|
@ -194,7 +194,7 @@ class ImagesInterface(InvocationContextInterface):
|
||||
board_id_ = self._data.invocation.board.board_id
|
||||
|
||||
if self._services.configuration.nsfw_check:
|
||||
image = SafetyChecker.blur_if_nsfw(image)
|
||||
SafetyChecker.raise_if_nsfw(image)
|
||||
|
||||
if self._services.configuration.watermark:
|
||||
image = InvisibleWatermark.add_watermark(image, "InvokeAI")
|
||||
|
@ -4,11 +4,9 @@ wraps the safety_checker model. It respects the global "nsfw_checker"
|
||||
configuration variable, that allows the checker to be supressed.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from PIL import Image, ImageFilter
|
||||
from PIL import Image
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
@ -20,10 +18,15 @@ repo_id = "CompVis/stable-diffusion-safety-checker"
|
||||
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
|
||||
|
||||
|
||||
class NSFWImageException(Exception):
|
||||
"""Raised when a NSFW image is detected."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("A potentially NSFW image has been detected.")
|
||||
|
||||
|
||||
class SafetyChecker:
|
||||
"""
|
||||
Wrapper around SafetyChecker model.
|
||||
"""
|
||||
"""Wrapper around SafetyChecker model."""
|
||||
|
||||
feature_extractor = None
|
||||
safety_checker = None
|
||||
@ -72,22 +75,9 @@ class SafetyChecker:
|
||||
return has_nsfw_concept[0]
|
||||
|
||||
@classmethod
|
||||
def blur_if_nsfw(cls, image: Image.Image) -> Image.Image:
|
||||
def raise_if_nsfw(cls, image: Image.Image) -> Image.Image:
|
||||
"""Raises an exception if the image contains NSFW content."""
|
||||
if cls.has_nsfw_concept(image):
|
||||
logger.warning("A potentially NSFW image has been detected. Image will be blurred.")
|
||||
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
||||
caution = cls._get_caution_img()
|
||||
# Center the caution image on the blurred image
|
||||
x = (blurry_image.width - caution.width) // 2
|
||||
y = (blurry_image.height - caution.height) // 2
|
||||
blurry_image.paste(caution, (x, y), caution)
|
||||
image = blurry_image
|
||||
raise NSFWImageException()
|
||||
|
||||
return image
|
||||
|
||||
@classmethod
|
||||
def _get_caution_img(cls) -> Image.Image:
|
||||
import invokeai.app.assets.images as image_assets
|
||||
|
||||
caution = Image.open(Path(image_assets.__path__[0]) / "caution.png")
|
||||
return caution.resize((caution.width // 2, caution.height // 2))
|
||||
|
Loading…
Reference in New Issue
Block a user