mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
84 lines
2.9 KiB
Python
84 lines
2.9 KiB
Python
'''
|
|
SafetyChecker class - checks images against the StabilityAI NSFW filter
|
|
and blurs images that contain potential NSFW content.
|
|
'''
|
|
import diffusers
|
|
import numpy as np
|
|
import torch
|
|
import traceback
|
|
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
|
StableDiffusionSafetyChecker,
|
|
)
|
|
from pathlib import Path
|
|
from PIL import Image, ImageFilter
|
|
from transformers import AutoFeatureExtractor
|
|
|
|
import invokeai.assets.web as web_assets
|
|
import invokeai.backend.util.logging as logger
|
|
from .globals import global_cache_dir
|
|
from .util import CPU_DEVICE
|
|
|
|
class SafetyChecker(object):
|
|
CAUTION_IMG = "caution.png"
|
|
|
|
def __init__(self, device: torch.device):
|
|
path = Path(web_assets.__path__[0]) / self.CAUTION_IMG
|
|
caution = Image.open(path)
|
|
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
|
|
self.device = device
|
|
|
|
try:
|
|
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
|
safety_model_path = global_cache_dir("hub")
|
|
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
|
safety_model_id,
|
|
local_files_only=True,
|
|
cache_dir=safety_model_path,
|
|
)
|
|
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(
|
|
safety_model_id,
|
|
local_files_only=True,
|
|
cache_dir=safety_model_path,
|
|
)
|
|
except Exception:
|
|
logger.error(
|
|
"An error was encountered while installing the safety checker:"
|
|
)
|
|
print(traceback.format_exc())
|
|
|
|
def check(self, image: Image.Image):
|
|
"""
|
|
Check provided image against the StabilityAI safety checker and return
|
|
|
|
"""
|
|
|
|
self.safety_checker.to(self.device)
|
|
features = self.safety_feature_extractor([image], return_tensors="pt")
|
|
features.to(self.device)
|
|
|
|
# unfortunately checker requires the numpy version, so we have to convert back
|
|
x_image = np.array(image).astype(np.float32) / 255.0
|
|
x_image = x_image[None].transpose(0, 3, 1, 2)
|
|
|
|
diffusers.logging.set_verbosity_error()
|
|
checked_image, has_nsfw_concept = self.safety_checker(
|
|
images=x_image, clip_input=features.pixel_values
|
|
)
|
|
self.safety_checker.to(CPU_DEVICE) # offload
|
|
if has_nsfw_concept[0]:
|
|
logger.warning(
|
|
"An image with potential non-safe content has been detected. A blurred image will be returned."
|
|
)
|
|
return self.blur(image)
|
|
else:
|
|
return image
|
|
|
|
def blur(self, input):
|
|
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
|
try:
|
|
if caution := self.caution_img:
|
|
blurry.paste(caution, (0, 0), caution)
|
|
except FileNotFoundError:
|
|
pass
|
|
return blurry
|