feat(nodes): use new blur_if_nsfw method

This commit is contained in:
psychedelicious 2024-05-13 18:33:25 +10:00
parent 9c819f0fd8
commit 93da75209c
2 changed files with 3 additions and 15 deletions

View File

@ -1,6 +1,5 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from pathlib import Path
from typing import Literal, Optional from typing import Literal, Optional
import cv2 import cv2
@ -504,7 +503,7 @@ class ImageInverseLerpInvocation(BaseInvocation, WithMetadata, WithBoard):
title="Blur NSFW Image", title="Blur NSFW Image",
tags=["image", "nsfw"], tags=["image", "nsfw"],
category="image", category="image",
version="1.2.2", version="1.2.3",
) )
class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithBoard): class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Add blur to NSFW-flagged images""" """Add blur to NSFW-flagged images"""
@ -516,23 +515,12 @@ class ImageNSFWBlurInvocation(BaseInvocation, WithMetadata, WithBoard):
logger = context.logger logger = context.logger
logger.debug("Running NSFW checker") logger.debug("Running NSFW checker")
if SafetyChecker.has_nsfw_concept(image): image = SafetyChecker.blur_if_nsfw(image)
logger.info("A potentially NSFW image has been detected. Image will be blurred.")
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
caution = self._get_caution_img()
blurry_image.paste(caution, (0, 0), caution)
image = blurry_image
image_dto = context.images.save(image=image) image_dto = context.images.save(image=image)
return ImageOutput.build(image_dto) return ImageOutput.build(image_dto)
def _get_caution_img(self) -> Image.Image:
import invokeai.app.assets.images as image_assets
caution = Image.open(Path(image_assets.__path__[0]) / "caution.png")
return caution.resize((caution.width // 2, caution.height // 2))
@invocation( @invocation(
"img_watermark", "img_watermark",

View File

@ -65,7 +65,7 @@ class SafetyChecker:
@classmethod @classmethod
def blur_if_nsfw(cls, image: Image.Image) -> Image.Image: def blur_if_nsfw(cls, image: Image.Image) -> Image.Image:
if cls.has_nsfw_concept(image): if cls.has_nsfw_concept(image):
logger.info("A potentially NSFW image has been detected. Image will be blurred.") logger.warning("A potentially NSFW image has been detected. Image will be blurred.")
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32)) blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
caution = cls._get_caution_img() caution = cls._get_caution_img()
# Center the caution image on the blurred image # Center the caution image on the blurred image