mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(nodes): fix nsfw checker model download
This commit is contained in:
parent
eef6fcf286
commit
9c819f0fd8
@ -13,7 +13,6 @@ from pydantic import BaseModel, Field
|
|||||||
from invokeai.app.invocations.upscale import ESRGAN_MODELS
|
from invokeai.app.invocations.upscale import ESRGAN_MODELS
|
||||||
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
|
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
|
||||||
from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch
|
from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch
|
||||||
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
|
||||||
from invokeai.backend.util.logging import logging
|
from invokeai.backend.util.logging import logging
|
||||||
from invokeai.version import __version__
|
from invokeai.version import __version__
|
||||||
|
|
||||||
@ -109,9 +108,7 @@ async def get_config() -> AppConfig:
|
|||||||
upscaling_models.append(str(Path(model).stem))
|
upscaling_models.append(str(Path(model).stem))
|
||||||
upscaler = Upscaler(upscaling_method="esrgan", upscaling_models=upscaling_models)
|
upscaler = Upscaler(upscaling_method="esrgan", upscaling_models=upscaling_models)
|
||||||
|
|
||||||
nsfw_methods = []
|
nsfw_methods = ["nsfw_checker"]
|
||||||
if SafetyChecker.safety_checker_available():
|
|
||||||
nsfw_methods.append("nsfw_checker")
|
|
||||||
|
|
||||||
watermarking_methods = ["invisible_watermark"]
|
watermarking_methods = ["invisible_watermark"]
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from PIL import Image
|
from PIL import Image, ImageFilter
|
||||||
from transformers import AutoFeatureExtractor
|
from transformers import AutoFeatureExtractor
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
@ -16,6 +16,7 @@ from invokeai.app.services.config.config_default import get_config
|
|||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||||
|
|
||||||
|
repo_id = "CompVis/stable-diffusion-safety-checker"
|
||||||
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
|
CHECKER_PATH = "core/convert/stable-diffusion-safety-checker"
|
||||||
|
|
||||||
|
|
||||||
@ -24,30 +25,30 @@ class SafetyChecker:
|
|||||||
Wrapper around SafetyChecker model.
|
Wrapper around SafetyChecker model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
safety_checker = None
|
|
||||||
feature_extractor = None
|
feature_extractor = None
|
||||||
tried_load: bool = False
|
safety_checker = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _load_safety_checker(cls):
|
def _load_safety_checker(cls):
|
||||||
if cls.tried_load:
|
if cls.safety_checker is not None and cls.feature_extractor is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(get_config().models_path / CHECKER_PATH)
|
model_path = get_config().models_path / CHECKER_PATH
|
||||||
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(get_config().models_path / CHECKER_PATH)
|
if model_path.exists():
|
||||||
|
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(model_path)
|
||||||
|
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(model_path)
|
||||||
|
else:
|
||||||
|
model_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
|
||||||
|
cls.feature_extractor.save_pretrained(model_path, safe_serialization=True)
|
||||||
|
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(repo_id)
|
||||||
|
cls.safety_checker.save_pretrained(model_path, safe_serialization=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Could not load NSFW checker: {str(e)}")
|
logger.warning(f"Could not load NSFW checker: {str(e)}")
|
||||||
cls.tried_load = True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def safety_checker_available(cls) -> bool:
|
|
||||||
return Path(get_config().models_path, CHECKER_PATH).exists()
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def has_nsfw_concept(cls, image: Image.Image) -> bool:
|
def has_nsfw_concept(cls, image: Image.Image) -> bool:
|
||||||
if not cls.safety_checker_available() and cls.tried_load:
|
|
||||||
return False
|
|
||||||
cls._load_safety_checker()
|
cls._load_safety_checker()
|
||||||
if cls.safety_checker is None or cls.feature_extractor is None:
|
if cls.safety_checker is None or cls.feature_extractor is None:
|
||||||
return False
|
return False
|
||||||
@ -60,3 +61,24 @@ class SafetyChecker:
|
|||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
checked_image, has_nsfw_concept = cls.safety_checker(images=x_image, clip_input=features.pixel_values)
|
checked_image, has_nsfw_concept = cls.safety_checker(images=x_image, clip_input=features.pixel_values)
|
||||||
return has_nsfw_concept[0]
|
return has_nsfw_concept[0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def blur_if_nsfw(cls, image: Image.Image) -> Image.Image:
|
||||||
|
if cls.has_nsfw_concept(image):
|
||||||
|
logger.info("A potentially NSFW image has been detected. Image will be blurred.")
|
||||||
|
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
||||||
|
caution = cls._get_caution_img()
|
||||||
|
# Center the caution image on the blurred image
|
||||||
|
x = (blurry_image.width - caution.width) // 2
|
||||||
|
y = (blurry_image.height - caution.height) // 2
|
||||||
|
blurry_image.paste(caution, (x, y), caution)
|
||||||
|
image = blurry_image
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_caution_img(cls) -> Image.Image:
|
||||||
|
import invokeai.app.assets.images as image_assets
|
||||||
|
|
||||||
|
caution = Image.open(Path(image_assets.__path__[0]) / "caution.png")
|
||||||
|
return caution.resize((caution.width // 2, caution.height // 2))
|
||||||
|
Loading…
Reference in New Issue
Block a user