diff --git a/invokeai/app/api/routers/app_info.py b/invokeai/app/api/routers/app_info.py index 2137aa9be7..4cbdc81b28 100644 --- a/invokeai/app/api/routers/app_info.py +++ b/invokeai/app/api/routers/app_info.py @@ -12,7 +12,6 @@ from pydantic import BaseModel, Field from invokeai.app.invocations.upscale import ESRGAN_MODELS from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus -from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark from invokeai.backend.image_util.patchmatch import PatchMatch from invokeai.backend.image_util.safety_checker import SafetyChecker from invokeai.backend.util.logging import logging @@ -114,9 +113,7 @@ async def get_config() -> AppConfig: if SafetyChecker.safety_checker_available(): nsfw_methods.append("nsfw_checker") - watermarking_methods = [] - if InvisibleWatermark.invisible_watermark_available(): - watermarking_methods.append("invisible_watermark") + watermarking_methods = ["invisible_watermark"] return AppConfig( infill_methods=infill_methods, diff --git a/invokeai/backend/image_util/safety_checker.py b/invokeai/backend/image_util/safety_checker.py index b92a73c24f..679aa2fa10 100644 --- a/invokeai/backend/image_util/safety_checker.py +++ b/invokeai/backend/image_util/safety_checker.py @@ -4,15 +4,18 @@ wraps the safety_checker model. It respects the global "nsfw_checker" configuration variable, that allows the checker to be supressed. """ + import numpy as np +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from PIL import Image +from transformers import AutoFeatureExtractor import invokeai.backend.util.logging as logger -from invokeai.app.services.config import InvokeAIAppConfig +from invokeai.app.services.config.config_default import get_config from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.silence_warnings import SilenceWarnings -config = InvokeAIAppConfig.get_config() +config = get_config() CHECKER_PATH = "core/convert/stable-diffusion-safety-checker" @@ -31,18 +34,12 @@ class SafetyChecker: if cls.tried_load: return - if config.nsfw_checker: - try: - from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker - from transformers import AutoFeatureExtractor - - cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH) - cls.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH) - logger.info("NSFW checker initialized") - except Exception as e: - logger.warning(f"Could not load NSFW checker: {str(e)}") - else: - logger.info("NSFW checker loading disabled") + try: + cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH) + cls.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH) + logger.info("NSFW checker initialized") + except Exception as e: + logger.warning(f"Could not load NSFW checker: {str(e)}") cls.tried_load = True @classmethod @@ -54,7 +51,8 @@ class SafetyChecker: def has_nsfw_concept(cls, image: Image.Image) -> bool: if not cls.safety_checker_available(): return False - + assert cls.safety_checker is not None + assert cls.feature_extractor is not None device = choose_torch_device() features = cls.feature_extractor([image], return_tensors="pt") features.to(device)