restore NSFW checker

This commit is contained in:
Lincoln Stein 2023-03-11 16:16:44 -05:00
parent 580f9ecded
commit 3aa1ee1218
5 changed files with 18 additions and 55 deletions

View File

@ -11,5 +11,6 @@ from .generator import (
Inpaint Inpaint
) )
from .model_management import ModelManager from .model_management import ModelManager
from .safety_checker import SafetyChecker
from .args import Args from .args import Args
from .globals import Globals from .globals import Globals

View File

@ -25,18 +25,19 @@ from accelerate.utils import set_seed
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from omegaconf import OmegaConf from omegaconf import OmegaConf
from pathlib import Path
from .args import metadata_from_png from .args import metadata_from_png
from .generator import infill_methods from .generator import infill_methods
from .globals import Globals, global_cache_dir from .globals import Globals, global_cache_dir
from .image_util import InitImageResizer, PngWriter, Txt2Mask, configure_model_padding from .image_util import InitImageResizer, PngWriter, Txt2Mask, configure_model_padding
from .model_management import ModelManager from .model_management import ModelManager
from .safety_checker import SafetyChecker
from .prompting import get_uc_and_c_and_ec from .prompting import get_uc_and_c_and_ec
from .prompting.conditioning import log_tokenization from .prompting.conditioning import log_tokenization
from .stable_diffusion import HuggingFaceConceptsLibrary from .stable_diffusion import HuggingFaceConceptsLibrary
from .util import choose_precision, choose_torch_device from .util import choose_precision, choose_torch_device
def fix_func(orig): def fix_func(orig):
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
@ -245,31 +246,8 @@ class Generate:
# load safety checker if requested # load safety checker if requested
if safety_checker: if safety_checker:
try:
print(">> Initializing NSFW checker") print(">> Initializing NSFW checker")
from diffusers.pipelines.stable_diffusion.safety_checker import ( self.safety_checker = SafetyChecker(self.device)
StableDiffusionSafetyChecker,
)
from transformers import AutoFeatureExtractor
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_model_path = global_cache_dir("hub")
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
safety_model_id,
local_files_only=True,
cache_dir=safety_model_path,
)
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(
safety_model_id,
local_files_only=True,
cache_dir=safety_model_path,
)
self.safety_checker.to(self.device)
except Exception:
print(
"** An error was encountered while installing the safety checker:"
)
print(traceback.format_exc())
else: else:
print(">> NSFW checker is disabled") print(">> NSFW checker is disabled")
@ -524,15 +502,6 @@ class Generate:
generator.set_variation(self.seed, variation_amount, with_variations) generator.set_variation(self.seed, variation_amount, with_variations)
generator.use_mps_noise = use_mps_noise generator.use_mps_noise = use_mps_noise
checker = (
{
"checker": self.safety_checker,
"extractor": self.safety_feature_extractor,
}
if self.safety_checker
else None
)
results = generator.generate( results = generator.generate(
prompt, prompt,
iterations=iterations, iterations=iterations,
@ -559,7 +528,7 @@ class Generate:
embiggen_strength=embiggen_strength, embiggen_strength=embiggen_strength,
inpaint_replace=inpaint_replace, inpaint_replace=inpaint_replace,
mask_blur_radius=mask_blur_radius, mask_blur_radius=mask_blur_radius,
safety_checker=checker, safety_checker=self.safety_checker,
seam_size=seam_size, seam_size=seam_size,
seam_blur=seam_blur, seam_blur=seam_blur,
seam_strength=seam_strength, seam_strength=seam_strength,

View File

@ -15,14 +15,18 @@ from transformers import AutoFeatureExtractor
import invokeai.assets.web as web_assets import invokeai.assets.web as web_assets
from .globals import global_cache_dir from .globals import global_cache_dir
from .util import CPU_DEVICE
class SafetyChecker(object): class SafetyChecker(object):
CAUTION_IMG = "caution.png" CAUTION_IMG = "caution.png"
def __init__(self, device: torch.device): def __init__(self, device: torch.device):
path = Path(web_assets.__path__[0]) / self.CAUTION_IMG
caution = Image.open(path)
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
self.device = device self.device = device
try: try:
print(">> Initializing NSFW checker")
safety_model_id = "CompVis/stable-diffusion-safety-checker" safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_model_path = global_cache_dir("hub") safety_model_path = global_cache_dir("hub")
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained( self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
@ -35,15 +39,11 @@ class SafetyChecker(object):
local_files_only=True, local_files_only=True,
cache_dir=safety_model_path, cache_dir=safety_model_path,
) )
self.safety_checker.to(device)
self.safety_feature_extractor.to(device)
except Exception: except Exception:
print( print(
"** An error was encountered while installing the safety checker:" "** An error was encountered while installing the safety checker:"
) )
print(traceback.format_exc()) print(traceback.format_exc())
else:
print(">> NSFW checker is disabled")
def check(self, image: Image.Image): def check(self, image: Image.Image):
""" """
@ -51,7 +51,10 @@ class SafetyChecker(object):
""" """
self.safety_checker.to(self.device)
features = self.safety_feature_extractor([image], return_tensors="pt") features = self.safety_feature_extractor([image], return_tensors="pt")
features.to(self.device)
# unfortunately checker requires the numpy version, so we have to convert back # unfortunately checker requires the numpy version, so we have to convert back
x_image = np.array(image).astype(np.float32) / 255.0 x_image = np.array(image).astype(np.float32) / 255.0
x_image = x_image[None].transpose(0, 3, 1, 2) x_image = x_image[None].transpose(0, 3, 1, 2)
@ -60,6 +63,7 @@ class SafetyChecker(object):
checked_image, has_nsfw_concept = self.safety_checker( checked_image, has_nsfw_concept = self.safety_checker(
images=x_image, clip_input=features.pixel_values images=x_image, clip_input=features.pixel_values
) )
self.safety_checker.to(CPU_DEVICE) # offload
if has_nsfw_concept[0]: if has_nsfw_concept[0]:
print( print(
"** An image with potential non-safe content has been detected. A blurred image will be returned. **" "** An image with potential non-safe content has been detected. A blurred image will be returned. **"
@ -71,19 +75,8 @@ class SafetyChecker(object):
def blur(self, input): def blur(self, input):
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32)) blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
try: try:
caution = self.get_caution_img() if caution := self.caution_img:
if caution:
blurry.paste(caution, (0, 0), caution) blurry.paste(caution, (0, 0), caution)
except FileNotFoundError: except FileNotFoundError:
pass pass
return blurry return blurry
def get_caution_img(self):
path = None
if self.caution_img:
return self.caution_img
path = Path(web_assets.__path__[0]) / self.CAUTION_IMG
caution = Image.open(path)
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
return self.caution_img

View File

@ -21,7 +21,7 @@ def simple_graph():
def mock_services(): def mock_services():
# NOTE: none of these are actually called by the test invocations # NOTE: none of these are actually called by the test invocations
return InvocationServices( return InvocationServices(
generate = None, model_manager = None,
events = None, events = None,
images = None, images = None,
queue = MemoryInvocationQueue(), queue = MemoryInvocationQueue(),

View File

@ -21,7 +21,7 @@ def simple_graph():
def mock_services() -> InvocationServices: def mock_services() -> InvocationServices:
# NOTE: none of these are actually called by the test invocations # NOTE: none of these are actually called by the test invocations
return InvocationServices( return InvocationServices(
generate = None, # type: ignore model_manager = None, # type: ignore
events = TestEventService(), events = TestEventService(),
images = None, # type: ignore images = None, # type: ignore
queue = MemoryInvocationQueue(), queue = MemoryInvocationQueue(),