mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
restore NSFW checker
This commit is contained in:
parent
580f9ecded
commit
3aa1ee1218
@ -11,5 +11,6 @@ from .generator import (
|
||||
Inpaint
|
||||
)
|
||||
from .model_management import ModelManager
|
||||
from .safety_checker import SafetyChecker
|
||||
from .args import Args
|
||||
from .globals import Globals
|
||||
|
@ -25,18 +25,19 @@ from accelerate.utils import set_seed
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
|
||||
from .args import metadata_from_png
|
||||
from .generator import infill_methods
|
||||
from .globals import Globals, global_cache_dir
|
||||
from .image_util import InitImageResizer, PngWriter, Txt2Mask, configure_model_padding
|
||||
from .model_management import ModelManager
|
||||
from .safety_checker import SafetyChecker
|
||||
from .prompting import get_uc_and_c_and_ec
|
||||
from .prompting.conditioning import log_tokenization
|
||||
from .stable_diffusion import HuggingFaceConceptsLibrary
|
||||
from .util import choose_precision, choose_torch_device
|
||||
|
||||
|
||||
def fix_func(orig):
|
||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
|
||||
@ -245,31 +246,8 @@ class Generate:
|
||||
|
||||
# load safety checker if requested
|
||||
if safety_checker:
|
||||
try:
|
||||
print(">> Initializing NSFW checker")
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
safety_model_path = global_cache_dir("hub")
|
||||
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
||||
safety_model_id,
|
||||
local_files_only=True,
|
||||
cache_dir=safety_model_path,
|
||||
)
|
||||
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
safety_model_id,
|
||||
local_files_only=True,
|
||||
cache_dir=safety_model_path,
|
||||
)
|
||||
self.safety_checker.to(self.device)
|
||||
except Exception:
|
||||
print(
|
||||
"** An error was encountered while installing the safety checker:"
|
||||
)
|
||||
print(traceback.format_exc())
|
||||
print(">> Initializing NSFW checker")
|
||||
self.safety_checker = SafetyChecker(self.device)
|
||||
else:
|
||||
print(">> NSFW checker is disabled")
|
||||
|
||||
@ -524,15 +502,6 @@ class Generate:
|
||||
generator.set_variation(self.seed, variation_amount, with_variations)
|
||||
generator.use_mps_noise = use_mps_noise
|
||||
|
||||
checker = (
|
||||
{
|
||||
"checker": self.safety_checker,
|
||||
"extractor": self.safety_feature_extractor,
|
||||
}
|
||||
if self.safety_checker
|
||||
else None
|
||||
)
|
||||
|
||||
results = generator.generate(
|
||||
prompt,
|
||||
iterations=iterations,
|
||||
@ -559,7 +528,7 @@ class Generate:
|
||||
embiggen_strength=embiggen_strength,
|
||||
inpaint_replace=inpaint_replace,
|
||||
mask_blur_radius=mask_blur_radius,
|
||||
safety_checker=checker,
|
||||
safety_checker=self.safety_checker,
|
||||
seam_size=seam_size,
|
||||
seam_blur=seam_blur,
|
||||
seam_strength=seam_strength,
|
||||
|
@ -15,14 +15,18 @@ from transformers import AutoFeatureExtractor
|
||||
|
||||
import invokeai.assets.web as web_assets
|
||||
from .globals import global_cache_dir
|
||||
from .util import CPU_DEVICE
|
||||
|
||||
class SafetyChecker(object):
|
||||
CAUTION_IMG = "caution.png"
|
||||
|
||||
def __init__(self, device: torch.device):
|
||||
path = Path(web_assets.__path__[0]) / self.CAUTION_IMG
|
||||
caution = Image.open(path)
|
||||
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
|
||||
self.device = device
|
||||
|
||||
try:
|
||||
print(">> Initializing NSFW checker")
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
safety_model_path = global_cache_dir("hub")
|
||||
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
||||
@ -35,15 +39,11 @@ class SafetyChecker(object):
|
||||
local_files_only=True,
|
||||
cache_dir=safety_model_path,
|
||||
)
|
||||
self.safety_checker.to(device)
|
||||
self.safety_feature_extractor.to(device)
|
||||
except Exception:
|
||||
print(
|
||||
"** An error was encountered while installing the safety checker:"
|
||||
)
|
||||
print(traceback.format_exc())
|
||||
else:
|
||||
print(">> NSFW checker is disabled")
|
||||
|
||||
def check(self, image: Image.Image):
|
||||
"""
|
||||
@ -51,7 +51,10 @@ class SafetyChecker(object):
|
||||
|
||||
"""
|
||||
|
||||
self.safety_checker.to(self.device)
|
||||
features = self.safety_feature_extractor([image], return_tensors="pt")
|
||||
features.to(self.device)
|
||||
|
||||
# unfortunately checker requires the numpy version, so we have to convert back
|
||||
x_image = np.array(image).astype(np.float32) / 255.0
|
||||
x_image = x_image[None].transpose(0, 3, 1, 2)
|
||||
@ -60,6 +63,7 @@ class SafetyChecker(object):
|
||||
checked_image, has_nsfw_concept = self.safety_checker(
|
||||
images=x_image, clip_input=features.pixel_values
|
||||
)
|
||||
self.safety_checker.to(CPU_DEVICE) # offload
|
||||
if has_nsfw_concept[0]:
|
||||
print(
|
||||
"** An image with potential non-safe content has been detected. A blurred image will be returned. **"
|
||||
@ -71,19 +75,8 @@ class SafetyChecker(object):
|
||||
def blur(self, input):
|
||||
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
||||
try:
|
||||
caution = self.get_caution_img()
|
||||
if caution:
|
||||
if caution := self.caution_img:
|
||||
blurry.paste(caution, (0, 0), caution)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
return blurry
|
||||
|
||||
def get_caution_img(self):
|
||||
path = None
|
||||
if self.caution_img:
|
||||
return self.caution_img
|
||||
path = Path(web_assets.__path__[0]) / self.CAUTION_IMG
|
||||
caution = Image.open(path)
|
||||
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
|
||||
return self.caution_img
|
||||
|
||||
|
@ -21,7 +21,7 @@ def simple_graph():
|
||||
def mock_services():
|
||||
# NOTE: none of these are actually called by the test invocations
|
||||
return InvocationServices(
|
||||
generate = None,
|
||||
model_manager = None,
|
||||
events = None,
|
||||
images = None,
|
||||
queue = MemoryInvocationQueue(),
|
||||
|
@ -21,7 +21,7 @@ def simple_graph():
|
||||
def mock_services() -> InvocationServices:
|
||||
# NOTE: none of these are actually called by the test invocations
|
||||
return InvocationServices(
|
||||
generate = None, # type: ignore
|
||||
model_manager = None, # type: ignore
|
||||
events = TestEventService(),
|
||||
images = None, # type: ignore
|
||||
queue = MemoryInvocationQueue(),
|
||||
|
Loading…
Reference in New Issue
Block a user