mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactored code; added watermark and nsfw facilities to app config route
This commit is contained in:
34
invokeai/backend/image_util/invisible_watermark.py
Normal file
34
invokeai/backend/image_util/invisible_watermark.py
Normal file
@ -0,0 +1,34 @@
|
||||
"""
|
||||
This module defines a singleton object, "invisible_watermark" that
|
||||
wraps the invisible watermark model. It respects the global "invisible_watermark"
|
||||
configuration variable, that allows the watermarking to be supressed.
|
||||
"""
|
||||
import numpy as np
|
||||
import cv2
|
||||
from PIL import Image
|
||||
from imwatermark import WatermarkEncoder
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
import invokeai.backend.util.logging as logger
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
class InvisibleWatermark:
|
||||
"""
|
||||
Wrapper around InvisibleWatermark module.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def invisible_watermark_available(self) -> bool:
|
||||
return config.invisible_watermark
|
||||
|
||||
@classmethod
|
||||
def add_watermark(self, image: Image, watermark_text:str) -> Image:
|
||||
if not self.invisible_watermark_available():
|
||||
return image
|
||||
logger.debug(f'Applying invisible watermark "{watermark_text}"')
|
||||
bgr = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
|
||||
encoder = WatermarkEncoder()
|
||||
encoder.set_watermark('bytes', watermark_text.encode('utf-8'))
|
||||
bgr_encoded = encoder.encode(bgr, 'dwtDct')
|
||||
return Image.fromarray(
|
||||
cv2.cvtColor(bgr_encoded, cv2.COLOR_BGR2RGB)
|
||||
).convert("RGBA")
|
63
invokeai/backend/image_util/safety_checker.py
Normal file
63
invokeai/backend/image_util/safety_checker.py
Normal file
@ -0,0 +1,63 @@
|
||||
"""
|
||||
This module defines a singleton object, "safety_checker" that
|
||||
wraps the safety_checker model. It respects the global "nsfw_checker"
|
||||
configuration variable, that allows the checker to be supressed.
|
||||
"""
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from invokeai.backend import SilenceWarnings
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
import invokeai.backend.util.logging as logger
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
|
||||
CHECKER_PATH = 'core/convert/stable-diffusion-safety-checker'
|
||||
|
||||
class SafetyChecker:
|
||||
"""
|
||||
Wrapper around SafetyChecker model.
|
||||
"""
|
||||
safety_checker = None
|
||||
feature_extractor = None
|
||||
tried_load: bool = False
|
||||
|
||||
@classmethod
|
||||
def _load_safety_checker(self):
|
||||
if self.tried_load:
|
||||
return
|
||||
|
||||
if config.nsfw_checker:
|
||||
try:
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from transformers import AutoFeatureExtractor
|
||||
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
||||
config.models_path / CHECKER_PATH
|
||||
)
|
||||
self.feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
config.models_path / CHECKER_PATH)
|
||||
logger.info('NSFW checker initialized')
|
||||
except Exception as e:
|
||||
logger.warning(f'Could not load NSFW checker: {str(e)}')
|
||||
else:
|
||||
logger.info('NSFW checker loading disabled')
|
||||
self.tried_load = True
|
||||
|
||||
@classmethod
|
||||
def safety_checker_available(self) -> bool:
|
||||
self._load_safety_checker()
|
||||
return self.safety_checker is not None
|
||||
|
||||
@classmethod
|
||||
def has_nsfw_concept(self, image: Image) -> bool:
|
||||
if not self.safety_checker_available():
|
||||
return False
|
||||
|
||||
device = choose_torch_device()
|
||||
features = self.feature_extractor([image], return_tensors="pt")
|
||||
features.to(device)
|
||||
self.safety_checker.to(device)
|
||||
x_image = np.array(image).astype(np.float32) / 255.0
|
||||
x_image = x_image[None].transpose(0, 3, 1, 2)
|
||||
with SilenceWarnings():
|
||||
checked_image, has_nsfw_concept = self.safety_checker(images=x_image, clip_input=features.pixel_values)
|
||||
return has_nsfw_concept[0]
|
Reference in New Issue
Block a user