mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add support for safety checker (NSFW filter)
Now you can activate the Hugging Face `diffusers` library safety check for NSFW and other potentially disturbing imagery. To turn on the safety check, pass --safety_checker at the command line. For developers, the flag is `safety_checker=True` passed to ldm.generate.Generate(). Once the safety checker is turned on, it cannot be turned off unless you reinitialize a new Generate object. When the safety checker is active, suspect images will be blurred and a warning icon is added. There is also a warning message printed in the CLI, but it can be a little hard to see because of its positioning in the output stream. There is a slight but noticeable delay when the safety checker runs. Note that invisible watermarking is *not* currently implemented. The watermark code distributed by the CompViz distribution uses a library that does not seem to be able to retrieve the watermarks it creates, and it does not appear that Hugging Face `diffusers` or other SD distributions are doing any watermarking.
This commit is contained in:
@ -132,20 +132,21 @@ class Generate:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model = None,
|
||||
conf = 'configs/models.yaml',
|
||||
embedding_path = None,
|
||||
sampler_name = 'k_lms',
|
||||
ddim_eta = 0.0, # deterministic
|
||||
full_precision = False,
|
||||
precision = 'auto',
|
||||
# these are deprecated; if present they override values in the conf file
|
||||
weights = None,
|
||||
config = None,
|
||||
model = None,
|
||||
conf = 'configs/models.yaml',
|
||||
embedding_path = None,
|
||||
sampler_name = 'k_lms',
|
||||
ddim_eta = 0.0, # deterministic
|
||||
full_precision = False,
|
||||
precision = 'auto',
|
||||
gfpgan=None,
|
||||
codeformer=None,
|
||||
esrgan=None,
|
||||
free_gpu_mem=False,
|
||||
safety_checker:bool=False,
|
||||
# these are deprecated; if present they override values in the conf file
|
||||
weights = None,
|
||||
config = None,
|
||||
):
|
||||
mconfig = OmegaConf.load(conf)
|
||||
self.height = None
|
||||
@ -176,6 +177,7 @@ class Generate:
|
||||
self.free_gpu_mem = free_gpu_mem
|
||||
self.size_matters = True # used to warn once about large image sizes and VRAM
|
||||
self.txt2mask = None
|
||||
self.safety_checker = None
|
||||
|
||||
# Note that in previous versions, there was an option to pass the
|
||||
# device to Generate(). However the device was then ignored, so
|
||||
@ -203,6 +205,19 @@ class Generate:
|
||||
# gets rid of annoying messages about random seed
|
||||
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
|
||||
|
||||
# load safety checker if requested
|
||||
if safety_checker:
|
||||
try:
|
||||
print('>> Initializing safety checker')
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from transformers import AutoFeatureExtractor
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, local_files_only=True)
|
||||
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, local_files_only=True)
|
||||
except Exception:
|
||||
print('** An error was encountered while installing the safety checker:')
|
||||
print(traceback.format_exc())
|
||||
|
||||
def prompt2png(self, prompt, outdir, **kwargs):
|
||||
"""
|
||||
Takes a prompt and an output directory, writes out the requested number
|
||||
@ -418,6 +433,11 @@ class Generate:
|
||||
self.seed, variation_amount, with_variations
|
||||
)
|
||||
|
||||
checker = {
|
||||
'checker':self.safety_checker,
|
||||
'extractor':self.safety_feature_extractor
|
||||
} if self.safety_checker else None
|
||||
|
||||
results = generator.generate(
|
||||
prompt,
|
||||
iterations=iterations,
|
||||
@ -428,10 +448,10 @@ class Generate:
|
||||
conditioning=(uc, c),
|
||||
ddim_eta=ddim_eta,
|
||||
image_callback=image_callback, # called after the final image is generated
|
||||
step_callback=step_callback, # called after each intermediate image is generated
|
||||
step_callback=step_callback, # called after each intermediate image is generated
|
||||
width=width,
|
||||
height=height,
|
||||
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
|
||||
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
|
||||
init_image=init_image, # notice that init_image is different from init_img
|
||||
mask_image=mask_image,
|
||||
strength=strength,
|
||||
@ -440,7 +460,8 @@ class Generate:
|
||||
embiggen=embiggen,
|
||||
embiggen_tiles=embiggen_tiles,
|
||||
inpaint_replace=inpaint_replace,
|
||||
mask_blur_radius=mask_blur_radius
|
||||
mask_blur_radius=mask_blur_radius,
|
||||
safety_checker=checker
|
||||
)
|
||||
|
||||
if init_color:
|
||||
|
Reference in New Issue
Block a user