add support for safety checker (NSFW filter)

Now you can activate the Hugging Face `diffusers` library safety check
for NSFW and other potentially disturbing imagery.

To turn on the safety check, pass --safety_checker at the command
line. For developers, the flag is `safety_checker=True` passed to
ldm.generate.Generate(). Once the safety checker is turned on, it
cannot be turned off unless you reinitialize a new Generate object.

When the safety checker is active, suspect images will be blurred and
a warning icon is added. There is also a warning message printed in
the CLI, but it can be a little hard to see because of its positioning
in the output stream.

There is a slight but noticeable delay when the safety checker runs.

Note that invisible watermarking is *not* currently implemented. The
watermark code distributed by the CompViz distribution uses a library
that does not seem to be able to retrieve the watermarks it creates,
and it does not appear that Hugging Face `diffusers` or other SD
distributions are doing any watermarking.
This commit is contained in:
Lincoln Stein
2022-10-23 22:26:18 -04:00
parent b7ce5b4f1b
commit b159b2fe42
10 changed files with 195 additions and 94 deletions

View File

@ -132,20 +132,21 @@ class Generate:
def __init__(
self,
model = None,
conf = 'configs/models.yaml',
embedding_path = None,
sampler_name = 'k_lms',
ddim_eta = 0.0, # deterministic
full_precision = False,
precision = 'auto',
# these are deprecated; if present they override values in the conf file
weights = None,
config = None,
model = None,
conf = 'configs/models.yaml',
embedding_path = None,
sampler_name = 'k_lms',
ddim_eta = 0.0, # deterministic
full_precision = False,
precision = 'auto',
gfpgan=None,
codeformer=None,
esrgan=None,
free_gpu_mem=False,
safety_checker:bool=False,
# these are deprecated; if present they override values in the conf file
weights = None,
config = None,
):
mconfig = OmegaConf.load(conf)
self.height = None
@ -176,6 +177,7 @@ class Generate:
self.free_gpu_mem = free_gpu_mem
self.size_matters = True # used to warn once about large image sizes and VRAM
self.txt2mask = None
self.safety_checker = None
# Note that in previous versions, there was an option to pass the
# device to Generate(). However the device was then ignored, so
@ -203,6 +205,19 @@ class Generate:
# gets rid of annoying messages about random seed
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
# load safety checker if requested
if safety_checker:
try:
print('>> Initializing safety checker')
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
safety_model_id = "CompVis/stable-diffusion-safety-checker"
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, local_files_only=True)
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, local_files_only=True)
except Exception:
print('** An error was encountered while installing the safety checker:')
print(traceback.format_exc())
def prompt2png(self, prompt, outdir, **kwargs):
"""
Takes a prompt and an output directory, writes out the requested number
@ -418,6 +433,11 @@ class Generate:
self.seed, variation_amount, with_variations
)
checker = {
'checker':self.safety_checker,
'extractor':self.safety_feature_extractor
} if self.safety_checker else None
results = generator.generate(
prompt,
iterations=iterations,
@ -428,10 +448,10 @@ class Generate:
conditioning=(uc, c),
ddim_eta=ddim_eta,
image_callback=image_callback, # called after the final image is generated
step_callback=step_callback, # called after each intermediate image is generated
step_callback=step_callback, # called after each intermediate image is generated
width=width,
height=height,
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
init_image=init_image, # notice that init_image is different from init_img
mask_image=mask_image,
strength=strength,
@ -440,7 +460,8 @@ class Generate:
embiggen=embiggen,
embiggen_tiles=embiggen_tiles,
inpaint_replace=inpaint_replace,
mask_blur_radius=mask_blur_radius
mask_blur_radius=mask_blur_radius,
safety_checker=checker
)
if init_color: