add support for safety checker (NSFW filter)

Now you can activate the Hugging Face `diffusers` library safety check
for NSFW and other potentially disturbing imagery.

To turn on the safety check, pass --safety_checker at the command
line. For developers, the flag is `safety_checker=True` passed to
ldm.generate.Generate(). Once the safety checker is turned on, it
cannot be turned off unless you reinitialize a new Generate object.

When the safety checker is active, suspect images will be blurred and
a warning icon is added. There is also a warning message printed in
the CLI, but it can be a little hard to see because of its positioning
in the output stream.

There is a slight but noticeable delay when the safety checker runs.

Note that invisible watermarking is *not* currently implemented. The
watermark code distributed by the CompViz distribution uses a library
that does not seem to be able to retrieve the watermarks it creates,
and it does not appear that Hugging Face `diffusers` or other SD
distributions are doing any watermarking.
This commit is contained in:
Lincoln Stein
2022-10-23 22:26:18 -04:00
parent b7ce5b4f1b
commit b159b2fe42
10 changed files with 195 additions and 94 deletions

View File

@ -86,6 +86,7 @@ overridden on a per-prompt basis (see [List of prompt arguments](#list-of-prompt
| `--model <modelname>` | | `stable-diffusion-1.4` | Loads model specified in configs/models.yaml. Currently one of "stable-diffusion-1.4" or "laion400m" | | `--model <modelname>` | | `stable-diffusion-1.4` | Loads model specified in configs/models.yaml. Currently one of "stable-diffusion-1.4" or "laion400m" |
| `--full_precision` | `-F` | `False` | Run in slower full-precision mode. Needed for Macintosh M1/M2 hardware and some older video cards. | | `--full_precision` | `-F` | `False` | Run in slower full-precision mode. Needed for Macintosh M1/M2 hardware and some older video cards. |
| `--png_compression <0-9>` | `-z<0-9>` | 6 | Select level of compression for output files, from 0 (no compression) to 9 (max compression) | | `--png_compression <0-9>` | `-z<0-9>` | 6 | Select level of compression for output files, from 0 (no compression) to 9 (max compression) |
| `--safety-checker` | | False | Activate safety checker for NSFW and other potentially disturbing imagery |
| `--web` | | `False` | Start in web server mode | | `--web` | | `False` | Start in web server mode |
| `--host <ip addr>` | | `localhost` | Which network interface web server should listen on. Set to 0.0.0.0 to listen on any. | | `--host <ip addr>` | | `localhost` | Which network interface web server should listen on. Set to 0.0.0.0 to listen on any. |
| `--port <port>` | | `9090` | Which port web server should listen for requests on. | | `--port <port>` | | `9090` | Which port web server should listen for requests on. |
@ -97,7 +98,6 @@ overridden on a per-prompt basis (see [List of prompt arguments](#list-of-prompt
| `--embedding_path <path>` | | `None` | Path to pre-trained embedding manager checkpoints, for custom models | | `--embedding_path <path>` | | `None` | Path to pre-trained embedding manager checkpoints, for custom models |
| `--gfpgan_dir` | | `src/gfpgan` | Path to where GFPGAN is installed. | | `--gfpgan_dir` | | `src/gfpgan` | Path to where GFPGAN is installed. |
| `--gfpgan_model_path` | | `experiments/pretrained_models/GFPGANv1.4.pth` | Path to GFPGAN model file, relative to `--gfpgan_dir`. | | `--gfpgan_model_path` | | `experiments/pretrained_models/GFPGANv1.4.pth` | Path to GFPGAN model file, relative to `--gfpgan_dir`. |
| `--device <device>` | `-d<device>` | `torch.cuda.current_device()` | Device to run SD on, e.g. "cuda:0" |
| `--free_gpu_mem` | | `False` | Free GPU memory after sampling, to allow image decoding and saving in low VRAM conditions | | `--free_gpu_mem` | | `False` | Free GPU memory after sampling, to allow image decoding and saving in low VRAM conditions |
| `--precision` | | `auto` | Set model precision, default is selected by device. Options: auto, float32, float16, autocast | | `--precision` | | `auto` | Set model precision, default is selected by device. Options: auto, float32, float16, autocast |

View File

@ -19,6 +19,7 @@ dependencies:
# ``` # ```
- albumentations==1.2.1 - albumentations==1.2.1
- coloredlogs==15.0.1 - coloredlogs==15.0.1
- diffusers==0.6.0
- einops==0.4.1 - einops==0.4.1
- grpcio==1.46.4 - grpcio==1.46.4
- humanfriendly==10.0 - humanfriendly==10.0

View File

@ -26,6 +26,7 @@ dependencies:
- pyreadline3 - pyreadline3
- torch-fidelity==0.3.0 - torch-fidelity==0.3.0
- transformers==4.21.3 - transformers==4.21.3
- diffusers==0.6.0
- torchmetrics==0.7.0 - torchmetrics==0.7.0
- flask==2.1.3 - flask==2.1.3
- flask_socketio==5.3.0 - flask_socketio==5.3.0

View File

@ -139,13 +139,14 @@ class Generate:
ddim_eta = 0.0, # deterministic ddim_eta = 0.0, # deterministic
full_precision = False, full_precision = False,
precision = 'auto', precision = 'auto',
# these are deprecated; if present they override values in the conf file
weights = None,
config = None,
gfpgan=None, gfpgan=None,
codeformer=None, codeformer=None,
esrgan=None, esrgan=None,
free_gpu_mem=False, free_gpu_mem=False,
safety_checker:bool=False,
# these are deprecated; if present they override values in the conf file
weights = None,
config = None,
): ):
mconfig = OmegaConf.load(conf) mconfig = OmegaConf.load(conf)
self.height = None self.height = None
@ -176,6 +177,7 @@ class Generate:
self.free_gpu_mem = free_gpu_mem self.free_gpu_mem = free_gpu_mem
self.size_matters = True # used to warn once about large image sizes and VRAM self.size_matters = True # used to warn once about large image sizes and VRAM
self.txt2mask = None self.txt2mask = None
self.safety_checker = None
# Note that in previous versions, there was an option to pass the # Note that in previous versions, there was an option to pass the
# device to Generate(). However the device was then ignored, so # device to Generate(). However the device was then ignored, so
@ -203,6 +205,19 @@ class Generate:
# gets rid of annoying messages about random seed # gets rid of annoying messages about random seed
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR) logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
# load safety checker if requested
if safety_checker:
try:
print('>> Initializing safety checker')
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
safety_model_id = "CompVis/stable-diffusion-safety-checker"
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, local_files_only=True)
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, local_files_only=True)
except Exception:
print('** An error was encountered while installing the safety checker:')
print(traceback.format_exc())
def prompt2png(self, prompt, outdir, **kwargs): def prompt2png(self, prompt, outdir, **kwargs):
""" """
Takes a prompt and an output directory, writes out the requested number Takes a prompt and an output directory, writes out the requested number
@ -418,6 +433,11 @@ class Generate:
self.seed, variation_amount, with_variations self.seed, variation_amount, with_variations
) )
checker = {
'checker':self.safety_checker,
'extractor':self.safety_feature_extractor
} if self.safety_checker else None
results = generator.generate( results = generator.generate(
prompt, prompt,
iterations=iterations, iterations=iterations,
@ -440,7 +460,8 @@ class Generate:
embiggen=embiggen, embiggen=embiggen,
embiggen_tiles=embiggen_tiles, embiggen_tiles=embiggen_tiles,
inpaint_replace=inpaint_replace, inpaint_replace=inpaint_replace,
mask_blur_radius=mask_blur_radius mask_blur_radius=mask_blur_radius,
safety_checker=checker
) )
if init_color: if init_color:

View File

@ -418,6 +418,11 @@ class Args(object):
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}', help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
default='auto', default='auto',
) )
model_group.add_argument(
'--safety_checker',
action='store_true',
help='Check for and blur potentially NSFW images',
)
file_group.add_argument( file_group.add_argument(
'--from_file', '--from_file',
dest='infile', dest='infile',

View File

@ -7,13 +7,14 @@ import numpy as np
import random import random
import os import os
from tqdm import tqdm, trange from tqdm import tqdm, trange
from PIL import Image from PIL import Image, ImageFilter
from einops import rearrange, repeat from einops import rearrange, repeat
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from ldm.invoke.devices import choose_autocast from ldm.invoke.devices import choose_autocast
from ldm.util import rand_perlin_2d from ldm.util import rand_perlin_2d
downsampling = 8 downsampling = 8
CAUTION_IMG = 'assets/caution.png'
class Generator(): class Generator():
def __init__(self, model, precision): def __init__(self, model, precision):
@ -22,6 +23,7 @@ class Generator():
self.seed = None self.seed = None
self.latent_channels = model.channels self.latent_channels = model.channels
self.downsampling_factor = downsampling # BUG: should come from model or config self.downsampling_factor = downsampling # BUG: should come from model or config
self.safety_checker = None
self.perlin = 0.0 self.perlin = 0.0
self.threshold = 0 self.threshold = 0
self.variation_amount = 0 self.variation_amount = 0
@ -42,8 +44,10 @@ class Generator():
def generate(self,prompt,init_image,width,height,iterations=1,seed=None, def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0, image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
safety_checker:dict=None,
**kwargs): **kwargs):
scope = choose_autocast(self.precision) scope = choose_autocast(self.precision)
self.safety_checker = safety_checker
make_image = self.get_make_image( make_image = self.get_make_image(
prompt, prompt,
init_image = init_image, init_image = init_image,
@ -77,10 +81,17 @@ class Generator():
pass pass
image = make_image(x_T) image = make_image(x_T)
if self.safety_checker is not None:
image = self.safety_check(image)
results.append([image, seed]) results.append([image, seed])
if image_callback is not None: if image_callback is not None:
image_callback(image, seed, first_seed=first_seed) image_callback(image, seed, first_seed=first_seed)
seed = self.new_seed() seed = self.new_seed()
return results return results
def sample_to_image(self,samples): def sample_to_image(self,samples):
@ -169,6 +180,39 @@ class Generator():
return v2 return v2
def safety_check(self,image:Image.Image):
'''
If the CompViz safety checker flags an NSFW image, we
blur it out.
'''
import diffusers
checker = self.safety_checker['checker']
extractor = self.safety_checker['extractor']
features = extractor([image], return_tensors="pt")
# unfortunately checker requires the numpy version, so we have to convert back
x_image = np.array(image).astype(np.float32) / 255.0
x_image = x_image[None].transpose(0, 3, 1, 2)
diffusers.logging.set_verbosity_error()
checked_image, has_nsfw_concept = checker(images=x_image, clip_input=features.pixel_values)
if has_nsfw_concept[0]:
print('** An image with potential non-safe content has been detected. A blurred image will be returned. **')
return self.blur(image)
else:
return image
def blur(self,input):
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
try:
caution = Image.open(CAUTION_IMG)
caution = caution.resize((caution.width // 2, caution.height //2))
blurry.paste(caution,(0,0),caution)
except FileNotFoundError:
pass
return blurry
# this is a handy routine for debugging use. Given a generated sample, # this is a handy routine for debugging use. Given a generated sample,
# convert it into a PNG image and store it at the indicated path # convert it into a PNG image and store it at the indicated path
def save_sample(self, sample, filepath): def save_sample(self, sample, filepath):

View File

@ -1,5 +1,6 @@
albumentations==0.4.3 albumentations==0.4.3
einops==0.3.0 einops==0.3.0
diffusers==0.6.0
huggingface-hub==0.8.1 huggingface-hub==0.8.1
imageio==2.9.0 imageio==2.9.0
imageio-ffmpeg==0.4.2 imageio-ffmpeg==0.4.2

View File

@ -32,6 +32,7 @@ send2trash
dependency_injector==4.40.0 dependency_injector==4.40.0
eventlet eventlet
realesrgan realesrgan
diffusers
git+https://github.com/openai/CLIP.git@main#egg=clip git+https://github.com/openai/CLIP.git@main#egg=clip
git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan

View File

@ -79,6 +79,7 @@ def main():
codeformer=codeformer, codeformer=codeformer,
esrgan=esrgan, esrgan=esrgan,
free_gpu_mem=opt.free_gpu_mem, free_gpu_mem=opt.free_gpu_mem,
safety_checker=opt.safety_checker,
) )
except (FileNotFoundError, IOError, KeyError) as e: except (FileNotFoundError, IOError, KeyError) as e:
print(f'{e}. Aborting.') print(f'{e}. Aborting.')

View File

@ -5,7 +5,7 @@
# two machines must share a common .cache directory. # two machines must share a common .cache directory.
from transformers import CLIPTokenizer, CLIPTextModel from transformers import CLIPTokenizer, CLIPTextModel
import clip import clip
from transformers import BertTokenizerFast from transformers import BertTokenizerFast, AutoFeatureExtractor
import sys import sys
import transformers import transformers
import os import os
@ -17,21 +17,27 @@ import traceback
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
#---------------------------------------------
# this will preload the Bert tokenizer fles # this will preload the Bert tokenizer fles
print('Loading bert tokenizer (ignore deprecation errors)...', end='') def download_bert():
print('Installing bert tokenizer (ignore deprecation errors)...', end='')
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning) warnings.filterwarnings('ignore', category=DeprecationWarning)
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
print('...success') print('...success')
sys.stdout.flush() sys.stdout.flush()
#---------------------------------------------
# this will download requirements for Kornia # this will download requirements for Kornia
print('Loading Kornia requirements...', end='') def download_kornia():
print('Installing Kornia requirements...', end='')
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning) warnings.filterwarnings('ignore', category=DeprecationWarning)
import kornia import kornia
print('...success') print('...success')
#---------------------------------------------
def download_clip():
version = 'openai/clip-vit-large-patch14' version = 'openai/clip-vit-large-patch14'
sys.stdout.flush() sys.stdout.flush()
print('Loading CLIP model...',end='') print('Loading CLIP model...',end='')
@ -39,19 +45,11 @@ tokenizer = CLIPTokenizer.from_pretrained(version)
transformer = CLIPTextModel.from_pretrained(version) transformer = CLIPTextModel.from_pretrained(version)
print('...success') print('...success')
# In the event that the user has installed GFPGAN and also elected to use #---------------------------------------------
# RealESRGAN, this will attempt to download the model needed by RealESRGANer def download_gfpgan():
gfpgan = False print('Installing models from RealESRGAN and facexlib...',end='')
try: try:
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
gfpgan = True
except ModuleNotFoundError:
pass
if gfpgan:
print('Loading models from RealESRGAN and facexlib...',end='')
try:
from realesrgan.archs.srvgg_arch import SRVGGNetCompact from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from facexlib.utils.face_restoration_helper import FaceRestoreHelper from facexlib.utils.face_restoration_helper import FaceRestoreHelper
@ -94,7 +92,9 @@ if gfpgan:
print('Error loading GFPGAN:') print('Error loading GFPGAN:')
print(traceback.format_exc()) print(traceback.format_exc())
print('preloading CodeFormer model file...',end='') #---------------------------------------------
def download_codeformer():
print('Installing CodeFormer model file...',end='')
try: try:
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
model_dest = 'ldm/invoke/restoration/codeformer/weights/codeformer.pth' model_dest = 'ldm/invoke/restoration/codeformer/weights/codeformer.pth'
@ -107,7 +107,9 @@ except Exception:
print(traceback.format_exc()) print(traceback.format_exc())
print('...success') print('...success')
print('Loading clipseg model for text-based masking...',end='') #---------------------------------------------
def download_clipseg():
print('Installing clipseg model for text-based masking...',end='')
try: try:
model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download' model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download'
model_dest = 'src/clipseg/clipseg_weights.zip' model_dest = 'src/clipseg/clipseg_weights.zip'
@ -134,4 +136,28 @@ except Exception:
print(traceback.format_exc()) print(traceback.format_exc())
print('...success') print('...success')
#-------------------------------------
def download_safety_checker():
print('Installing safety model for NSFW content detection...',end='')
try:
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
except ModuleNotFoundError:
print('Error installing safety checker model:')
print(traceback.format_exc())
return
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
print('...success')
#-------------------------------------
if __name__ == '__main__':
download_bert()
download_kornia()
download_clip()
download_gfpgan()
download_codeformer()
download_clipseg()
download_safety_checker()