mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add support for safety checker (NSFW filter)
Now you can activate the Hugging Face `diffusers` library safety check for NSFW and other potentially disturbing imagery. To turn on the safety check, pass --safety_checker at the command line. For developers, the flag is `safety_checker=True` passed to ldm.generate.Generate(). Once the safety checker is turned on, it cannot be turned off unless you reinitialize a new Generate object. When the safety checker is active, suspect images will be blurred and a warning icon is added. There is also a warning message printed in the CLI, but it can be a little hard to see because of its positioning in the output stream. There is a slight but noticeable delay when the safety checker runs. Note that invisible watermarking is *not* currently implemented. The watermark code distributed by the CompViz distribution uses a library that does not seem to be able to retrieve the watermarks it creates, and it does not appear that Hugging Face `diffusers` or other SD distributions are doing any watermarking.
This commit is contained in:
@ -86,6 +86,7 @@ overridden on a per-prompt basis (see [List of prompt arguments](#list-of-prompt
|
|||||||
| `--model <modelname>` | | `stable-diffusion-1.4` | Loads model specified in configs/models.yaml. Currently one of "stable-diffusion-1.4" or "laion400m" |
|
| `--model <modelname>` | | `stable-diffusion-1.4` | Loads model specified in configs/models.yaml. Currently one of "stable-diffusion-1.4" or "laion400m" |
|
||||||
| `--full_precision` | `-F` | `False` | Run in slower full-precision mode. Needed for Macintosh M1/M2 hardware and some older video cards. |
|
| `--full_precision` | `-F` | `False` | Run in slower full-precision mode. Needed for Macintosh M1/M2 hardware and some older video cards. |
|
||||||
| `--png_compression <0-9>` | `-z<0-9>` | 6 | Select level of compression for output files, from 0 (no compression) to 9 (max compression) |
|
| `--png_compression <0-9>` | `-z<0-9>` | 6 | Select level of compression for output files, from 0 (no compression) to 9 (max compression) |
|
||||||
|
| `--safety-checker` | | False | Activate safety checker for NSFW and other potentially disturbing imagery |
|
||||||
| `--web` | | `False` | Start in web server mode |
|
| `--web` | | `False` | Start in web server mode |
|
||||||
| `--host <ip addr>` | | `localhost` | Which network interface web server should listen on. Set to 0.0.0.0 to listen on any. |
|
| `--host <ip addr>` | | `localhost` | Which network interface web server should listen on. Set to 0.0.0.0 to listen on any. |
|
||||||
| `--port <port>` | | `9090` | Which port web server should listen for requests on. |
|
| `--port <port>` | | `9090` | Which port web server should listen for requests on. |
|
||||||
@ -97,7 +98,6 @@ overridden on a per-prompt basis (see [List of prompt arguments](#list-of-prompt
|
|||||||
| `--embedding_path <path>` | | `None` | Path to pre-trained embedding manager checkpoints, for custom models |
|
| `--embedding_path <path>` | | `None` | Path to pre-trained embedding manager checkpoints, for custom models |
|
||||||
| `--gfpgan_dir` | | `src/gfpgan` | Path to where GFPGAN is installed. |
|
| `--gfpgan_dir` | | `src/gfpgan` | Path to where GFPGAN is installed. |
|
||||||
| `--gfpgan_model_path` | | `experiments/pretrained_models/GFPGANv1.4.pth` | Path to GFPGAN model file, relative to `--gfpgan_dir`. |
|
| `--gfpgan_model_path` | | `experiments/pretrained_models/GFPGANv1.4.pth` | Path to GFPGAN model file, relative to `--gfpgan_dir`. |
|
||||||
| `--device <device>` | `-d<device>` | `torch.cuda.current_device()` | Device to run SD on, e.g. "cuda:0" |
|
|
||||||
| `--free_gpu_mem` | | `False` | Free GPU memory after sampling, to allow image decoding and saving in low VRAM conditions |
|
| `--free_gpu_mem` | | `False` | Free GPU memory after sampling, to allow image decoding and saving in low VRAM conditions |
|
||||||
| `--precision` | | `auto` | Set model precision, default is selected by device. Options: auto, float32, float16, autocast |
|
| `--precision` | | `auto` | Set model precision, default is selected by device. Options: auto, float32, float16, autocast |
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@ dependencies:
|
|||||||
# ```
|
# ```
|
||||||
- albumentations==1.2.1
|
- albumentations==1.2.1
|
||||||
- coloredlogs==15.0.1
|
- coloredlogs==15.0.1
|
||||||
|
- diffusers==0.6.0
|
||||||
- einops==0.4.1
|
- einops==0.4.1
|
||||||
- grpcio==1.46.4
|
- grpcio==1.46.4
|
||||||
- humanfriendly==10.0
|
- humanfriendly==10.0
|
||||||
|
@ -26,6 +26,7 @@ dependencies:
|
|||||||
- pyreadline3
|
- pyreadline3
|
||||||
- torch-fidelity==0.3.0
|
- torch-fidelity==0.3.0
|
||||||
- transformers==4.21.3
|
- transformers==4.21.3
|
||||||
|
- diffusers==0.6.0
|
||||||
- torchmetrics==0.7.0
|
- torchmetrics==0.7.0
|
||||||
- flask==2.1.3
|
- flask==2.1.3
|
||||||
- flask_socketio==5.3.0
|
- flask_socketio==5.3.0
|
||||||
|
@ -132,20 +132,21 @@ class Generate:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model = None,
|
model = None,
|
||||||
conf = 'configs/models.yaml',
|
conf = 'configs/models.yaml',
|
||||||
embedding_path = None,
|
embedding_path = None,
|
||||||
sampler_name = 'k_lms',
|
sampler_name = 'k_lms',
|
||||||
ddim_eta = 0.0, # deterministic
|
ddim_eta = 0.0, # deterministic
|
||||||
full_precision = False,
|
full_precision = False,
|
||||||
precision = 'auto',
|
precision = 'auto',
|
||||||
# these are deprecated; if present they override values in the conf file
|
|
||||||
weights = None,
|
|
||||||
config = None,
|
|
||||||
gfpgan=None,
|
gfpgan=None,
|
||||||
codeformer=None,
|
codeformer=None,
|
||||||
esrgan=None,
|
esrgan=None,
|
||||||
free_gpu_mem=False,
|
free_gpu_mem=False,
|
||||||
|
safety_checker:bool=False,
|
||||||
|
# these are deprecated; if present they override values in the conf file
|
||||||
|
weights = None,
|
||||||
|
config = None,
|
||||||
):
|
):
|
||||||
mconfig = OmegaConf.load(conf)
|
mconfig = OmegaConf.load(conf)
|
||||||
self.height = None
|
self.height = None
|
||||||
@ -176,6 +177,7 @@ class Generate:
|
|||||||
self.free_gpu_mem = free_gpu_mem
|
self.free_gpu_mem = free_gpu_mem
|
||||||
self.size_matters = True # used to warn once about large image sizes and VRAM
|
self.size_matters = True # used to warn once about large image sizes and VRAM
|
||||||
self.txt2mask = None
|
self.txt2mask = None
|
||||||
|
self.safety_checker = None
|
||||||
|
|
||||||
# Note that in previous versions, there was an option to pass the
|
# Note that in previous versions, there was an option to pass the
|
||||||
# device to Generate(). However the device was then ignored, so
|
# device to Generate(). However the device was then ignored, so
|
||||||
@ -203,6 +205,19 @@ class Generate:
|
|||||||
# gets rid of annoying messages about random seed
|
# gets rid of annoying messages about random seed
|
||||||
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
|
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
|
||||||
|
|
||||||
|
# load safety checker if requested
|
||||||
|
if safety_checker:
|
||||||
|
try:
|
||||||
|
print('>> Initializing safety checker')
|
||||||
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
|
from transformers import AutoFeatureExtractor
|
||||||
|
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||||
|
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, local_files_only=True)
|
||||||
|
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, local_files_only=True)
|
||||||
|
except Exception:
|
||||||
|
print('** An error was encountered while installing the safety checker:')
|
||||||
|
print(traceback.format_exc())
|
||||||
|
|
||||||
def prompt2png(self, prompt, outdir, **kwargs):
|
def prompt2png(self, prompt, outdir, **kwargs):
|
||||||
"""
|
"""
|
||||||
Takes a prompt and an output directory, writes out the requested number
|
Takes a prompt and an output directory, writes out the requested number
|
||||||
@ -418,6 +433,11 @@ class Generate:
|
|||||||
self.seed, variation_amount, with_variations
|
self.seed, variation_amount, with_variations
|
||||||
)
|
)
|
||||||
|
|
||||||
|
checker = {
|
||||||
|
'checker':self.safety_checker,
|
||||||
|
'extractor':self.safety_feature_extractor
|
||||||
|
} if self.safety_checker else None
|
||||||
|
|
||||||
results = generator.generate(
|
results = generator.generate(
|
||||||
prompt,
|
prompt,
|
||||||
iterations=iterations,
|
iterations=iterations,
|
||||||
@ -428,10 +448,10 @@ class Generate:
|
|||||||
conditioning=(uc, c),
|
conditioning=(uc, c),
|
||||||
ddim_eta=ddim_eta,
|
ddim_eta=ddim_eta,
|
||||||
image_callback=image_callback, # called after the final image is generated
|
image_callback=image_callback, # called after the final image is generated
|
||||||
step_callback=step_callback, # called after each intermediate image is generated
|
step_callback=step_callback, # called after each intermediate image is generated
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
|
init_img=init_img, # embiggen needs to manipulate from the unmodified init_img
|
||||||
init_image=init_image, # notice that init_image is different from init_img
|
init_image=init_image, # notice that init_image is different from init_img
|
||||||
mask_image=mask_image,
|
mask_image=mask_image,
|
||||||
strength=strength,
|
strength=strength,
|
||||||
@ -440,7 +460,8 @@ class Generate:
|
|||||||
embiggen=embiggen,
|
embiggen=embiggen,
|
||||||
embiggen_tiles=embiggen_tiles,
|
embiggen_tiles=embiggen_tiles,
|
||||||
inpaint_replace=inpaint_replace,
|
inpaint_replace=inpaint_replace,
|
||||||
mask_blur_radius=mask_blur_radius
|
mask_blur_radius=mask_blur_radius,
|
||||||
|
safety_checker=checker
|
||||||
)
|
)
|
||||||
|
|
||||||
if init_color:
|
if init_color:
|
||||||
|
@ -418,6 +418,11 @@ class Args(object):
|
|||||||
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
|
help=f'Set model precision. Defaults to auto selected based on device. Options: {", ".join(PRECISION_CHOICES)}',
|
||||||
default='auto',
|
default='auto',
|
||||||
)
|
)
|
||||||
|
model_group.add_argument(
|
||||||
|
'--safety_checker',
|
||||||
|
action='store_true',
|
||||||
|
help='Check for and blur potentially NSFW images',
|
||||||
|
)
|
||||||
file_group.add_argument(
|
file_group.add_argument(
|
||||||
'--from_file',
|
'--from_file',
|
||||||
dest='infile',
|
dest='infile',
|
||||||
|
@ -7,25 +7,27 @@ import numpy as np
|
|||||||
import random
|
import random
|
||||||
import os
|
import os
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
from PIL import Image
|
from PIL import Image, ImageFilter
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
from ldm.invoke.devices import choose_autocast
|
from ldm.invoke.devices import choose_autocast
|
||||||
from ldm.util import rand_perlin_2d
|
from ldm.util import rand_perlin_2d
|
||||||
|
|
||||||
downsampling = 8
|
downsampling = 8
|
||||||
|
CAUTION_IMG = 'assets/caution.png'
|
||||||
|
|
||||||
class Generator():
|
class Generator():
|
||||||
def __init__(self, model, precision):
|
def __init__(self, model, precision):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
self.seed = None
|
self.seed = None
|
||||||
self.latent_channels = model.channels
|
self.latent_channels = model.channels
|
||||||
self.downsampling_factor = downsampling # BUG: should come from model or config
|
self.downsampling_factor = downsampling # BUG: should come from model or config
|
||||||
self.perlin = 0.0
|
self.safety_checker = None
|
||||||
self.threshold = 0
|
self.perlin = 0.0
|
||||||
self.variation_amount = 0
|
self.threshold = 0
|
||||||
self.with_variations = []
|
self.variation_amount = 0
|
||||||
|
self.with_variations = []
|
||||||
|
|
||||||
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
||||||
def get_make_image(self,prompt,**kwargs):
|
def get_make_image(self,prompt,**kwargs):
|
||||||
@ -42,8 +44,10 @@ class Generator():
|
|||||||
|
|
||||||
def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
|
def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
|
||||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||||
|
safety_checker:dict=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
scope = choose_autocast(self.precision)
|
scope = choose_autocast(self.precision)
|
||||||
|
self.safety_checker = safety_checker
|
||||||
make_image = self.get_make_image(
|
make_image = self.get_make_image(
|
||||||
prompt,
|
prompt,
|
||||||
init_image = init_image,
|
init_image = init_image,
|
||||||
@ -77,10 +81,17 @@ class Generator():
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
image = make_image(x_T)
|
image = make_image(x_T)
|
||||||
|
|
||||||
|
if self.safety_checker is not None:
|
||||||
|
image = self.safety_check(image)
|
||||||
|
|
||||||
results.append([image, seed])
|
results.append([image, seed])
|
||||||
|
|
||||||
if image_callback is not None:
|
if image_callback is not None:
|
||||||
image_callback(image, seed, first_seed=first_seed)
|
image_callback(image, seed, first_seed=first_seed)
|
||||||
|
|
||||||
seed = self.new_seed()
|
seed = self.new_seed()
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def sample_to_image(self,samples):
|
def sample_to_image(self,samples):
|
||||||
@ -169,6 +180,39 @@ class Generator():
|
|||||||
|
|
||||||
return v2
|
return v2
|
||||||
|
|
||||||
|
def safety_check(self,image:Image.Image):
|
||||||
|
'''
|
||||||
|
If the CompViz safety checker flags an NSFW image, we
|
||||||
|
blur it out.
|
||||||
|
'''
|
||||||
|
import diffusers
|
||||||
|
|
||||||
|
checker = self.safety_checker['checker']
|
||||||
|
extractor = self.safety_checker['extractor']
|
||||||
|
features = extractor([image], return_tensors="pt")
|
||||||
|
|
||||||
|
# unfortunately checker requires the numpy version, so we have to convert back
|
||||||
|
x_image = np.array(image).astype(np.float32) / 255.0
|
||||||
|
x_image = x_image[None].transpose(0, 3, 1, 2)
|
||||||
|
|
||||||
|
diffusers.logging.set_verbosity_error()
|
||||||
|
checked_image, has_nsfw_concept = checker(images=x_image, clip_input=features.pixel_values)
|
||||||
|
if has_nsfw_concept[0]:
|
||||||
|
print('** An image with potential non-safe content has been detected. A blurred image will be returned. **')
|
||||||
|
return self.blur(image)
|
||||||
|
else:
|
||||||
|
return image
|
||||||
|
|
||||||
|
def blur(self,input):
|
||||||
|
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
||||||
|
try:
|
||||||
|
caution = Image.open(CAUTION_IMG)
|
||||||
|
caution = caution.resize((caution.width // 2, caution.height //2))
|
||||||
|
blurry.paste(caution,(0,0),caution)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
return blurry
|
||||||
|
|
||||||
# this is a handy routine for debugging use. Given a generated sample,
|
# this is a handy routine for debugging use. Given a generated sample,
|
||||||
# convert it into a PNG image and store it at the indicated path
|
# convert it into a PNG image and store it at the indicated path
|
||||||
def save_sample(self, sample, filepath):
|
def save_sample(self, sample, filepath):
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
albumentations==0.4.3
|
albumentations==0.4.3
|
||||||
einops==0.3.0
|
einops==0.3.0
|
||||||
|
diffusers==0.6.0
|
||||||
huggingface-hub==0.8.1
|
huggingface-hub==0.8.1
|
||||||
imageio==2.9.0
|
imageio==2.9.0
|
||||||
imageio-ffmpeg==0.4.2
|
imageio-ffmpeg==0.4.2
|
||||||
|
@ -32,6 +32,7 @@ send2trash
|
|||||||
dependency_injector==4.40.0
|
dependency_injector==4.40.0
|
||||||
eventlet
|
eventlet
|
||||||
realesrgan
|
realesrgan
|
||||||
|
diffusers
|
||||||
git+https://github.com/openai/CLIP.git@main#egg=clip
|
git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||||
git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
|
git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k-diffusion
|
||||||
git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan
|
git+https://github.com/TencentARC/GFPGAN.git#egg=gfpgan
|
||||||
|
@ -69,16 +69,17 @@ def main():
|
|||||||
# creating a Generate object:
|
# creating a Generate object:
|
||||||
try:
|
try:
|
||||||
gen = Generate(
|
gen = Generate(
|
||||||
conf = opt.conf,
|
conf = opt.conf,
|
||||||
model = opt.model,
|
model = opt.model,
|
||||||
sampler_name = opt.sampler_name,
|
sampler_name = opt.sampler_name,
|
||||||
embedding_path = opt.embedding_path,
|
embedding_path = opt.embedding_path,
|
||||||
full_precision = opt.full_precision,
|
full_precision = opt.full_precision,
|
||||||
precision = opt.precision,
|
precision = opt.precision,
|
||||||
gfpgan=gfpgan,
|
gfpgan=gfpgan,
|
||||||
codeformer=codeformer,
|
codeformer=codeformer,
|
||||||
esrgan=esrgan,
|
esrgan=esrgan,
|
||||||
free_gpu_mem=opt.free_gpu_mem,
|
free_gpu_mem=opt.free_gpu_mem,
|
||||||
|
safety_checker=opt.safety_checker,
|
||||||
)
|
)
|
||||||
except (FileNotFoundError, IOError, KeyError) as e:
|
except (FileNotFoundError, IOError, KeyError) as e:
|
||||||
print(f'{e}. Aborting.')
|
print(f'{e}. Aborting.')
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
# two machines must share a common .cache directory.
|
# two machines must share a common .cache directory.
|
||||||
from transformers import CLIPTokenizer, CLIPTextModel
|
from transformers import CLIPTokenizer, CLIPTextModel
|
||||||
import clip
|
import clip
|
||||||
from transformers import BertTokenizerFast
|
from transformers import BertTokenizerFast, AutoFeatureExtractor
|
||||||
import sys
|
import sys
|
||||||
import transformers
|
import transformers
|
||||||
import os
|
import os
|
||||||
@ -17,41 +17,39 @@ import traceback
|
|||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
#---------------------------------------------
|
||||||
# this will preload the Bert tokenizer fles
|
# this will preload the Bert tokenizer fles
|
||||||
print('Loading bert tokenizer (ignore deprecation errors)...', end='')
|
def download_bert():
|
||||||
with warnings.catch_warnings():
|
print('Installing bert tokenizer (ignore deprecation errors)...', end='')
|
||||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
with warnings.catch_warnings():
|
||||||
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||||
print('...success')
|
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
||||||
sys.stdout.flush()
|
print('...success')
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
#---------------------------------------------
|
||||||
# this will download requirements for Kornia
|
# this will download requirements for Kornia
|
||||||
print('Loading Kornia requirements...', end='')
|
def download_kornia():
|
||||||
with warnings.catch_warnings():
|
print('Installing Kornia requirements...', end='')
|
||||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
with warnings.catch_warnings():
|
||||||
import kornia
|
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||||
print('...success')
|
import kornia
|
||||||
|
print('...success')
|
||||||
|
|
||||||
version = 'openai/clip-vit-large-patch14'
|
#---------------------------------------------
|
||||||
sys.stdout.flush()
|
def download_clip():
|
||||||
print('Loading CLIP model...',end='')
|
version = 'openai/clip-vit-large-patch14'
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(version)
|
sys.stdout.flush()
|
||||||
transformer = CLIPTextModel.from_pretrained(version)
|
print('Loading CLIP model...',end='')
|
||||||
print('...success')
|
tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||||
|
transformer = CLIPTextModel.from_pretrained(version)
|
||||||
|
print('...success')
|
||||||
|
|
||||||
# In the event that the user has installed GFPGAN and also elected to use
|
#---------------------------------------------
|
||||||
# RealESRGAN, this will attempt to download the model needed by RealESRGANer
|
def download_gfpgan():
|
||||||
gfpgan = False
|
print('Installing models from RealESRGAN and facexlib...',end='')
|
||||||
try:
|
|
||||||
from realesrgan import RealESRGANer
|
|
||||||
|
|
||||||
gfpgan = True
|
|
||||||
except ModuleNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if gfpgan:
|
|
||||||
print('Loading models from RealESRGAN and facexlib...',end='')
|
|
||||||
try:
|
try:
|
||||||
|
from realesrgan import RealESRGANer
|
||||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||||
|
|
||||||
@ -94,44 +92,72 @@ if gfpgan:
|
|||||||
print('Error loading GFPGAN:')
|
print('Error loading GFPGAN:')
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
|
||||||
print('preloading CodeFormer model file...',end='')
|
#---------------------------------------------
|
||||||
try:
|
def download_codeformer():
|
||||||
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
print('Installing CodeFormer model file...',end='')
|
||||||
model_dest = 'ldm/invoke/restoration/codeformer/weights/codeformer.pth'
|
try:
|
||||||
if not os.path.exists(model_dest):
|
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||||
print('Downloading codeformer model file...')
|
model_dest = 'ldm/invoke/restoration/codeformer/weights/codeformer.pth'
|
||||||
|
if not os.path.exists(model_dest):
|
||||||
|
print('Downloading codeformer model file...')
|
||||||
|
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||||
|
urllib.request.urlretrieve(model_url,model_dest)
|
||||||
|
except Exception:
|
||||||
|
print('Error loading CodeFormer:')
|
||||||
|
print(traceback.format_exc())
|
||||||
|
print('...success')
|
||||||
|
|
||||||
|
#---------------------------------------------
|
||||||
|
def download_clipseg():
|
||||||
|
print('Installing clipseg model for text-based masking...',end='')
|
||||||
|
try:
|
||||||
|
model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download'
|
||||||
|
model_dest = 'src/clipseg/clipseg_weights.zip'
|
||||||
|
weights_dir = 'src/clipseg/weights'
|
||||||
|
if not os.path.exists(weights_dir):
|
||||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||||
urllib.request.urlretrieve(model_url,model_dest)
|
urllib.request.urlretrieve(model_url,model_dest)
|
||||||
except Exception:
|
with zipfile.ZipFile(model_dest,'r') as zip:
|
||||||
print('Error loading CodeFormer:')
|
zip.extractall('src/clipseg')
|
||||||
print(traceback.format_exc())
|
os.rename('src/clipseg/clipseg_weights','src/clipseg/weights')
|
||||||
print('...success')
|
os.remove(model_dest)
|
||||||
|
from clipseg_models.clipseg import CLIPDensePredT
|
||||||
|
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, )
|
||||||
|
model.eval()
|
||||||
|
model.load_state_dict(
|
||||||
|
torch.load(
|
||||||
|
'src/clipseg/weights/rd64-uni-refined.pth',
|
||||||
|
map_location=torch.device('cpu')
|
||||||
|
),
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
print('Error installing clipseg model:')
|
||||||
|
print(traceback.format_exc())
|
||||||
|
print('...success')
|
||||||
|
|
||||||
print('Loading clipseg model for text-based masking...',end='')
|
#-------------------------------------
|
||||||
try:
|
def download_safety_checker():
|
||||||
model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download'
|
print('Installing safety model for NSFW content detection...',end='')
|
||||||
model_dest = 'src/clipseg/clipseg_weights.zip'
|
try:
|
||||||
weights_dir = 'src/clipseg/weights'
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
if not os.path.exists(weights_dir):
|
except ModuleNotFoundError:
|
||||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
print('Error installing safety checker model:')
|
||||||
urllib.request.urlretrieve(model_url,model_dest)
|
print(traceback.format_exc())
|
||||||
with zipfile.ZipFile(model_dest,'r') as zip:
|
return
|
||||||
zip.extractall('src/clipseg')
|
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||||
os.rename('src/clipseg/clipseg_weights','src/clipseg/weights')
|
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
||||||
os.remove(model_dest)
|
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
||||||
from clipseg_models.clipseg import CLIPDensePredT
|
print('...success')
|
||||||
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, )
|
|
||||||
model.eval()
|
#-------------------------------------
|
||||||
model.load_state_dict(
|
if __name__ == '__main__':
|
||||||
torch.load(
|
download_bert()
|
||||||
'src/clipseg/weights/rd64-uni-refined.pth',
|
download_kornia()
|
||||||
map_location=torch.device('cpu')
|
download_clip()
|
||||||
),
|
download_gfpgan()
|
||||||
strict=False,
|
download_codeformer()
|
||||||
)
|
download_clipseg()
|
||||||
except Exception:
|
download_safety_checker()
|
||||||
print('Error installing clipseg model:')
|
|
||||||
print(traceback.format_exc())
|
|
||||||
print('...success')
|
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user