InvokeAI/invokeai/backend/safety_checker.py

78 lines
2.8 KiB
Python
Raw Normal View History

'''
SafetyChecker class - checks images against the StabilityAI NSFW filter
and blurs images that contain potential NSFW content.
'''
import diffusers
import numpy as np
import torch
import traceback
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from pathlib import Path
from PIL import Image, ImageFilter
from transformers import AutoFeatureExtractor
import invokeai.assets.web as web_assets
2023-04-29 13:43:40 +00:00
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
2023-03-11 21:16:44 +00:00
from .util import CPU_DEVICE
config = InvokeAIAppConfig.get_config()
class SafetyChecker(object):
CAUTION_IMG = "caution.png"
def __init__(self, device: torch.device):
2023-03-11 21:16:44 +00:00
path = Path(web_assets.__path__[0]) / self.CAUTION_IMG
caution = Image.open(path)
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
self.device = device
try:
safety_model_id = config.models_path / 'core/convert/stable-diffusion-safety-checker'
feature_extractor_id = config.models_path / 'core/convert/stable-diffusion-safety-checker-extractor'
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_id)
except Exception:
2023-04-29 13:43:40 +00:00
logger.error(
"An error was encountered while installing the safety checker:"
)
print(traceback.format_exc())
def check(self, image: Image.Image):
"""
Check provided image against the StabilityAI safety checker and return
"""
2023-03-11 21:16:44 +00:00
self.safety_checker.to(self.device)
features = self.safety_feature_extractor([image], return_tensors="pt")
2023-03-11 21:16:44 +00:00
features.to(self.device)
# unfortunately checker requires the numpy version, so we have to convert back
x_image = np.array(image).astype(np.float32) / 255.0
x_image = x_image[None].transpose(0, 3, 1, 2)
diffusers.logging.set_verbosity_error()
checked_image, has_nsfw_concept = self.safety_checker(
images=x_image, clip_input=features.pixel_values
)
2023-03-11 21:16:44 +00:00
self.safety_checker.to(CPU_DEVICE) # offload
if has_nsfw_concept[0]:
2023-04-29 13:43:40 +00:00
logger.warning(
"An image with potential non-safe content has been detected. A blurred image will be returned."
)
return self.blur(image)
else:
return image
def blur(self, input):
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
try:
2023-03-11 21:16:44 +00:00
if caution := self.caution_img:
blurry.paste(caution, (0, 0), caution)
except FileNotFoundError:
pass
return blurry