fix(nodes): do not load NSFW checker model on startup

Just check if the path exists to determine if it is "available". When needed, load it.
This commit is contained in:
psychedelicious 2024-03-19 17:00:09 +11:00
parent 0e51495071
commit 2eacbb4d9d

View File

@ -4,6 +4,8 @@ wraps the safety_checker model. It respects the global "nsfw_checker"
configuration variable, that allows the checker to be supressed. configuration variable, that allows the checker to be supressed.
""" """
from pathlib import Path
import numpy as np import numpy as np
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from PIL import Image from PIL import Image
@ -41,15 +43,15 @@ class SafetyChecker:
@classmethod @classmethod
def safety_checker_available(cls) -> bool: def safety_checker_available(cls) -> bool:
cls._load_safety_checker() return Path(get_config().models_path, CHECKER_PATH).exists()
return cls.safety_checker is not None
@classmethod @classmethod
def has_nsfw_concept(cls, image: Image.Image) -> bool: def has_nsfw_concept(cls, image: Image.Image) -> bool:
if not cls.safety_checker_available(): if not cls.safety_checker_available() and cls.tried_load:
return False
cls._load_safety_checker()
if cls.safety_checker is None or cls.feature_extractor is None:
return False return False
assert cls.safety_checker is not None
assert cls.feature_extractor is not None
device = choose_torch_device() device = choose_torch_device()
features = cls.feature_extractor([image], return_tensors="pt") features = cls.feature_extractor([image], return_tensors="pt")
features.to(device) features.to(device)