mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
restore NSFW checker
This commit is contained in:
parent
580f9ecded
commit
3aa1ee1218
@ -11,5 +11,6 @@ from .generator import (
|
|||||||
Inpaint
|
Inpaint
|
||||||
)
|
)
|
||||||
from .model_management import ModelManager
|
from .model_management import ModelManager
|
||||||
|
from .safety_checker import SafetyChecker
|
||||||
from .args import Args
|
from .args import Args
|
||||||
from .globals import Globals
|
from .globals import Globals
|
||||||
|
@ -25,18 +25,19 @@ from accelerate.utils import set_seed
|
|||||||
from diffusers.pipeline_utils import DiffusionPipeline
|
from diffusers.pipeline_utils import DiffusionPipeline
|
||||||
from diffusers.utils.import_utils import is_xformers_available
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from .args import metadata_from_png
|
from .args import metadata_from_png
|
||||||
from .generator import infill_methods
|
from .generator import infill_methods
|
||||||
from .globals import Globals, global_cache_dir
|
from .globals import Globals, global_cache_dir
|
||||||
from .image_util import InitImageResizer, PngWriter, Txt2Mask, configure_model_padding
|
from .image_util import InitImageResizer, PngWriter, Txt2Mask, configure_model_padding
|
||||||
from .model_management import ModelManager
|
from .model_management import ModelManager
|
||||||
|
from .safety_checker import SafetyChecker
|
||||||
from .prompting import get_uc_and_c_and_ec
|
from .prompting import get_uc_and_c_and_ec
|
||||||
from .prompting.conditioning import log_tokenization
|
from .prompting.conditioning import log_tokenization
|
||||||
from .stable_diffusion import HuggingFaceConceptsLibrary
|
from .stable_diffusion import HuggingFaceConceptsLibrary
|
||||||
from .util import choose_precision, choose_torch_device
|
from .util import choose_precision, choose_torch_device
|
||||||
|
|
||||||
|
|
||||||
def fix_func(orig):
|
def fix_func(orig):
|
||||||
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||||
|
|
||||||
@ -245,31 +246,8 @@ class Generate:
|
|||||||
|
|
||||||
# load safety checker if requested
|
# load safety checker if requested
|
||||||
if safety_checker:
|
if safety_checker:
|
||||||
try:
|
print(">> Initializing NSFW checker")
|
||||||
print(">> Initializing NSFW checker")
|
self.safety_checker = SafetyChecker(self.device)
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
|
||||||
StableDiffusionSafetyChecker,
|
|
||||||
)
|
|
||||||
from transformers import AutoFeatureExtractor
|
|
||||||
|
|
||||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
|
||||||
safety_model_path = global_cache_dir("hub")
|
|
||||||
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
|
||||||
safety_model_id,
|
|
||||||
local_files_only=True,
|
|
||||||
cache_dir=safety_model_path,
|
|
||||||
)
|
|
||||||
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(
|
|
||||||
safety_model_id,
|
|
||||||
local_files_only=True,
|
|
||||||
cache_dir=safety_model_path,
|
|
||||||
)
|
|
||||||
self.safety_checker.to(self.device)
|
|
||||||
except Exception:
|
|
||||||
print(
|
|
||||||
"** An error was encountered while installing the safety checker:"
|
|
||||||
)
|
|
||||||
print(traceback.format_exc())
|
|
||||||
else:
|
else:
|
||||||
print(">> NSFW checker is disabled")
|
print(">> NSFW checker is disabled")
|
||||||
|
|
||||||
@ -524,15 +502,6 @@ class Generate:
|
|||||||
generator.set_variation(self.seed, variation_amount, with_variations)
|
generator.set_variation(self.seed, variation_amount, with_variations)
|
||||||
generator.use_mps_noise = use_mps_noise
|
generator.use_mps_noise = use_mps_noise
|
||||||
|
|
||||||
checker = (
|
|
||||||
{
|
|
||||||
"checker": self.safety_checker,
|
|
||||||
"extractor": self.safety_feature_extractor,
|
|
||||||
}
|
|
||||||
if self.safety_checker
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
results = generator.generate(
|
results = generator.generate(
|
||||||
prompt,
|
prompt,
|
||||||
iterations=iterations,
|
iterations=iterations,
|
||||||
@ -559,7 +528,7 @@ class Generate:
|
|||||||
embiggen_strength=embiggen_strength,
|
embiggen_strength=embiggen_strength,
|
||||||
inpaint_replace=inpaint_replace,
|
inpaint_replace=inpaint_replace,
|
||||||
mask_blur_radius=mask_blur_radius,
|
mask_blur_radius=mask_blur_radius,
|
||||||
safety_checker=checker,
|
safety_checker=self.safety_checker,
|
||||||
seam_size=seam_size,
|
seam_size=seam_size,
|
||||||
seam_blur=seam_blur,
|
seam_blur=seam_blur,
|
||||||
seam_strength=seam_strength,
|
seam_strength=seam_strength,
|
||||||
|
@ -15,14 +15,18 @@ from transformers import AutoFeatureExtractor
|
|||||||
|
|
||||||
import invokeai.assets.web as web_assets
|
import invokeai.assets.web as web_assets
|
||||||
from .globals import global_cache_dir
|
from .globals import global_cache_dir
|
||||||
|
from .util import CPU_DEVICE
|
||||||
|
|
||||||
class SafetyChecker(object):
|
class SafetyChecker(object):
|
||||||
CAUTION_IMG = "caution.png"
|
CAUTION_IMG = "caution.png"
|
||||||
|
|
||||||
def __init__(self, device: torch.device):
|
def __init__(self, device: torch.device):
|
||||||
|
path = Path(web_assets.__path__[0]) / self.CAUTION_IMG
|
||||||
|
caution = Image.open(path)
|
||||||
|
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print(">> Initializing NSFW checker")
|
|
||||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||||
safety_model_path = global_cache_dir("hub")
|
safety_model_path = global_cache_dir("hub")
|
||||||
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
||||||
@ -35,15 +39,11 @@ class SafetyChecker(object):
|
|||||||
local_files_only=True,
|
local_files_only=True,
|
||||||
cache_dir=safety_model_path,
|
cache_dir=safety_model_path,
|
||||||
)
|
)
|
||||||
self.safety_checker.to(device)
|
|
||||||
self.safety_feature_extractor.to(device)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
print(
|
print(
|
||||||
"** An error was encountered while installing the safety checker:"
|
"** An error was encountered while installing the safety checker:"
|
||||||
)
|
)
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
else:
|
|
||||||
print(">> NSFW checker is disabled")
|
|
||||||
|
|
||||||
def check(self, image: Image.Image):
|
def check(self, image: Image.Image):
|
||||||
"""
|
"""
|
||||||
@ -51,7 +51,10 @@ class SafetyChecker(object):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
self.safety_checker.to(self.device)
|
||||||
features = self.safety_feature_extractor([image], return_tensors="pt")
|
features = self.safety_feature_extractor([image], return_tensors="pt")
|
||||||
|
features.to(self.device)
|
||||||
|
|
||||||
# unfortunately checker requires the numpy version, so we have to convert back
|
# unfortunately checker requires the numpy version, so we have to convert back
|
||||||
x_image = np.array(image).astype(np.float32) / 255.0
|
x_image = np.array(image).astype(np.float32) / 255.0
|
||||||
x_image = x_image[None].transpose(0, 3, 1, 2)
|
x_image = x_image[None].transpose(0, 3, 1, 2)
|
||||||
@ -60,6 +63,7 @@ class SafetyChecker(object):
|
|||||||
checked_image, has_nsfw_concept = self.safety_checker(
|
checked_image, has_nsfw_concept = self.safety_checker(
|
||||||
images=x_image, clip_input=features.pixel_values
|
images=x_image, clip_input=features.pixel_values
|
||||||
)
|
)
|
||||||
|
self.safety_checker.to(CPU_DEVICE) # offload
|
||||||
if has_nsfw_concept[0]:
|
if has_nsfw_concept[0]:
|
||||||
print(
|
print(
|
||||||
"** An image with potential non-safe content has been detected. A blurred image will be returned. **"
|
"** An image with potential non-safe content has been detected. A blurred image will be returned. **"
|
||||||
@ -71,19 +75,8 @@ class SafetyChecker(object):
|
|||||||
def blur(self, input):
|
def blur(self, input):
|
||||||
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
||||||
try:
|
try:
|
||||||
caution = self.get_caution_img()
|
if caution := self.caution_img:
|
||||||
if caution:
|
|
||||||
blurry.paste(caution, (0, 0), caution)
|
blurry.paste(caution, (0, 0), caution)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
return blurry
|
return blurry
|
||||||
|
|
||||||
def get_caution_img(self):
|
|
||||||
path = None
|
|
||||||
if self.caution_img:
|
|
||||||
return self.caution_img
|
|
||||||
path = Path(web_assets.__path__[0]) / self.CAUTION_IMG
|
|
||||||
caution = Image.open(path)
|
|
||||||
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
|
|
||||||
return self.caution_img
|
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ def simple_graph():
|
|||||||
def mock_services():
|
def mock_services():
|
||||||
# NOTE: none of these are actually called by the test invocations
|
# NOTE: none of these are actually called by the test invocations
|
||||||
return InvocationServices(
|
return InvocationServices(
|
||||||
generate = None,
|
model_manager = None,
|
||||||
events = None,
|
events = None,
|
||||||
images = None,
|
images = None,
|
||||||
queue = MemoryInvocationQueue(),
|
queue = MemoryInvocationQueue(),
|
||||||
|
@ -21,7 +21,7 @@ def simple_graph():
|
|||||||
def mock_services() -> InvocationServices:
|
def mock_services() -> InvocationServices:
|
||||||
# NOTE: none of these are actually called by the test invocations
|
# NOTE: none of these are actually called by the test invocations
|
||||||
return InvocationServices(
|
return InvocationServices(
|
||||||
generate = None, # type: ignore
|
model_manager = None, # type: ignore
|
||||||
events = TestEventService(),
|
events = TestEventService(),
|
||||||
images = None, # type: ignore
|
images = None, # type: ignore
|
||||||
queue = MemoryInvocationQueue(),
|
queue = MemoryInvocationQueue(),
|
||||||
|
Loading…
Reference in New Issue
Block a user