mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add support for safety checker (NSFW filter)
Now you can activate the Hugging Face `diffusers` library safety check for NSFW and other potentially disturbing imagery. To turn on the safety check, pass --safety_checker at the command line. For developers, the flag is `safety_checker=True` passed to ldm.generate.Generate(). Once the safety checker is turned on, it cannot be turned off unless you reinitialize a new Generate object. When the safety checker is active, suspect images will be blurred and a warning icon is added. There is also a warning message printed in the CLI, but it can be a little hard to see because of its positioning in the output stream. There is a slight but noticeable delay when the safety checker runs. Note that invisible watermarking is *not* currently implemented. The watermark code distributed by the CompViz distribution uses a library that does not seem to be able to retrieve the watermarks it creates, and it does not appear that Hugging Face `diffusers` or other SD distributions are doing any watermarking.
This commit is contained in:
@ -69,16 +69,17 @@ def main():
|
||||
# creating a Generate object:
|
||||
try:
|
||||
gen = Generate(
|
||||
conf = opt.conf,
|
||||
model = opt.model,
|
||||
sampler_name = opt.sampler_name,
|
||||
conf = opt.conf,
|
||||
model = opt.model,
|
||||
sampler_name = opt.sampler_name,
|
||||
embedding_path = opt.embedding_path,
|
||||
full_precision = opt.full_precision,
|
||||
precision = opt.precision,
|
||||
precision = opt.precision,
|
||||
gfpgan=gfpgan,
|
||||
codeformer=codeformer,
|
||||
esrgan=esrgan,
|
||||
free_gpu_mem=opt.free_gpu_mem,
|
||||
safety_checker=opt.safety_checker,
|
||||
)
|
||||
except (FileNotFoundError, IOError, KeyError) as e:
|
||||
print(f'{e}. Aborting.')
|
||||
|
@ -5,7 +5,7 @@
|
||||
# two machines must share a common .cache directory.
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
import clip
|
||||
from transformers import BertTokenizerFast
|
||||
from transformers import BertTokenizerFast, AutoFeatureExtractor
|
||||
import sys
|
||||
import transformers
|
||||
import os
|
||||
@ -17,41 +17,39 @@ import traceback
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
#---------------------------------------------
|
||||
# this will preload the Bert tokenizer fles
|
||||
print('Loading bert tokenizer (ignore deprecation errors)...', end='')
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
||||
print('...success')
|
||||
sys.stdout.flush()
|
||||
def download_bert():
|
||||
print('Installing bert tokenizer (ignore deprecation errors)...', end='')
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
||||
print('...success')
|
||||
sys.stdout.flush()
|
||||
|
||||
#---------------------------------------------
|
||||
# this will download requirements for Kornia
|
||||
print('Loading Kornia requirements...', end='')
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
import kornia
|
||||
print('...success')
|
||||
def download_kornia():
|
||||
print('Installing Kornia requirements...', end='')
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
import kornia
|
||||
print('...success')
|
||||
|
||||
version = 'openai/clip-vit-large-patch14'
|
||||
sys.stdout.flush()
|
||||
print('Loading CLIP model...',end='')
|
||||
tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||
transformer = CLIPTextModel.from_pretrained(version)
|
||||
print('...success')
|
||||
#---------------------------------------------
|
||||
def download_clip():
|
||||
version = 'openai/clip-vit-large-patch14'
|
||||
sys.stdout.flush()
|
||||
print('Loading CLIP model...',end='')
|
||||
tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||
transformer = CLIPTextModel.from_pretrained(version)
|
||||
print('...success')
|
||||
|
||||
# In the event that the user has installed GFPGAN and also elected to use
|
||||
# RealESRGAN, this will attempt to download the model needed by RealESRGANer
|
||||
gfpgan = False
|
||||
try:
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
gfpgan = True
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
if gfpgan:
|
||||
print('Loading models from RealESRGAN and facexlib...',end='')
|
||||
#---------------------------------------------
|
||||
def download_gfpgan():
|
||||
print('Installing models from RealESRGAN and facexlib...',end='')
|
||||
try:
|
||||
from realesrgan import RealESRGANer
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
|
||||
@ -94,44 +92,72 @@ if gfpgan:
|
||||
print('Error loading GFPGAN:')
|
||||
print(traceback.format_exc())
|
||||
|
||||
print('preloading CodeFormer model file...',end='')
|
||||
try:
|
||||
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||
model_dest = 'ldm/invoke/restoration/codeformer/weights/codeformer.pth'
|
||||
if not os.path.exists(model_dest):
|
||||
print('Downloading codeformer model file...')
|
||||
#---------------------------------------------
|
||||
def download_codeformer():
|
||||
print('Installing CodeFormer model file...',end='')
|
||||
try:
|
||||
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||
model_dest = 'ldm/invoke/restoration/codeformer/weights/codeformer.pth'
|
||||
if not os.path.exists(model_dest):
|
||||
print('Downloading codeformer model file...')
|
||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||
urllib.request.urlretrieve(model_url,model_dest)
|
||||
except Exception:
|
||||
print('Error loading CodeFormer:')
|
||||
print(traceback.format_exc())
|
||||
print('...success')
|
||||
|
||||
#---------------------------------------------
|
||||
def download_clipseg():
|
||||
print('Installing clipseg model for text-based masking...',end='')
|
||||
try:
|
||||
model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download'
|
||||
model_dest = 'src/clipseg/clipseg_weights.zip'
|
||||
weights_dir = 'src/clipseg/weights'
|
||||
if not os.path.exists(weights_dir):
|
||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||
urllib.request.urlretrieve(model_url,model_dest)
|
||||
except Exception:
|
||||
print('Error loading CodeFormer:')
|
||||
print(traceback.format_exc())
|
||||
print('...success')
|
||||
with zipfile.ZipFile(model_dest,'r') as zip:
|
||||
zip.extractall('src/clipseg')
|
||||
os.rename('src/clipseg/clipseg_weights','src/clipseg/weights')
|
||||
os.remove(model_dest)
|
||||
from clipseg_models.clipseg import CLIPDensePredT
|
||||
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, )
|
||||
model.eval()
|
||||
model.load_state_dict(
|
||||
torch.load(
|
||||
'src/clipseg/weights/rd64-uni-refined.pth',
|
||||
map_location=torch.device('cpu')
|
||||
),
|
||||
strict=False,
|
||||
)
|
||||
except Exception:
|
||||
print('Error installing clipseg model:')
|
||||
print(traceback.format_exc())
|
||||
print('...success')
|
||||
|
||||
print('Loading clipseg model for text-based masking...',end='')
|
||||
try:
|
||||
model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download'
|
||||
model_dest = 'src/clipseg/clipseg_weights.zip'
|
||||
weights_dir = 'src/clipseg/weights'
|
||||
if not os.path.exists(weights_dir):
|
||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||
urllib.request.urlretrieve(model_url,model_dest)
|
||||
with zipfile.ZipFile(model_dest,'r') as zip:
|
||||
zip.extractall('src/clipseg')
|
||||
os.rename('src/clipseg/clipseg_weights','src/clipseg/weights')
|
||||
os.remove(model_dest)
|
||||
from clipseg_models.clipseg import CLIPDensePredT
|
||||
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, )
|
||||
model.eval()
|
||||
model.load_state_dict(
|
||||
torch.load(
|
||||
'src/clipseg/weights/rd64-uni-refined.pth',
|
||||
map_location=torch.device('cpu')
|
||||
),
|
||||
strict=False,
|
||||
)
|
||||
except Exception:
|
||||
print('Error installing clipseg model:')
|
||||
print(traceback.format_exc())
|
||||
print('...success')
|
||||
#-------------------------------------
|
||||
def download_safety_checker():
|
||||
print('Installing safety model for NSFW content detection...',end='')
|
||||
try:
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
except ModuleNotFoundError:
|
||||
print('Error installing safety checker model:')
|
||||
print(traceback.format_exc())
|
||||
return
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
||||
print('...success')
|
||||
|
||||
#-------------------------------------
|
||||
if __name__ == '__main__':
|
||||
download_bert()
|
||||
download_kornia()
|
||||
download_clip()
|
||||
download_gfpgan()
|
||||
download_codeformer()
|
||||
download_clipseg()
|
||||
download_safety_checker()
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user