fix(backend): nsfw checker on non-rgb images

This commit is contained in:
psychedelicious 2024-04-24 08:19:08 +10:00
parent 1ca583e407
commit 8bddcce2b0

View File

@ -56,16 +56,25 @@ class SafetyChecker:
features = cls.feature_extractor([image], return_tensors="pt")
features.to(device)
cls.safety_checker.to(device)
x_image = np.array(image).astype(np.float32) / 255.0
# Only RGB(A) images are supported, so to prevent an error when a NSFW concept is detect in an image with
# a different image mode, we _must_ convert it to RGB and then to a normalized, batched np array.
rgb_image = image.convert("RGB")
# Convert to normalized (0-1) np array
x_image = np.array(rgb_image).astype(np.float32) / 255.0
# Add batch dimension and transpose to NCHW
x_image = x_image[None].transpose(0, 3, 1, 2)
# A warning is logged if a NSFW concept is detected - silence those, so we can handle it ourselves.
with SilenceWarnings():
# `clip_input` (features) is used to check for NSFW concepts. `images` is required, but it isn't actually
# checked for NSFW concepts. If a NSFW concept is detected, the the image is replaced with a black image.
checked_image, has_nsfw_concept = cls.safety_checker(images=x_image, clip_input=features.pixel_values)
return has_nsfw_concept[0]
@classmethod
def blur_if_nsfw(cls, image: Image.Image) -> Image.Image:
if cls.has_nsfw_concept(image):
logger.info("A potentially NSFW image has been detected. Image will be blurred.")
logger.warning("A potentially NSFW image has been detected. Image will be blurred.")
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
caution = cls._get_caution_img()
# Center the caution image on the blurred image