mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactored code; added watermark and nsfw facilities to app config route
This commit is contained in:
parent
4194a0ed99
commit
efa615a8fd
@ -1,9 +1,15 @@
|
|||||||
|
import typing
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
|
from pathlib import Path
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from invokeai.backend.image_util.patchmatch import PatchMatch
|
from invokeai.backend.image_util.patchmatch import PatchMatch
|
||||||
|
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||||
|
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||||
|
from invokeai.app.invocations.upscale import ESRGAN_MODELS
|
||||||
|
|
||||||
from invokeai.version import __version__
|
from invokeai.version import __version__
|
||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
@ -30,6 +36,10 @@ class AppConfig(BaseModel):
|
|||||||
"""App Config Response"""
|
"""App Config Response"""
|
||||||
|
|
||||||
infill_methods: list[str] = Field(description="List of available infill methods")
|
infill_methods: list[str] = Field(description="List of available infill methods")
|
||||||
|
upscaling_methods: list[str] = Field(description="List of upscaling methods")
|
||||||
|
upscaling_models: list[str] = Field(description="List of postprocessing methods")
|
||||||
|
nsfw_methods: list[str] = Field(description="List of NSFW checking methods")
|
||||||
|
watermarking_methods: list[str] = Field(description="List of invisible watermark methods")
|
||||||
|
|
||||||
|
|
||||||
@app_router.get(
|
@app_router.get(
|
||||||
@ -46,7 +56,27 @@ async def get_config() -> AppConfig:
|
|||||||
infill_methods = ['tile']
|
infill_methods = ['tile']
|
||||||
if PatchMatch.patchmatch_available():
|
if PatchMatch.patchmatch_available():
|
||||||
infill_methods.append('patchmatch')
|
infill_methods.append('patchmatch')
|
||||||
return AppConfig(infill_methods=infill_methods)
|
|
||||||
|
upscaling_methods = ['esrgan']
|
||||||
|
upscaling_models = []
|
||||||
|
for model in typing.get_args(ESRGAN_MODELS):
|
||||||
|
upscaling_models.append(str(Path(model).stem))
|
||||||
|
|
||||||
|
nsfw_methods = []
|
||||||
|
if SafetyChecker.safety_checker_available():
|
||||||
|
nsfw_methods.append('nsfw_checker')
|
||||||
|
|
||||||
|
watermarking_methods = []
|
||||||
|
if InvisibleWatermark.invisible_watermark_available():
|
||||||
|
watermarking_methods.append('invisible_watermark')
|
||||||
|
|
||||||
|
return AppConfig(
|
||||||
|
infill_methods=infill_methods,
|
||||||
|
upscaling_methods=upscaling_methods,
|
||||||
|
upscaling_models=upscaling_models,
|
||||||
|
nsfw_methods=nsfw_methods,
|
||||||
|
watermarking_methods=watermarking_methods,
|
||||||
|
)
|
||||||
|
|
||||||
@app_router.get(
|
@app_router.get(
|
||||||
"/logging",
|
"/logging",
|
||||||
|
@ -8,8 +8,6 @@ from pydantic import Field
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from invokeai.app.invocations.metadata import CoreMetadata
|
from invokeai.app.invocations.metadata import CoreMetadata
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
|
||||||
from transformers import AutoFeatureExtractor
|
|
||||||
from ..models.image import (
|
from ..models.image import (
|
||||||
ImageCategory, ImageField, ResourceOrigin,
|
ImageCategory, ImageField, ResourceOrigin,
|
||||||
PILInvocationConfig, ImageOutput, MaskOutput,
|
PILInvocationConfig, ImageOutput, MaskOutput,
|
||||||
@ -19,9 +17,8 @@ from .baseinvocation import (
|
|||||||
InvocationContext,
|
InvocationContext,
|
||||||
InvocationConfig,
|
InvocationConfig,
|
||||||
)
|
)
|
||||||
from ..services.config import InvokeAIAppConfig
|
from invokeai.backend.image_util.safety_checker import SafetyChecker
|
||||||
from invokeai.backend.util.devices import choose_torch_device
|
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
|
||||||
from invokeai.backend import SilenceWarnings
|
|
||||||
|
|
||||||
class LoadImageInvocation(BaseInvocation):
|
class LoadImageInvocation(BaseInvocation):
|
||||||
"""Load an image and provide it as output."""
|
"""Load an image and provide it as output."""
|
||||||
@ -614,14 +611,12 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
|
|
||||||
class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
|
class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
"""Add blur to NSFW-flagged images"""
|
"""Add blur to NSFW-flagged images"""
|
||||||
DEFAULT_ENABLED = InvokeAIAppConfig.get_config().nsfw_checker
|
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["img_nsfw"] = "img_nsfw"
|
type: Literal["img_nsfw"] = "img_nsfw"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to check")
|
image: Optional[ImageField] = Field(default=None, description="The image to check")
|
||||||
enabled: bool = Field(default=DEFAULT_ENABLED, description="Whether the NSFW checker is enabled")
|
|
||||||
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
|
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
@ -636,30 +631,14 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
|
|
||||||
config = context.services.configuration
|
|
||||||
logger = context.services.logger
|
logger = context.services.logger
|
||||||
device = choose_torch_device()
|
logger.debug("Running NSFW checker")
|
||||||
|
if SafetyChecker.has_nsfw_concept(image):
|
||||||
if self.enabled:
|
logger.info("A potentially NSFW image has been detected. Image will be blurred.")
|
||||||
logger.debug("Running NSFW checker")
|
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
||||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / 'core/convert/stable-diffusion-safety-checker')
|
caution = self._get_caution_img()
|
||||||
feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / 'core/convert/stable-diffusion-safety-checker')
|
blurry_image.paste(caution,(0,0),caution)
|
||||||
|
image = blurry_image
|
||||||
features = feature_extractor([image], return_tensors="pt")
|
|
||||||
features.to(device)
|
|
||||||
safety_checker.to(device)
|
|
||||||
|
|
||||||
x_image = numpy.array(image).astype(numpy.float32) / 255.0
|
|
||||||
x_image = x_image[None].transpose(0, 3, 1, 2)
|
|
||||||
with SilenceWarnings():
|
|
||||||
checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=features.pixel_values)
|
|
||||||
|
|
||||||
logger.info(f"NSFW scan result: {has_nsfw_concept[0]}")
|
|
||||||
if has_nsfw_concept[0]:
|
|
||||||
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
|
||||||
caution = self._get_caution_img()
|
|
||||||
blurry_image.paste(caution,(0,0),caution)
|
|
||||||
image = blurry_image
|
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=image,
|
image=image,
|
||||||
@ -685,16 +664,12 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
|
class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
|
||||||
""" Add an invisible watermark to an image """
|
""" Add an invisible watermark to an image """
|
||||||
|
|
||||||
# to avoid circular import
|
|
||||||
DEFAULT_ENABLED = InvokeAIAppConfig.get_config().invisible_watermark
|
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["img_watermark"] = "img_watermark"
|
type: Literal["img_watermark"] = "img_watermark"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to check")
|
image: Optional[ImageField] = Field(default=None, description="The image to check")
|
||||||
text: str = Field(default='InvokeAI', description="Watermark text")
|
text: str = Field(default='InvokeAI', description="Watermark text")
|
||||||
enabled: bool = Field(default=DEFAULT_ENABLED, description="Whether the invisible watermark is enabled")
|
|
||||||
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
|
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
@ -707,25 +682,10 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
import cv2
|
|
||||||
from imwatermark import WatermarkEncoder
|
|
||||||
|
|
||||||
logger = context.services.logger
|
|
||||||
image = context.services.images.get_pil_image(self.image.image_name)
|
image = context.services.images.get_pil_image(self.image.image_name)
|
||||||
if self.enabled:
|
new_image = InvisibleWatermark.add_watermark(image, self.text)
|
||||||
logger.debug("Running invisible watermarker")
|
|
||||||
bgr = cv2.cvtColor(numpy.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
|
|
||||||
wm = self.text
|
|
||||||
encoder = WatermarkEncoder()
|
|
||||||
encoder.set_watermark('bytes', wm.encode('utf-8'))
|
|
||||||
bgr_encoded = encoder.encode(bgr, 'dwtDct')
|
|
||||||
new_image = Image.fromarray(
|
|
||||||
cv2.cvtColor(bgr_encoded, cv2.COLOR_BGR2RGB)
|
|
||||||
).convert("RGBA")
|
|
||||||
image = new_image
|
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=image,
|
image=new_image,
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
image_category=ImageCategory.GENERAL,
|
image_category=ImageCategory.GENERAL,
|
||||||
node_id=self.id,
|
node_id=self.id,
|
||||||
|
@ -12,5 +12,4 @@ from .model_management import (
|
|||||||
ModelManager, ModelCache, BaseModelType,
|
ModelManager, ModelCache, BaseModelType,
|
||||||
ModelType, SubModelType, ModelInfo
|
ModelType, SubModelType, ModelInfo
|
||||||
)
|
)
|
||||||
from .safety_checker import SafetyChecker
|
|
||||||
from .model_management.models import SilenceWarnings
|
from .model_management.models import SilenceWarnings
|
||||||
|
@ -28,7 +28,6 @@ from diffusers.schedulers import SchedulerMixin as Scheduler
|
|||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from ..image_util import configure_model_padding
|
from ..image_util import configure_model_padding
|
||||||
from ..util.util import rand_perlin_2d
|
from ..util.util import rand_perlin_2d
|
||||||
from ..safety_checker import SafetyChecker
|
|
||||||
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||||
from ..stable_diffusion.schedulers import SCHEDULER_MAP
|
from ..stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
|
|
||||||
@ -52,7 +51,6 @@ class InvokeAIGeneratorBasicParams:
|
|||||||
v_symmetry_time_pct: Optional[float]=None
|
v_symmetry_time_pct: Optional[float]=None
|
||||||
variation_amount: float = 0.0
|
variation_amount: float = 0.0
|
||||||
with_variations: list=field(default_factory=list)
|
with_variations: list=field(default_factory=list)
|
||||||
safety_checker: Optional[SafetyChecker]=None
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InvokeAIGeneratorOutput:
|
class InvokeAIGeneratorOutput:
|
||||||
@ -240,7 +238,6 @@ class Generator:
|
|||||||
self.seed = None
|
self.seed = None
|
||||||
self.latent_channels = model.unet.config.in_channels
|
self.latent_channels = model.unet.config.in_channels
|
||||||
self.downsampling_factor = downsampling # BUG: should come from model or config
|
self.downsampling_factor = downsampling # BUG: should come from model or config
|
||||||
self.safety_checker = None
|
|
||||||
self.perlin = 0.0
|
self.perlin = 0.0
|
||||||
self.threshold = 0
|
self.threshold = 0
|
||||||
self.variation_amount = 0
|
self.variation_amount = 0
|
||||||
@ -277,12 +274,10 @@ class Generator:
|
|||||||
perlin=0.0,
|
perlin=0.0,
|
||||||
h_symmetry_time_pct=None,
|
h_symmetry_time_pct=None,
|
||||||
v_symmetry_time_pct=None,
|
v_symmetry_time_pct=None,
|
||||||
safety_checker: SafetyChecker=None,
|
|
||||||
free_gpu_mem: bool = False,
|
free_gpu_mem: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
scope = nullcontext
|
scope = nullcontext
|
||||||
self.safety_checker = safety_checker
|
|
||||||
self.free_gpu_mem = free_gpu_mem
|
self.free_gpu_mem = free_gpu_mem
|
||||||
attention_maps_images = []
|
attention_maps_images = []
|
||||||
attention_maps_callback = lambda saver: attention_maps_images.append(
|
attention_maps_callback = lambda saver: attention_maps_images.append(
|
||||||
@ -329,9 +324,6 @@ class Generator:
|
|||||||
# Pass on the seed in case a layer beneath us needs to generate noise on its own.
|
# Pass on the seed in case a layer beneath us needs to generate noise on its own.
|
||||||
image = make_image(x_T, seed)
|
image = make_image(x_T, seed)
|
||||||
|
|
||||||
if self.safety_checker is not None:
|
|
||||||
image = self.safety_checker.check(image)
|
|
||||||
|
|
||||||
results.append([image, seed, attention_maps_images])
|
results.append([image, seed, attention_maps_images])
|
||||||
|
|
||||||
if image_callback is not None:
|
if image_callback is not None:
|
||||||
|
34
invokeai/backend/image_util/invisible_watermark.py
Normal file
34
invokeai/backend/image_util/invisible_watermark.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
"""
|
||||||
|
This module defines a singleton object, "invisible_watermark" that
|
||||||
|
wraps the invisible watermark model. It respects the global "invisible_watermark"
|
||||||
|
configuration variable, that allows the watermarking to be supressed.
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
from PIL import Image
|
||||||
|
from imwatermark import WatermarkEncoder
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
import invokeai.backend.util.logging as logger
|
||||||
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
|
class InvisibleWatermark:
|
||||||
|
"""
|
||||||
|
Wrapper around InvisibleWatermark module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def invisible_watermark_available(self) -> bool:
|
||||||
|
return config.invisible_watermark
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_watermark(self, image: Image, watermark_text:str) -> Image:
|
||||||
|
if not self.invisible_watermark_available():
|
||||||
|
return image
|
||||||
|
logger.debug(f'Applying invisible watermark "{watermark_text}"')
|
||||||
|
bgr = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
|
||||||
|
encoder = WatermarkEncoder()
|
||||||
|
encoder.set_watermark('bytes', watermark_text.encode('utf-8'))
|
||||||
|
bgr_encoded = encoder.encode(bgr, 'dwtDct')
|
||||||
|
return Image.fromarray(
|
||||||
|
cv2.cvtColor(bgr_encoded, cv2.COLOR_BGR2RGB)
|
||||||
|
).convert("RGBA")
|
63
invokeai/backend/image_util/safety_checker.py
Normal file
63
invokeai/backend/image_util/safety_checker.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
"""
|
||||||
|
This module defines a singleton object, "safety_checker" that
|
||||||
|
wraps the safety_checker model. It respects the global "nsfw_checker"
|
||||||
|
configuration variable, that allows the checker to be supressed.
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from invokeai.backend import SilenceWarnings
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.backend.util.devices import choose_torch_device
|
||||||
|
import invokeai.backend.util.logging as logger
|
||||||
|
config = InvokeAIAppConfig.get_config()
|
||||||
|
|
||||||
|
CHECKER_PATH = 'core/convert/stable-diffusion-safety-checker'
|
||||||
|
|
||||||
|
class SafetyChecker:
|
||||||
|
"""
|
||||||
|
Wrapper around SafetyChecker model.
|
||||||
|
"""
|
||||||
|
safety_checker = None
|
||||||
|
feature_extractor = None
|
||||||
|
tried_load: bool = False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _load_safety_checker(self):
|
||||||
|
if self.tried_load:
|
||||||
|
return
|
||||||
|
|
||||||
|
if config.nsfw_checker:
|
||||||
|
try:
|
||||||
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
|
from transformers import AutoFeatureExtractor
|
||||||
|
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
||||||
|
config.models_path / CHECKER_PATH
|
||||||
|
)
|
||||||
|
self.feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||||
|
config.models_path / CHECKER_PATH)
|
||||||
|
logger.info('NSFW checker initialized')
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f'Could not load NSFW checker: {str(e)}')
|
||||||
|
else:
|
||||||
|
logger.info('NSFW checker loading disabled')
|
||||||
|
self.tried_load = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def safety_checker_available(self) -> bool:
|
||||||
|
self._load_safety_checker()
|
||||||
|
return self.safety_checker is not None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def has_nsfw_concept(self, image: Image) -> bool:
|
||||||
|
if not self.safety_checker_available():
|
||||||
|
return False
|
||||||
|
|
||||||
|
device = choose_torch_device()
|
||||||
|
features = self.feature_extractor([image], return_tensors="pt")
|
||||||
|
features.to(device)
|
||||||
|
self.safety_checker.to(device)
|
||||||
|
x_image = np.array(image).astype(np.float32) / 255.0
|
||||||
|
x_image = x_image[None].transpose(0, 3, 1, 2)
|
||||||
|
with SilenceWarnings():
|
||||||
|
checked_image, has_nsfw_concept = self.safety_checker(images=x_image, clip_input=features.pixel_values)
|
||||||
|
return has_nsfw_concept[0]
|
@ -1,77 +0,0 @@
|
|||||||
'''
|
|
||||||
SafetyChecker class - checks images against the StabilityAI NSFW filter
|
|
||||||
and blurs images that contain potential NSFW content.
|
|
||||||
'''
|
|
||||||
import diffusers
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import traceback
|
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
|
||||||
StableDiffusionSafetyChecker,
|
|
||||||
)
|
|
||||||
from pathlib import Path
|
|
||||||
from PIL import Image, ImageFilter
|
|
||||||
from transformers import AutoFeatureExtractor
|
|
||||||
|
|
||||||
import invokeai.assets.web as web_assets
|
|
||||||
import invokeai.backend.util.logging as logger
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
from .util import CPU_DEVICE
|
|
||||||
|
|
||||||
config = InvokeAIAppConfig.get_config()
|
|
||||||
|
|
||||||
class SafetyChecker(object):
|
|
||||||
CAUTION_IMG = "caution.png"
|
|
||||||
|
|
||||||
def __init__(self, device: torch.device):
|
|
||||||
path = Path(web_assets.__path__[0]) / self.CAUTION_IMG
|
|
||||||
caution = Image.open(path)
|
|
||||||
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
try:
|
|
||||||
safety_model_id = config.models_path / 'core/convert/stable-diffusion-safety-checker'
|
|
||||||
feature_extractor_id = config.models_path / 'core/convert/stable-diffusion-safety-checker-extractor'
|
|
||||||
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
|
||||||
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_id)
|
|
||||||
except Exception:
|
|
||||||
logger.error(
|
|
||||||
"An error was encountered while installing the safety checker:"
|
|
||||||
)
|
|
||||||
print(traceback.format_exc())
|
|
||||||
|
|
||||||
def check(self, image: Image.Image):
|
|
||||||
"""
|
|
||||||
Check provided image against the StabilityAI safety checker and return
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.safety_checker.to(self.device)
|
|
||||||
features = self.safety_feature_extractor([image], return_tensors="pt")
|
|
||||||
features.to(self.device)
|
|
||||||
|
|
||||||
# unfortunately checker requires the numpy version, so we have to convert back
|
|
||||||
x_image = np.array(image).astype(np.float32) / 255.0
|
|
||||||
x_image = x_image[None].transpose(0, 3, 1, 2)
|
|
||||||
|
|
||||||
diffusers.logging.set_verbosity_error()
|
|
||||||
checked_image, has_nsfw_concept = self.safety_checker(
|
|
||||||
images=x_image, clip_input=features.pixel_values
|
|
||||||
)
|
|
||||||
self.safety_checker.to(CPU_DEVICE) # offload
|
|
||||||
if has_nsfw_concept[0]:
|
|
||||||
logger.warning(
|
|
||||||
"An image with potential non-safe content has been detected. A blurred image will be returned."
|
|
||||||
)
|
|
||||||
return self.blur(image)
|
|
||||||
else:
|
|
||||||
return image
|
|
||||||
|
|
||||||
def blur(self, input):
|
|
||||||
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
|
||||||
try:
|
|
||||||
if caution := self.caution_img:
|
|
||||||
blurry.paste(caution, (0, 0), caution)
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
return blurry
|
|
Loading…
Reference in New Issue
Block a user