fix(config): fix nsfw_checker handling

This setting was hardcoded to True. Rework logic around it to not conditionally check the setting.
This commit is contained in:
psychedelicious 2024-03-11 22:47:01 +11:00
parent 3fb116155b
commit fbe3afa5e1
2 changed files with 14 additions and 19 deletions

View File

@ -12,7 +12,6 @@ from pydantic import BaseModel, Field
from invokeai.app.invocations.upscale import ESRGAN_MODELS from invokeai.app.invocations.upscale import ESRGAN_MODELS
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend.image_util.patchmatch import PatchMatch from invokeai.backend.image_util.patchmatch import PatchMatch
from invokeai.backend.image_util.safety_checker import SafetyChecker from invokeai.backend.image_util.safety_checker import SafetyChecker
from invokeai.backend.util.logging import logging from invokeai.backend.util.logging import logging
@ -114,9 +113,7 @@ async def get_config() -> AppConfig:
if SafetyChecker.safety_checker_available(): if SafetyChecker.safety_checker_available():
nsfw_methods.append("nsfw_checker") nsfw_methods.append("nsfw_checker")
watermarking_methods = [] watermarking_methods = ["invisible_watermark"]
if InvisibleWatermark.invisible_watermark_available():
watermarking_methods.append("invisible_watermark")
return AppConfig( return AppConfig(
infill_methods=infill_methods, infill_methods=infill_methods,

View File

@ -4,15 +4,18 @@ wraps the safety_checker model. It respects the global "nsfw_checker"
configuration variable, that allows the checker to be supressed. configuration variable, that allows the checker to be supressed.
""" """
import numpy as np import numpy as np
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from PIL import Image from PIL import Image
from transformers import AutoFeatureExtractor
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config.config_default import get_config
from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.silence_warnings import SilenceWarnings from invokeai.backend.util.silence_warnings import SilenceWarnings
config = InvokeAIAppConfig.get_config() config = get_config()
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker" CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
@ -31,18 +34,12 @@ class SafetyChecker:
if cls.tried_load: if cls.tried_load:
return return
if config.nsfw_checker: try:
try: cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH)
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker cls.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH)
from transformers import AutoFeatureExtractor logger.info("NSFW checker initialized")
except Exception as e:
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH) logger.warning(f"Could not load NSFW checker: {str(e)}")
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH)
logger.info("NSFW checker initialized")
except Exception as e:
logger.warning(f"Could not load NSFW checker: {str(e)}")
else:
logger.info("NSFW checker loading disabled")
cls.tried_load = True cls.tried_load = True
@classmethod @classmethod
@ -54,7 +51,8 @@ class SafetyChecker:
def has_nsfw_concept(cls, image: Image.Image) -> bool: def has_nsfw_concept(cls, image: Image.Image) -> bool:
if not cls.safety_checker_available(): if not cls.safety_checker_available():
return False return False
assert cls.safety_checker is not None
assert cls.feature_extractor is not None
device = choose_torch_device() device = choose_torch_device()
features = cls.feature_extractor([image], return_tensors="pt") features = cls.feature_extractor([image], return_tensors="pt")
features.to(device) features.to(device)