mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(backend): fix nsfw/watermarker util types
This commit is contained in:
parent
284a257c25
commit
252c9a5f5a
@ -20,12 +20,12 @@ class InvisibleWatermark:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def invisible_watermark_available(self) -> bool:
|
def invisible_watermark_available(cls) -> bool:
|
||||||
return config.invisible_watermark
|
return config.invisible_watermark
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_watermark(self, image: Image, watermark_text: str) -> Image:
|
def add_watermark(cls, image: Image.Image, watermark_text: str) -> Image.Image:
|
||||||
if not self.invisible_watermark_available():
|
if not cls.invisible_watermark_available():
|
||||||
return image
|
return image
|
||||||
logger.debug(f'Applying invisible watermark "{watermark_text}"')
|
logger.debug(f'Applying invisible watermark "{watermark_text}"')
|
||||||
bgr = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
|
bgr = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
|
||||||
|
@ -26,8 +26,8 @@ class SafetyChecker:
|
|||||||
tried_load: bool = False
|
tried_load: bool = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _load_safety_checker(self):
|
def _load_safety_checker(cls):
|
||||||
if self.tried_load:
|
if cls.tried_load:
|
||||||
return
|
return
|
||||||
|
|
||||||
if config.nsfw_checker:
|
if config.nsfw_checker:
|
||||||
@ -35,31 +35,31 @@ class SafetyChecker:
|
|||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from transformers import AutoFeatureExtractor
|
from transformers import AutoFeatureExtractor
|
||||||
|
|
||||||
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH)
|
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH)
|
||||||
self.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH)
|
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH)
|
||||||
logger.info("NSFW checker initialized")
|
logger.info("NSFW checker initialized")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Could not load NSFW checker: {str(e)}")
|
logger.warning(f"Could not load NSFW checker: {str(e)}")
|
||||||
else:
|
else:
|
||||||
logger.info("NSFW checker loading disabled")
|
logger.info("NSFW checker loading disabled")
|
||||||
self.tried_load = True
|
cls.tried_load = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def safety_checker_available(self) -> bool:
|
def safety_checker_available(cls) -> bool:
|
||||||
self._load_safety_checker()
|
cls._load_safety_checker()
|
||||||
return self.safety_checker is not None
|
return cls.safety_checker is not None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def has_nsfw_concept(self, image: Image) -> bool:
|
def has_nsfw_concept(cls, image: Image.Image) -> bool:
|
||||||
if not self.safety_checker_available():
|
if not cls.safety_checker_available():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
device = choose_torch_device()
|
device = choose_torch_device()
|
||||||
features = self.feature_extractor([image], return_tensors="pt")
|
features = cls.feature_extractor([image], return_tensors="pt")
|
||||||
features.to(device)
|
features.to(device)
|
||||||
self.safety_checker.to(device)
|
cls.safety_checker.to(device)
|
||||||
x_image = np.array(image).astype(np.float32) / 255.0
|
x_image = np.array(image).astype(np.float32) / 255.0
|
||||||
x_image = x_image[None].transpose(0, 3, 1, 2)
|
x_image = x_image[None].transpose(0, 3, 1, 2)
|
||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
checked_image, has_nsfw_concept = self.safety_checker(images=x_image, clip_input=features.pixel_values)
|
checked_image, has_nsfw_concept = cls.safety_checker(images=x_image, clip_input=features.pixel_values)
|
||||||
return has_nsfw_concept[0]
|
return has_nsfw_concept[0]
|
||||||
|
Loading…
Reference in New Issue
Block a user