fix(config): fix nsfw_checker handling

This setting was hardcoded to True. Rework logic around it to not conditionally check the setting.
This commit is contained in:
psychedelicious
2024-03-11 22:47:01 +11:00
parent 3fb116155b
commit fbe3afa5e1
2 changed files with 14 additions and 19 deletions

View File

@ -4,15 +4,18 @@ wraps the safety_checker model. It respects the global "nsfw_checker"
configuration variable, that allows the checker to be supressed.
"""
import numpy as np
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from PIL import Image
from transformers import AutoFeatureExtractor
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.silence_warnings import SilenceWarnings
config = InvokeAIAppConfig.get_config()
config = get_config()
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
@ -31,18 +34,12 @@ class SafetyChecker:
if cls.tried_load:
return
if config.nsfw_checker:
try:
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH)
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH)
logger.info("NSFW checker initialized")
except Exception as e:
logger.warning(f"Could not load NSFW checker: {str(e)}")
else:
logger.info("NSFW checker loading disabled")
try:
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH)
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH)
logger.info("NSFW checker initialized")
except Exception as e:
logger.warning(f"Could not load NSFW checker: {str(e)}")
cls.tried_load = True
@classmethod
@ -54,7 +51,8 @@ class SafetyChecker:
def has_nsfw_concept(cls, image: Image.Image) -> bool:
if not cls.safety_checker_available():
return False
assert cls.safety_checker is not None
assert cls.feature_extractor is not None
device = choose_torch_device()
features = cls.feature_extractor([image], return_tensors="pt")
features.to(device)