fix(nodes): do not load NSFW checker model on startup

Just check if the path exists to determine if it is "available". When needed, load it.
This commit is contained in:
psychedelicious 2024-03-19 17:00:09 +11:00
parent 0e51495071
commit 2eacbb4d9d

View File

@ -4,6 +4,8 @@ wraps the safety_checker model. It respects the global "nsfw_checker"
configuration variable, that allows the checker to be supressed.
"""
from pathlib import Path
import numpy as np
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from PIL import Image
@ -41,15 +43,15 @@ class SafetyChecker:
@classmethod
def safety_checker_available(cls) -> bool:
cls._load_safety_checker()
return cls.safety_checker is not None
return Path(get_config().models_path, CHECKER_PATH).exists()
@classmethod
def has_nsfw_concept(cls, image: Image.Image) -> bool:
if not cls.safety_checker_available():
if not cls.safety_checker_available() and cls.tried_load:
return False
cls._load_safety_checker()
if cls.safety_checker is None or cls.feature_extractor is None:
return False
assert cls.safety_checker is not None
assert cls.feature_extractor is not None
device = choose_torch_device()
features = cls.feature_extractor([image], return_tensors="pt")
features.to(device)