update linear graphs to perform safety checking and watermarking

This commit is contained in:
Lincoln Stein 2023-07-23 23:32:08 -04:00 committed by psychedelicious
parent e32cd794f7
commit bd43751323
14 changed files with 243 additions and 112 deletions

View File

@ -20,7 +20,7 @@ from ...backend.model_management import BaseModelType, ModelType
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
from .image import ImageOutput, PILInvocationConfig
from .image_defs import ImageOutput, PILInvocationConfig
CONTROLNET_DEFAULT_MODELS = [
###########################################

View File

@ -7,10 +7,9 @@ from PIL import Image, ImageFilter, ImageOps, ImageChops
from pydantic import BaseModel, Field
from pathlib import Path
from typing import Union
from invokeai.app.invocations.metadata import CoreMetadata
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import (
BaseInvocation,
@ -18,53 +17,15 @@ from .baseinvocation import (
InvocationContext,
InvocationConfig,
)
from .image_defs import (
PILInvocationConfig,
ImageOutput,
MaskOutput,
)
from ..services.config import InvokeAIAppConfig
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend import SilenceWarnings
class PILInvocationConfig(BaseModel):
"""Helper class to provide all PIL invocations with additional config"""
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["PIL", "image"],
},
}
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
# fmt: off
type: Literal["image_output"] = "image_output"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
class Config:
schema_extra = {"required": ["type", "image", "width", "height"]}
class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask"""
# fmt: off
type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask")
width: int = Field(description="The width of the mask in pixels")
height: int = Field(description="The height of the mask in pixels")
# fmt: on
class Config:
schema_extra = {
"required": [
"type",
"mask",
]
}
class LoadImageInvocation(BaseInvocation):
"""Load an image and provide it as output."""
@ -656,13 +617,15 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
"""Add blur to NSFW-flagged images"""
DEFAULT_ENABLED = InvokeAIAppConfig.get_config().nsfw_checker
# fmt: off
type: Literal["img_nsfw"] = "img_nsfw"
# Inputs
image: Optional[ImageField] = Field(default=None, description="The image to check")
active: bool = Field(default=True, description="Whether the NSFW checker is active")
enabled: bool = Field(default=DEFAULT_ENABLED, description="Whether the NSFW checker is enabled")
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
# fmt: on
class Config(InvocationConfig):
@ -676,55 +639,46 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name)
if not self.active:
return ImageOutput(
image=ImageField(image_name=self.image.image_name),
width=image.width,
height=image.height,
)
config = context.services.configuration
logger = context.services.logger
device = choose_torch_device()
logger.info("Running NSFW checker")
safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / 'core/convert/stable-diffusion-safety-checker')
feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / 'core/convert/stable-diffusion-safety-checker')
features = feature_extractor([image], return_tensors="pt")
features.to(device)
safety_checker.to(device)
if self.enabled:
logger.info("Running NSFW checker")
safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / 'core/convert/stable-diffusion-safety-checker')
feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / 'core/convert/stable-diffusion-safety-checker')
x_image = numpy.array(image).astype(numpy.float32) / 255.0
x_image = x_image[None].transpose(0, 3, 1, 2)
with SilenceWarnings():
checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=features.pixel_values)
features = feature_extractor([image], return_tensors="pt")
features.to(device)
safety_checker.to(device)
logger.info(f"NSFW scan result: {has_nsfw_concept[0]}")
if has_nsfw_concept[0]:
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
caution = self._get_caution_img()
blurry_image.paste(caution,(0,0),caution)
x_image = numpy.array(image).astype(numpy.float32) / 255.0
x_image = x_image[None].transpose(0, 3, 1, 2)
with SilenceWarnings():
checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=features.pixel_values)
image_dto = context.services.images.create(
image=blurry_image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
else:
return ImageOutput(
image=ImageField(image_name=self.image.image_name),
width=image.width,
height=image.height,
)
logger.info(f"NSFW scan result: {has_nsfw_concept[0]}")
if has_nsfw_concept[0]:
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
caution = self._get_caution_img()
blurry_image.paste(caution,(0,0),caution)
image = blurry_image
image_dto = context.services.images.create(
image=image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
def _get_caution_img(self)->Image:
import invokeai.assets.web as web_assets
@ -733,12 +687,18 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
""" Add an invisible watermark to an image """
# to avoid circular import
DEFAULT_ENABLED = InvokeAIAppConfig.get_config().invisible_watermark
# fmt: off
type: Literal["img_watermark"] = "img_watermark"
# Inputs
image: Optional[ImageField] = Field(default=None, description="The image to check")
text: str = Field(default='InvokeAI', description="Watermark text")
enabled: bool = Field(default=DEFAULT_ENABLED, description="Whether the invisible watermark is enabled")
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
# fmt: on
class Config(InvocationConfig):
@ -753,23 +713,28 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
import cv2
from imwatermark import WatermarkEncoder
logger = context.services.logger
image = context.services.images.get_pil_image(self.image.image_name)
bgr = cv2.cvtColor(numpy.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
wm = self.text
encoder = WatermarkEncoder()
encoder.set_watermark('bytes', wm.encode('utf-8'))
bgr_encoded = encoder.encode(bgr, 'dwtDct')
new_image = Image.fromarray(
cv2.cvtColor(bgr_encoded, cv2.COLOR_BGR2RGB)
).convert("RGBA")
if self.enabled:
logger.info("Running invisible watermarker")
bgr = cv2.cvtColor(numpy.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
wm = self.text
encoder = WatermarkEncoder()
encoder.set_watermark('bytes', wm.encode('utf-8'))
bgr_encoded = encoder.encode(bgr, 'dwtDct')
new_image = Image.fromarray(
cv2.cvtColor(bgr_encoded, cv2.COLOR_BGR2RGB)
).convert("RGBA")
image = new_image
image_dto = context.services.images.create(
image=new_image,
image=image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
)
return ImageOutput(

View File

@ -0,0 +1,54 @@
# Copyright 2023 Lincoln D. Stein and the InvokeAI Team
""" Common classes used by .image and .controlnet; avoids circular import issues """
from pydantic import BaseModel, Field
from typing import Literal
from ..models.image import ImageField
from .baseinvocation import (
BaseInvocationOutput,
InvocationConfig,
)
class PILInvocationConfig(BaseModel):
"""Helper class to provide all PIL invocations with additional config"""
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["PIL", "image"],
},
}
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
# fmt: off
type: Literal["image_output"] = "image_output"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
class Config:
schema_extra = {"required": ["type", "image", "width", "height"]}
class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask"""
# fmt: off
type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask")
width: int = Field(description="The width of the mask in pixels")
height: int = Field(description="The height of the mask in pixels")
# fmt: on
class Config:
schema_extra = {
"required": [
"type",
"mask",
]
}

View File

@ -501,7 +501,7 @@ class LatentsToImageInvocation(BaseInvocation):
vae: VaeField = Field(default=None, description="Vae submodel")
tiled: bool = Field(
default=False,
description="Decode latents by overlaping tiles(less memory consumption)")
description="Decode latents by overlapping tiles(less memory consumption)")
fp32: bool = Field(DEFAULT_PRECISION=='float32', description="Decode in full precision")
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")

View File

@ -11,7 +11,6 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
class LoRAMetadataField(BaseModel):
"""LoRA metadata for an image generated in InvokeAI."""

View File

@ -135,6 +135,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
board_id: str,
image_name: str,
) -> None:
print(f'DEBUG: board_id={board_id}, image_name={image_name}')
try:
self._lock.acquire()
self._cursor.execute(
@ -146,6 +147,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
(board_id, image_name, board_id),
)
self._conn.commit()
print('got here')
except sqlite3.Error as e:
self._conn.rollback()
raise e

View File

@ -365,6 +365,7 @@ setting environment variables INVOKEAI_<setting>.
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker", category='Features')
invisible_watermark : bool = Field(default=True, description="Enable/disable the invisible watermark", category='Features')
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
restore : bool = Field(default=True, description="Enable/disable face restoration code (DEPRECATED)", category='DEPRECATED')

View File

@ -1,4 +1,5 @@
from ..invocations.latent import LatentsToImageInvocation, TextToLatentsInvocation
from ..invocations.image import ImageNSFWBlurInvocation
from ..invocations.noise import NoiseInvocation
from ..invocations.compel import CompelInvocation
from ..invocations.params import ParamIntInvocation
@ -24,6 +25,7 @@ def create_text_to_image() -> LibraryGraph:
'5': CompelInvocation(id='5'),
'6': TextToLatentsInvocation(id='6'),
'7': LatentsToImageInvocation(id='7'),
'8': ImageNSFWBlurInvocation(id='8'),
},
edges=[
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
@ -33,6 +35,7 @@ def create_text_to_image() -> LibraryGraph:
Edge(source=EdgeConnection(node_id='6', field='latents'), destination=EdgeConnection(node_id='7', field='latents')),
Edge(source=EdgeConnection(node_id='4', field='conditioning'), destination=EdgeConnection(node_id='6', field='positive_conditioning')),
Edge(source=EdgeConnection(node_id='5', field='conditioning'), destination=EdgeConnection(node_id='6', field='negative_conditioning')),
Edge(source=EdgeConnection(node_id='7', field='image'), destination=EdgeConnection(node_id='8', field='image')),
]
),
exposed_inputs=[
@ -43,7 +46,7 @@ def create_text_to_image() -> LibraryGraph:
ExposedNodeInput(node_path='seed', field='a', alias='seed'),
],
exposed_outputs=[
ExposedNodeOutput(node_path='7', field='image', alias='image')
ExposedNodeOutput(node_path='8', field='image', alias='image')
])

View File

@ -216,16 +216,13 @@ class ImageService(ImageServiceABC):
metadata=metadata,
session_id=session_id,
)
if board_id is not None:
self._services.board_image_records.add_image_to_board(
board_id=board_id, image_name=image_name
)
self._services.image_files.save(
image_name=image_name, image=image, metadata=metadata, graph=graph
)
image_dto = self.get_dto(image_name)
return image_dto
@ -236,7 +233,7 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Failed to save image file")
raise
except Exception as e:
self._services.logger.error("Problem saving image record and file")
self._services.logger.error(f"Problem saving image record and file: {str(e)}")
raise e
def update(

View File

@ -23,6 +23,8 @@ import {
NOISE,
POSITIVE_CONDITIONING,
RESIZE,
NSFW_CHECKER,
WATERMARKER,
} from './constants';
/**
@ -104,9 +106,9 @@ export const buildCanvasImageToImageGraph = (
skipped_layers: clipSkip,
},
[LATENTS_TO_IMAGE]: {
is_intermediate: !shouldAutoSave,
type: 'l2i',
id: LATENTS_TO_IMAGE,
is_intermediate: true,
},
[LATENTS_TO_LATENTS]: {
type: 'l2l',
@ -126,6 +128,16 @@ export const buildCanvasImageToImageGraph = (
// image_name: initialImage.image_name,
// },
},
[NSFW_CHECKER]: {
type: 'img_nsfw',
id: NSFW_CHECKER,
is_intermediate: true,
},
[WATERMARKER]: {
is_intermediate: !shouldAutoSave,
type: 'img_watermark',
id: WATERMARKER,
},
},
edges: [
{
@ -168,6 +180,26 @@ export const buildCanvasImageToImageGraph = (
field: 'latents',
},
},
{
source: {
node_id: LATENTS_TO_IMAGE,
field: 'image',
},
destination: {
node_id: NSFW_CHECKER,
field: 'image',
},
},
{
source: {
node_id: NSFW_CHECKER,
field: 'image',
},
destination: {
node_id: WATERMARKER,
field: 'image',
},
},
{
source: {
node_id: IMAGE_TO_LATENTS,
@ -316,7 +348,7 @@ export const buildCanvasImageToImageGraph = (
field: 'metadata',
},
destination: {
node_id: LATENTS_TO_IMAGE,
node_id: WATERMARKER,
field: 'metadata',
},
});

View File

@ -16,6 +16,8 @@ import {
POSITIVE_CONDITIONING,
TEXT_TO_IMAGE_GRAPH,
TEXT_TO_LATENTS,
NSFW_CHECKER,
WATERMARKER,
} from './constants';
/**
@ -107,6 +109,16 @@ export const buildCanvasTextToImageGraph = (
[LATENTS_TO_IMAGE]: {
type: 'l2i',
id: LATENTS_TO_IMAGE,
is_intermediate: true,
},
[NSFW_CHECKER]: {
type: 'img_nsfw',
id: NSFW_CHECKER,
is_intermediate: true,
},
[WATERMARKER]: {
type: 'img_watermark',
id: WATERMARKER,
is_intermediate: !shouldAutoSave,
},
},
@ -181,6 +193,26 @@ export const buildCanvasTextToImageGraph = (
field: 'latents',
},
},
{
source: {
node_id: LATENTS_TO_IMAGE,
field: 'image',
},
destination: {
node_id: NSFW_CHECKER,
field: 'image',
},
},
{
source: {
node_id: NSFW_CHECKER,
field: 'image',
},
destination: {
node_id: WATERMARKER,
field: 'image',
},
},
{
source: {
node_id: NOISE,
@ -221,7 +253,7 @@ export const buildCanvasTextToImageGraph = (
field: 'metadata',
},
destination: {
node_id: LATENTS_TO_IMAGE,
node_id: WATERMARKER,
field: 'metadata',
},
});

View File

@ -22,6 +22,8 @@ import {
NOISE,
POSITIVE_CONDITIONING,
RESIZE,
NSFW_CHECKER,
WATERMARKER,
} from './constants';
/**
@ -185,6 +187,26 @@ export const buildLinearImageToImageGraph = (
field: 'latents',
},
},
{
source: {
node_id: LATENTS_TO_IMAGE,
field: 'image',
},
destination: {
node_id: NSFW_CHECKER,
field: 'image',
},
},
{
source: {
node_id: NSFW_CHECKER,
field: 'image',
},
destination: {
node_id: WATERMARKER,
field: 'image',
},
},
{
source: {
node_id: IMAGE_TO_LATENTS,
@ -362,7 +384,7 @@ export const buildLinearImageToImageGraph = (
field: 'metadata',
},
destination: {
node_id: LATENTS_TO_IMAGE,
node_id: WATERMARKER,
field: 'metadata',
},
});

View File

@ -16,6 +16,8 @@ import {
POSITIVE_CONDITIONING,
TEXT_TO_IMAGE_GRAPH,
TEXT_TO_LATENTS,
NSFW_CHECKER,
WATERMARKER,
} from './constants';
export const buildLinearTextToImageGraph = (
@ -47,7 +49,7 @@ export const buildLinearTextToImageGraph = (
}
/**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
v * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node
* ids.
*
@ -180,6 +182,26 @@ export const buildLinearTextToImageGraph = (
field: 'noise',
},
},
{
source: {
node_id: LATENTS_TO_IMAGE,
field: 'image',
},
destination: {
node_id: NSFW_CHECKER,
field: 'image',
},
},
{
source: {
node_id: NSFW_CHECKER,
field: 'image',
},
destination: {
node_id: WATERMARKER,
field: 'image',
},
},
],
};
@ -210,7 +232,7 @@ export const buildLinearTextToImageGraph = (
field: 'metadata',
},
destination: {
node_id: LATENTS_TO_IMAGE,
node_id: WATERMARKER,
field: 'metadata',
},
});

View File

@ -3,6 +3,8 @@ export const POSITIVE_CONDITIONING = 'positive_conditioning';
export const NEGATIVE_CONDITIONING = 'negative_conditioning';
export const TEXT_TO_LATENTS = 'text_to_latents';
export const LATENTS_TO_IMAGE = 'latents_to_image';
export const NSFW_CHECKER = 'nsfw_checker';
export const WATERMARKER = 'invisible_watermark';
export const NOISE = 'noise';
export const RANDOM_INT = 'rand_int';
export const RANGE_OF_SIZE = 'range_of_size';