fix(backend): fix nsfw/watermarker util types

This commit is contained in:
psychedelicious 2023-10-17 16:50:52 +11:00
parent 284a257c25
commit 252c9a5f5a
2 changed files with 16 additions and 16 deletions

View File

@ -20,12 +20,12 @@ class InvisibleWatermark:
""" """
@classmethod @classmethod
def invisible_watermark_available(self) -> bool: def invisible_watermark_available(cls) -> bool:
return config.invisible_watermark return config.invisible_watermark
@classmethod @classmethod
def add_watermark(self, image: Image, watermark_text: str) -> Image: def add_watermark(cls, image: Image.Image, watermark_text: str) -> Image.Image:
if not self.invisible_watermark_available(): if not cls.invisible_watermark_available():
return image return image
logger.debug(f'Applying invisible watermark "{watermark_text}"') logger.debug(f'Applying invisible watermark "{watermark_text}"')
bgr = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR) bgr = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)

View File

@ -26,8 +26,8 @@ class SafetyChecker:
tried_load: bool = False tried_load: bool = False
@classmethod @classmethod
def _load_safety_checker(self): def _load_safety_checker(cls):
if self.tried_load: if cls.tried_load:
return return
if config.nsfw_checker: if config.nsfw_checker:
@ -35,31 +35,31 @@ class SafetyChecker:
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor from transformers import AutoFeatureExtractor
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH) cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH)
self.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH) cls.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH)
logger.info("NSFW checker initialized") logger.info("NSFW checker initialized")
except Exception as e: except Exception as e:
logger.warning(f"Could not load NSFW checker: {str(e)}") logger.warning(f"Could not load NSFW checker: {str(e)}")
else: else:
logger.info("NSFW checker loading disabled") logger.info("NSFW checker loading disabled")
self.tried_load = True cls.tried_load = True
@classmethod @classmethod
def safety_checker_available(self) -> bool: def safety_checker_available(cls) -> bool:
self._load_safety_checker() cls._load_safety_checker()
return self.safety_checker is not None return cls.safety_checker is not None
@classmethod @classmethod
def has_nsfw_concept(self, image: Image) -> bool: def has_nsfw_concept(cls, image: Image.Image) -> bool:
if not self.safety_checker_available(): if not cls.safety_checker_available():
return False return False
device = choose_torch_device() device = choose_torch_device()
features = self.feature_extractor([image], return_tensors="pt") features = cls.feature_extractor([image], return_tensors="pt")
features.to(device) features.to(device)
self.safety_checker.to(device) cls.safety_checker.to(device)
x_image = np.array(image).astype(np.float32) / 255.0 x_image = np.array(image).astype(np.float32) / 255.0
x_image = x_image[None].transpose(0, 3, 1, 2) x_image = x_image[None].transpose(0, 3, 1, 2)
with SilenceWarnings(): with SilenceWarnings():
checked_image, has_nsfw_concept = self.safety_checker(images=x_image, clip_input=features.pixel_values) checked_image, has_nsfw_concept = cls.safety_checker(images=x_image, clip_input=features.pixel_values)
return has_nsfw_concept[0] return has_nsfw_concept[0]