mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(config): fix nsfw_checker handling
This setting was hardcoded to True. Rework logic around it to not conditionally check the setting.
This commit is contained in:
parent
3fb116155b
commit
fbe3afa5e1
@ -12,7 +12,6 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
from invokeai.app.invocations.upscale import ESRGAN_MODELS
|
from invokeai.app.invocations.upscale import ESRGAN_MODELS
|
||||||
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
|
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
|
||||||
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
|
||||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||||
from invokeai.backend.util.logging import logging
|
from invokeai.backend.util.logging import logging
|
||||||
@ -114,9 +113,7 @@ async def get_config() -> AppConfig:
|
|||||||
if SafetyChecker.safety_checker_available():
|
if SafetyChecker.safety_checker_available():
|
||||||
nsfw_methods.append("nsfw_checker")
|
nsfw_methods.append("nsfw_checker")
|
||||||
|
|
||||||
watermarking_methods = []
|
watermarking_methods = ["invisible_watermark"]
|
||||||
if InvisibleWatermark.invisible_watermark_available():
|
|
||||||
watermarking_methods.append("invisible_watermark")
|
|
||||||
|
|
||||||
return AppConfig(
|
return AppConfig(
|
||||||
infill_methods=infill_methods,
|
infill_methods=infill_methods,
|
||||||
|
@ -4,15 +4,18 @@ wraps the safety_checker model. It respects the global "nsfw_checker"
|
|||||||
configuration variable, that allows the checker to be supressed.
|
configuration variable, that allows the checker to be supressed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from transformers import AutoFeatureExtractor
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.util.devices import choose_torch_device
|
||||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||||
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
config = get_config()
|
||||||
|
|
||||||
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
|
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
|
||||||
|
|
||||||
@ -31,18 +34,12 @@ class SafetyChecker:
|
|||||||
if cls.tried_load:
|
if cls.tried_load:
|
||||||
return
|
return
|
||||||
|
|
||||||
if config.nsfw_checker:
|
try:
|
||||||
try:
|
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH)
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH)
|
||||||
from transformers import AutoFeatureExtractor
|
logger.info("NSFW checker initialized")
|
||||||
|
except Exception as e:
|
||||||
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH)
|
logger.warning(f"Could not load NSFW checker: {str(e)}")
|
||||||
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH)
|
|
||||||
logger.info("NSFW checker initialized")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Could not load NSFW checker: {str(e)}")
|
|
||||||
else:
|
|
||||||
logger.info("NSFW checker loading disabled")
|
|
||||||
cls.tried_load = True
|
cls.tried_load = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -54,7 +51,8 @@ class SafetyChecker:
|
|||||||
def has_nsfw_concept(cls, image: Image.Image) -> bool:
|
def has_nsfw_concept(cls, image: Image.Image) -> bool:
|
||||||
if not cls.safety_checker_available():
|
if not cls.safety_checker_available():
|
||||||
return False
|
return False
|
||||||
|
assert cls.safety_checker is not None
|
||||||
|
assert cls.feature_extractor is not None
|
||||||
device = choose_torch_device()
|
device = choose_torch_device()
|
||||||
features = cls.feature_extractor([image], return_tensors="pt")
|
features = cls.feature_extractor([image], return_tensors="pt")
|
||||||
features.to(device)
|
features.to(device)
|
||||||
|
Loading…
Reference in New Issue
Block a user