mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(backend): nsfw checker on non-rgb images
This commit is contained in:
parent
1ca583e407
commit
8bddcce2b0
@ -56,16 +56,25 @@ class SafetyChecker:
|
||||
features = cls.feature_extractor([image], return_tensors="pt")
|
||||
features.to(device)
|
||||
cls.safety_checker.to(device)
|
||||
x_image = np.array(image).astype(np.float32) / 255.0
|
||||
|
||||
# Only RGB(A) images are supported, so to prevent an error when a NSFW concept is detect in an image with
|
||||
# a different image mode, we _must_ convert it to RGB and then to a normalized, batched np array.
|
||||
rgb_image = image.convert("RGB")
|
||||
# Convert to normalized (0-1) np array
|
||||
x_image = np.array(rgb_image).astype(np.float32) / 255.0
|
||||
# Add batch dimension and transpose to NCHW
|
||||
x_image = x_image[None].transpose(0, 3, 1, 2)
|
||||
# A warning is logged if a NSFW concept is detected - silence those, so we can handle it ourselves.
|
||||
with SilenceWarnings():
|
||||
# `clip_input` (features) is used to check for NSFW concepts. `images` is required, but it isn't actually
|
||||
# checked for NSFW concepts. If a NSFW concept is detected, the the image is replaced with a black image.
|
||||
checked_image, has_nsfw_concept = cls.safety_checker(images=x_image, clip_input=features.pixel_values)
|
||||
return has_nsfw_concept[0]
|
||||
|
||||
@classmethod
|
||||
def blur_if_nsfw(cls, image: Image.Image) -> Image.Image:
|
||||
if cls.has_nsfw_concept(image):
|
||||
logger.info("A potentially NSFW image has been detected. Image will be blurred.")
|
||||
logger.warning("A potentially NSFW image has been detected. Image will be blurred.")
|
||||
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
||||
caution = cls._get_caution_img()
|
||||
# Center the caution image on the blurred image
|
||||
|
Loading…
Reference in New Issue
Block a user