refactored code; added watermark and nsfw facilities to app config route

This commit is contained in:
Lincoln Stein 2023-07-24 22:02:57 -04:00
parent 4194a0ed99
commit efa615a8fd
7 changed files with 139 additions and 138 deletions

View File

@ -1,9 +1,15 @@
import typing
from enum import Enum from enum import Enum
from fastapi import Body from fastapi import Body
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from pathlib import Path
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from invokeai.backend.image_util.patchmatch import PatchMatch from invokeai.backend.image_util.patchmatch import PatchMatch
from invokeai.backend.image_util.safety_checker import SafetyChecker
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.app.invocations.upscale import ESRGAN_MODELS
from invokeai.version import __version__ from invokeai.version import __version__
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
@ -30,6 +36,10 @@ class AppConfig(BaseModel):
"""App Config Response""" """App Config Response"""
infill_methods: list[str] = Field(description="List of available infill methods") infill_methods: list[str] = Field(description="List of available infill methods")
upscaling_methods: list[str] = Field(description="List of upscaling methods")
upscaling_models: list[str] = Field(description="List of postprocessing methods")
nsfw_methods: list[str] = Field(description="List of NSFW checking methods")
watermarking_methods: list[str] = Field(description="List of invisible watermark methods")
@app_router.get( @app_router.get(
@ -46,7 +56,27 @@ async def get_config() -> AppConfig:
infill_methods = ['tile'] infill_methods = ['tile']
if PatchMatch.patchmatch_available(): if PatchMatch.patchmatch_available():
infill_methods.append('patchmatch') infill_methods.append('patchmatch')
return AppConfig(infill_methods=infill_methods)
upscaling_methods = ['esrgan']
upscaling_models = []
for model in typing.get_args(ESRGAN_MODELS):
upscaling_models.append(str(Path(model).stem))
nsfw_methods = []
if SafetyChecker.safety_checker_available():
nsfw_methods.append('nsfw_checker')
watermarking_methods = []
if InvisibleWatermark.invisible_watermark_available():
watermarking_methods.append('invisible_watermark')
return AppConfig(
infill_methods=infill_methods,
upscaling_methods=upscaling_methods,
upscaling_models=upscaling_models,
nsfw_methods=nsfw_methods,
watermarking_methods=watermarking_methods,
)
@app_router.get( @app_router.get(
"/logging", "/logging",

View File

@ -8,8 +8,6 @@ from pydantic import Field
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
from invokeai.app.invocations.metadata import CoreMetadata from invokeai.app.invocations.metadata import CoreMetadata
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
from ..models.image import ( from ..models.image import (
ImageCategory, ImageField, ResourceOrigin, ImageCategory, ImageField, ResourceOrigin,
PILInvocationConfig, ImageOutput, MaskOutput, PILInvocationConfig, ImageOutput, MaskOutput,
@ -19,9 +17,8 @@ from .baseinvocation import (
InvocationContext, InvocationContext,
InvocationConfig, InvocationConfig,
) )
from ..services.config import InvokeAIAppConfig from invokeai.backend.image_util.safety_checker import SafetyChecker
from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
from invokeai.backend import SilenceWarnings
class LoadImageInvocation(BaseInvocation): class LoadImageInvocation(BaseInvocation):
"""Load an image and provide it as output.""" """Load an image and provide it as output."""
@ -614,14 +611,12 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig): class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
"""Add blur to NSFW-flagged images""" """Add blur to NSFW-flagged images"""
DEFAULT_ENABLED = InvokeAIAppConfig.get_config().nsfw_checker
# fmt: off # fmt: off
type: Literal["img_nsfw"] = "img_nsfw" type: Literal["img_nsfw"] = "img_nsfw"
# Inputs # Inputs
image: Optional[ImageField] = Field(default=None, description="The image to check") image: Optional[ImageField] = Field(default=None, description="The image to check")
enabled: bool = Field(default=DEFAULT_ENABLED, description="Whether the NSFW checker is enabled")
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image") metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
# fmt: on # fmt: on
@ -636,26 +631,10 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
config = context.services.configuration
logger = context.services.logger logger = context.services.logger
device = choose_torch_device()
if self.enabled:
logger.debug("Running NSFW checker") logger.debug("Running NSFW checker")
safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / 'core/convert/stable-diffusion-safety-checker') if SafetyChecker.has_nsfw_concept(image):
feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / 'core/convert/stable-diffusion-safety-checker') logger.info("A potentially NSFW image has been detected. Image will be blurred.")
features = feature_extractor([image], return_tensors="pt")
features.to(device)
safety_checker.to(device)
x_image = numpy.array(image).astype(numpy.float32) / 255.0
x_image = x_image[None].transpose(0, 3, 1, 2)
with SilenceWarnings():
checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=features.pixel_values)
logger.info(f"NSFW scan result: {has_nsfw_concept[0]}")
if has_nsfw_concept[0]:
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32)) blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
caution = self._get_caution_img() caution = self._get_caution_img()
blurry_image.paste(caution,(0,0),caution) blurry_image.paste(caution,(0,0),caution)
@ -685,16 +664,12 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig): class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
""" Add an invisible watermark to an image """ """ Add an invisible watermark to an image """
# to avoid circular import
DEFAULT_ENABLED = InvokeAIAppConfig.get_config().invisible_watermark
# fmt: off # fmt: off
type: Literal["img_watermark"] = "img_watermark" type: Literal["img_watermark"] = "img_watermark"
# Inputs # Inputs
image: Optional[ImageField] = Field(default=None, description="The image to check") image: Optional[ImageField] = Field(default=None, description="The image to check")
text: str = Field(default='InvokeAI', description="Watermark text") text: str = Field(default='InvokeAI', description="Watermark text")
enabled: bool = Field(default=DEFAULT_ENABLED, description="Whether the invisible watermark is enabled")
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image") metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
# fmt: on # fmt: on
@ -707,25 +682,10 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
} }
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
import cv2
from imwatermark import WatermarkEncoder
logger = context.services.logger
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
if self.enabled: new_image = InvisibleWatermark.add_watermark(image, self.text)
logger.debug("Running invisible watermarker")
bgr = cv2.cvtColor(numpy.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
wm = self.text
encoder = WatermarkEncoder()
encoder.set_watermark('bytes', wm.encode('utf-8'))
bgr_encoded = encoder.encode(bgr, 'dwtDct')
new_image = Image.fromarray(
cv2.cvtColor(bgr_encoded, cv2.COLOR_BGR2RGB)
).convert("RGBA")
image = new_image
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=image, image=new_image,
image_origin=ResourceOrigin.INTERNAL, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,

View File

@ -12,5 +12,4 @@ from .model_management import (
ModelManager, ModelCache, BaseModelType, ModelManager, ModelCache, BaseModelType,
ModelType, SubModelType, ModelInfo ModelType, SubModelType, ModelInfo
) )
from .safety_checker import SafetyChecker
from .model_management.models import SilenceWarnings from .model_management.models import SilenceWarnings

View File

@ -28,7 +28,6 @@ from diffusers.schedulers import SchedulerMixin as Scheduler
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from ..image_util import configure_model_padding from ..image_util import configure_model_padding
from ..util.util import rand_perlin_2d from ..util.util import rand_perlin_2d
from ..safety_checker import SafetyChecker
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
from ..stable_diffusion.schedulers import SCHEDULER_MAP from ..stable_diffusion.schedulers import SCHEDULER_MAP
@ -52,7 +51,6 @@ class InvokeAIGeneratorBasicParams:
v_symmetry_time_pct: Optional[float]=None v_symmetry_time_pct: Optional[float]=None
variation_amount: float = 0.0 variation_amount: float = 0.0
with_variations: list=field(default_factory=list) with_variations: list=field(default_factory=list)
safety_checker: Optional[SafetyChecker]=None
@dataclass @dataclass
class InvokeAIGeneratorOutput: class InvokeAIGeneratorOutput:
@ -240,7 +238,6 @@ class Generator:
self.seed = None self.seed = None
self.latent_channels = model.unet.config.in_channels self.latent_channels = model.unet.config.in_channels
self.downsampling_factor = downsampling # BUG: should come from model or config self.downsampling_factor = downsampling # BUG: should come from model or config
self.safety_checker = None
self.perlin = 0.0 self.perlin = 0.0
self.threshold = 0 self.threshold = 0
self.variation_amount = 0 self.variation_amount = 0
@ -277,12 +274,10 @@ class Generator:
perlin=0.0, perlin=0.0,
h_symmetry_time_pct=None, h_symmetry_time_pct=None,
v_symmetry_time_pct=None, v_symmetry_time_pct=None,
safety_checker: SafetyChecker=None,
free_gpu_mem: bool = False, free_gpu_mem: bool = False,
**kwargs, **kwargs,
): ):
scope = nullcontext scope = nullcontext
self.safety_checker = safety_checker
self.free_gpu_mem = free_gpu_mem self.free_gpu_mem = free_gpu_mem
attention_maps_images = [] attention_maps_images = []
attention_maps_callback = lambda saver: attention_maps_images.append( attention_maps_callback = lambda saver: attention_maps_images.append(
@ -329,9 +324,6 @@ class Generator:
# Pass on the seed in case a layer beneath us needs to generate noise on its own. # Pass on the seed in case a layer beneath us needs to generate noise on its own.
image = make_image(x_T, seed) image = make_image(x_T, seed)
if self.safety_checker is not None:
image = self.safety_checker.check(image)
results.append([image, seed, attention_maps_images]) results.append([image, seed, attention_maps_images])
if image_callback is not None: if image_callback is not None:

View File

@ -0,0 +1,34 @@
"""
This module defines a singleton object, "invisible_watermark" that
wraps the invisible watermark model. It respects the global "invisible_watermark"
configuration variable, that allows the watermarking to be supressed.
"""
import numpy as np
import cv2
from PIL import Image
from imwatermark import WatermarkEncoder
from invokeai.app.services.config import InvokeAIAppConfig
import invokeai.backend.util.logging as logger
config = InvokeAIAppConfig.get_config()
class InvisibleWatermark:
"""
Wrapper around InvisibleWatermark module.
"""
@classmethod
def invisible_watermark_available(self) -> bool:
return config.invisible_watermark
@classmethod
def add_watermark(self, image: Image, watermark_text:str) -> Image:
if not self.invisible_watermark_available():
return image
logger.debug(f'Applying invisible watermark "{watermark_text}"')
bgr = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
encoder = WatermarkEncoder()
encoder.set_watermark('bytes', watermark_text.encode('utf-8'))
bgr_encoded = encoder.encode(bgr, 'dwtDct')
return Image.fromarray(
cv2.cvtColor(bgr_encoded, cv2.COLOR_BGR2RGB)
).convert("RGBA")

View File

@ -0,0 +1,63 @@
"""
This module defines a singleton object, "safety_checker" that
wraps the safety_checker model. It respects the global "nsfw_checker"
configuration variable, that allows the checker to be supressed.
"""
import numpy as np
from PIL import Image
from invokeai.backend import SilenceWarnings
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util.devices import choose_torch_device
import invokeai.backend.util.logging as logger
config = InvokeAIAppConfig.get_config()
CHECKER_PATH = 'core/convert/stable-diffusion-safety-checker'
class SafetyChecker:
"""
Wrapper around SafetyChecker model.
"""
safety_checker = None
feature_extractor = None
tried_load: bool = False
@classmethod
def _load_safety_checker(self):
if self.tried_load:
return
if config.nsfw_checker:
try:
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
config.models_path / CHECKER_PATH
)
self.feature_extractor = AutoFeatureExtractor.from_pretrained(
config.models_path / CHECKER_PATH)
logger.info('NSFW checker initialized')
except Exception as e:
logger.warning(f'Could not load NSFW checker: {str(e)}')
else:
logger.info('NSFW checker loading disabled')
self.tried_load = True
@classmethod
def safety_checker_available(self) -> bool:
self._load_safety_checker()
return self.safety_checker is not None
@classmethod
def has_nsfw_concept(self, image: Image) -> bool:
if not self.safety_checker_available():
return False
device = choose_torch_device()
features = self.feature_extractor([image], return_tensors="pt")
features.to(device)
self.safety_checker.to(device)
x_image = np.array(image).astype(np.float32) / 255.0
x_image = x_image[None].transpose(0, 3, 1, 2)
with SilenceWarnings():
checked_image, has_nsfw_concept = self.safety_checker(images=x_image, clip_input=features.pixel_values)
return has_nsfw_concept[0]

View File

@ -1,77 +0,0 @@
'''
SafetyChecker class - checks images against the StabilityAI NSFW filter
and blurs images that contain potential NSFW content.
'''
import diffusers
import numpy as np
import torch
import traceback
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from pathlib import Path
from PIL import Image, ImageFilter
from transformers import AutoFeatureExtractor
import invokeai.assets.web as web_assets
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from .util import CPU_DEVICE
config = InvokeAIAppConfig.get_config()
class SafetyChecker(object):
CAUTION_IMG = "caution.png"
def __init__(self, device: torch.device):
path = Path(web_assets.__path__[0]) / self.CAUTION_IMG
caution = Image.open(path)
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
self.device = device
try:
safety_model_id = config.models_path / 'core/convert/stable-diffusion-safety-checker'
feature_extractor_id = config.models_path / 'core/convert/stable-diffusion-safety-checker-extractor'
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_id)
except Exception:
logger.error(
"An error was encountered while installing the safety checker:"
)
print(traceback.format_exc())
def check(self, image: Image.Image):
"""
Check provided image against the StabilityAI safety checker and return
"""
self.safety_checker.to(self.device)
features = self.safety_feature_extractor([image], return_tensors="pt")
features.to(self.device)
# unfortunately checker requires the numpy version, so we have to convert back
x_image = np.array(image).astype(np.float32) / 255.0
x_image = x_image[None].transpose(0, 3, 1, 2)
diffusers.logging.set_verbosity_error()
checked_image, has_nsfw_concept = self.safety_checker(
images=x_image, clip_input=features.pixel_values
)
self.safety_checker.to(CPU_DEVICE) # offload
if has_nsfw_concept[0]:
logger.warning(
"An image with potential non-safe content has been detected. A blurred image will be returned."
)
return self.blur(image)
else:
return image
def blur(self, input):
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
try:
if caution := self.caution_img:
blurry.paste(caution, (0, 0), caution)
except FileNotFoundError:
pass
return blurry