diff --git a/invokeai/backend/image_util/invisible_watermark.py b/invokeai/backend/image_util/invisible_watermark.py index 3e8604f9c3..37b3ca918c 100644 --- a/invokeai/backend/image_util/invisible_watermark.py +++ b/invokeai/backend/image_util/invisible_watermark.py @@ -20,12 +20,12 @@ class InvisibleWatermark: """ @classmethod - def invisible_watermark_available(self) -> bool: + def invisible_watermark_available(cls) -> bool: return config.invisible_watermark @classmethod - def add_watermark(self, image: Image, watermark_text: str) -> Image: - if not self.invisible_watermark_available(): + def add_watermark(cls, image: Image.Image, watermark_text: str) -> Image.Image: + if not cls.invisible_watermark_available(): return image logger.debug(f'Applying invisible watermark "{watermark_text}"') bgr = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR) diff --git a/invokeai/backend/image_util/safety_checker.py b/invokeai/backend/image_util/safety_checker.py index fd1f05f10e..b9649925e1 100644 --- a/invokeai/backend/image_util/safety_checker.py +++ b/invokeai/backend/image_util/safety_checker.py @@ -26,8 +26,8 @@ class SafetyChecker: tried_load: bool = False @classmethod - def _load_safety_checker(self): - if self.tried_load: + def _load_safety_checker(cls): + if cls.tried_load: return if config.nsfw_checker: @@ -35,31 +35,31 @@ class SafetyChecker: from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from transformers import AutoFeatureExtractor - self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH) - self.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH) + cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH) + cls.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH) logger.info("NSFW checker initialized") except Exception as e: logger.warning(f"Could not load NSFW checker: {str(e)}") else: logger.info("NSFW checker loading disabled") - self.tried_load = True + cls.tried_load = True @classmethod - def safety_checker_available(self) -> bool: - self._load_safety_checker() - return self.safety_checker is not None + def safety_checker_available(cls) -> bool: + cls._load_safety_checker() + return cls.safety_checker is not None @classmethod - def has_nsfw_concept(self, image: Image) -> bool: - if not self.safety_checker_available(): + def has_nsfw_concept(cls, image: Image.Image) -> bool: + if not cls.safety_checker_available(): return False device = choose_torch_device() - features = self.feature_extractor([image], return_tensors="pt") + features = cls.feature_extractor([image], return_tensors="pt") features.to(device) - self.safety_checker.to(device) + cls.safety_checker.to(device) x_image = np.array(image).astype(np.float32) / 255.0 x_image = x_image[None].transpose(0, 3, 1, 2) with SilenceWarnings(): - checked_image, has_nsfw_concept = self.safety_checker(images=x_image, clip_input=features.pixel_values) + checked_image, has_nsfw_concept = cls.safety_checker(images=x_image, clip_input=features.pixel_values) return has_nsfw_concept[0]