diff --git a/invokeai/backend/__init__.py b/invokeai/backend/__init__.py index a816486631..06066dd6b1 100644 --- a/invokeai/backend/__init__.py +++ b/invokeai/backend/__init__.py @@ -11,5 +11,6 @@ from .generator import ( Inpaint ) from .model_management import ModelManager +from .safety_checker import SafetyChecker from .args import Args from .globals import Globals diff --git a/invokeai/backend/generate.py b/invokeai/backend/generate.py index 22e4ff177d..7db0c4a2ef 100644 --- a/invokeai/backend/generate.py +++ b/invokeai/backend/generate.py @@ -25,18 +25,19 @@ from accelerate.utils import set_seed from diffusers.pipeline_utils import DiffusionPipeline from diffusers.utils.import_utils import is_xformers_available from omegaconf import OmegaConf +from pathlib import Path from .args import metadata_from_png from .generator import infill_methods from .globals import Globals, global_cache_dir from .image_util import InitImageResizer, PngWriter, Txt2Mask, configure_model_padding from .model_management import ModelManager +from .safety_checker import SafetyChecker from .prompting import get_uc_and_c_and_ec from .prompting.conditioning import log_tokenization from .stable_diffusion import HuggingFaceConceptsLibrary from .util import choose_precision, choose_torch_device - def fix_func(orig): if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): @@ -245,31 +246,8 @@ class Generate: # load safety checker if requested if safety_checker: - try: - print(">> Initializing NSFW checker") - from diffusers.pipelines.stable_diffusion.safety_checker import ( - StableDiffusionSafetyChecker, - ) - from transformers import AutoFeatureExtractor - - safety_model_id = "CompVis/stable-diffusion-safety-checker" - safety_model_path = global_cache_dir("hub") - self.safety_checker = StableDiffusionSafetyChecker.from_pretrained( - safety_model_id, - local_files_only=True, - cache_dir=safety_model_path, - ) - self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained( - safety_model_id, - local_files_only=True, - cache_dir=safety_model_path, - ) - self.safety_checker.to(self.device) - except Exception: - print( - "** An error was encountered while installing the safety checker:" - ) - print(traceback.format_exc()) + print(">> Initializing NSFW checker") + self.safety_checker = SafetyChecker(self.device) else: print(">> NSFW checker is disabled") @@ -524,15 +502,6 @@ class Generate: generator.set_variation(self.seed, variation_amount, with_variations) generator.use_mps_noise = use_mps_noise - checker = ( - { - "checker": self.safety_checker, - "extractor": self.safety_feature_extractor, - } - if self.safety_checker - else None - ) - results = generator.generate( prompt, iterations=iterations, @@ -559,7 +528,7 @@ class Generate: embiggen_strength=embiggen_strength, inpaint_replace=inpaint_replace, mask_blur_radius=mask_blur_radius, - safety_checker=checker, + safety_checker=self.safety_checker, seam_size=seam_size, seam_blur=seam_blur, seam_strength=seam_strength, diff --git a/invokeai/backend/safety_checker.py b/invokeai/backend/safety_checker.py index 86cf31cc13..2e6c4fd479 100644 --- a/invokeai/backend/safety_checker.py +++ b/invokeai/backend/safety_checker.py @@ -15,14 +15,18 @@ from transformers import AutoFeatureExtractor import invokeai.assets.web as web_assets from .globals import global_cache_dir +from .util import CPU_DEVICE class SafetyChecker(object): CAUTION_IMG = "caution.png" def __init__(self, device: torch.device): + path = Path(web_assets.__path__[0]) / self.CAUTION_IMG + caution = Image.open(path) + self.caution_img = caution.resize((caution.width // 2, caution.height // 2)) self.device = device + try: - print(">> Initializing NSFW checker") safety_model_id = "CompVis/stable-diffusion-safety-checker" safety_model_path = global_cache_dir("hub") self.safety_checker = StableDiffusionSafetyChecker.from_pretrained( @@ -35,15 +39,11 @@ class SafetyChecker(object): local_files_only=True, cache_dir=safety_model_path, ) - self.safety_checker.to(device) - self.safety_feature_extractor.to(device) except Exception: print( "** An error was encountered while installing the safety checker:" ) print(traceback.format_exc()) - else: - print(">> NSFW checker is disabled") def check(self, image: Image.Image): """ @@ -51,7 +51,10 @@ class SafetyChecker(object): """ + self.safety_checker.to(self.device) features = self.safety_feature_extractor([image], return_tensors="pt") + features.to(self.device) + # unfortunately checker requires the numpy version, so we have to convert back x_image = np.array(image).astype(np.float32) / 255.0 x_image = x_image[None].transpose(0, 3, 1, 2) @@ -60,6 +63,7 @@ class SafetyChecker(object): checked_image, has_nsfw_concept = self.safety_checker( images=x_image, clip_input=features.pixel_values ) + self.safety_checker.to(CPU_DEVICE) # offload if has_nsfw_concept[0]: print( "** An image with potential non-safe content has been detected. A blurred image will be returned. **" @@ -71,19 +75,8 @@ class SafetyChecker(object): def blur(self, input): blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32)) try: - caution = self.get_caution_img() - if caution: + if caution := self.caution_img: blurry.paste(caution, (0, 0), caution) except FileNotFoundError: pass return blurry - - def get_caution_img(self): - path = None - if self.caution_img: - return self.caution_img - path = Path(web_assets.__path__[0]) / self.CAUTION_IMG - caution = Image.open(path) - self.caution_img = caution.resize((caution.width // 2, caution.height // 2)) - return self.caution_img - diff --git a/tests/nodes/test_graph_execution_state.py b/tests/nodes/test_graph_execution_state.py index 4c22507098..8cd6baf010 100644 --- a/tests/nodes/test_graph_execution_state.py +++ b/tests/nodes/test_graph_execution_state.py @@ -21,7 +21,7 @@ def simple_graph(): def mock_services(): # NOTE: none of these are actually called by the test invocations return InvocationServices( - generate = None, + model_manager = None, events = None, images = None, queue = MemoryInvocationQueue(), diff --git a/tests/nodes/test_invoker.py b/tests/nodes/test_invoker.py index 6a7867bffe..a2cc92ce7a 100644 --- a/tests/nodes/test_invoker.py +++ b/tests/nodes/test_invoker.py @@ -21,7 +21,7 @@ def simple_graph(): def mock_services() -> InvocationServices: # NOTE: none of these are actually called by the test invocations return InvocationServices( - generate = None, # type: ignore + model_manager = None, # type: ignore events = TestEventService(), images = None, # type: ignore queue = MemoryInvocationQueue(),