update linear graphs to perform safety checking and watermarking

This commit is contained in:
Lincoln Stein 2023-07-23 23:32:08 -04:00 committed by psychedelicious
parent e32cd794f7
commit bd43751323
14 changed files with 243 additions and 112 deletions

View File

@ -20,7 +20,7 @@ from ...backend.model_management import BaseModelType, ModelType
from ..models.image import ImageCategory, ImageField, ResourceOrigin from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import (BaseInvocation, BaseInvocationOutput, from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext) InvocationConfig, InvocationContext)
from .image import ImageOutput, PILInvocationConfig from .image_defs import ImageOutput, PILInvocationConfig
CONTROLNET_DEFAULT_MODELS = [ CONTROLNET_DEFAULT_MODELS = [
########################################### ###########################################

View File

@ -7,10 +7,9 @@ from PIL import Image, ImageFilter, ImageOps, ImageChops
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
from invokeai.app.invocations.metadata import CoreMetadata
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor from transformers import AutoFeatureExtractor
from ..models.image import ImageCategory, ImageField, ResourceOrigin from ..models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
@ -18,53 +17,15 @@ from .baseinvocation import (
InvocationContext, InvocationContext,
InvocationConfig, InvocationConfig,
) )
from .image_defs import (
PILInvocationConfig,
ImageOutput,
MaskOutput,
)
from ..services.config import InvokeAIAppConfig
from invokeai.backend.util.devices import choose_torch_device from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend import SilenceWarnings from invokeai.backend import SilenceWarnings
class PILInvocationConfig(BaseModel):
"""Helper class to provide all PIL invocations with additional config"""
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["PIL", "image"],
},
}
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
# fmt: off
type: Literal["image_output"] = "image_output"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
class Config:
schema_extra = {"required": ["type", "image", "width", "height"]}
class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask"""
# fmt: off
type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask")
width: int = Field(description="The width of the mask in pixels")
height: int = Field(description="The height of the mask in pixels")
# fmt: on
class Config:
schema_extra = {
"required": [
"type",
"mask",
]
}
class LoadImageInvocation(BaseInvocation): class LoadImageInvocation(BaseInvocation):
"""Load an image and provide it as output.""" """Load an image and provide it as output."""
@ -656,13 +617,15 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig): class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
"""Add blur to NSFW-flagged images""" """Add blur to NSFW-flagged images"""
DEFAULT_ENABLED = InvokeAIAppConfig.get_config().nsfw_checker
# fmt: off # fmt: off
type: Literal["img_nsfw"] = "img_nsfw" type: Literal["img_nsfw"] = "img_nsfw"
# Inputs # Inputs
image: Optional[ImageField] = Field(default=None, description="The image to check") image: Optional[ImageField] = Field(default=None, description="The image to check")
active: bool = Field(default=True, description="Whether the NSFW checker is active") enabled: bool = Field(default=DEFAULT_ENABLED, description="Whether the NSFW checker is enabled")
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
# fmt: on # fmt: on
class Config(InvocationConfig): class Config(InvocationConfig):
@ -676,55 +639,46 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
if not self.active:
return ImageOutput(
image=ImageField(image_name=self.image.image_name),
width=image.width,
height=image.height,
)
config = context.services.configuration config = context.services.configuration
logger = context.services.logger logger = context.services.logger
device = choose_torch_device() device = choose_torch_device()
logger.info("Running NSFW checker") if self.enabled:
safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / 'core/convert/stable-diffusion-safety-checker') logger.info("Running NSFW checker")
feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / 'core/convert/stable-diffusion-safety-checker') safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / 'core/convert/stable-diffusion-safety-checker')
feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / 'core/convert/stable-diffusion-safety-checker')
features = feature_extractor([image], return_tensors="pt")
features.to(device)
safety_checker.to(device)
x_image = numpy.array(image).astype(numpy.float32) / 255.0 features = feature_extractor([image], return_tensors="pt")
x_image = x_image[None].transpose(0, 3, 1, 2) features.to(device)
with SilenceWarnings(): safety_checker.to(device)
checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=features.pixel_values)
logger.info(f"NSFW scan result: {has_nsfw_concept[0]}") x_image = numpy.array(image).astype(numpy.float32) / 255.0
if has_nsfw_concept[0]: x_image = x_image[None].transpose(0, 3, 1, 2)
blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32)) with SilenceWarnings():
caution = self._get_caution_img() checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=features.pixel_values)
blurry_image.paste(caution,(0,0),caution)
image_dto = context.services.images.create( logger.info(f"NSFW scan result: {has_nsfw_concept[0]}")
image=blurry_image, if has_nsfw_concept[0]:
image_origin=ResourceOrigin.INTERNAL, blurry_image = image.filter(filter=ImageFilter.GaussianBlur(radius=32))
image_category=ImageCategory.GENERAL, caution = self._get_caution_img()
node_id=self.id, blurry_image.paste(caution,(0,0),caution)
session_id=context.graph_execution_state_id, image = blurry_image
is_intermediate=self.is_intermediate,
) image_dto = context.services.images.create(
return ImageOutput( image=image,
image=ImageField(image_name=image_dto.image_name), image_origin=ResourceOrigin.INTERNAL,
width=image_dto.width, image_category=ImageCategory.GENERAL,
height=image_dto.height, node_id=self.id,
) session_id=context.graph_execution_state_id,
else: is_intermediate=self.is_intermediate,
return ImageOutput( metadata=self.metadata.dict() if self.metadata else None,
image=ImageField(image_name=self.image.image_name), )
width=image.width,
height=image.height, return ImageOutput(
) image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
def _get_caution_img(self)->Image: def _get_caution_img(self)->Image:
import invokeai.assets.web as web_assets import invokeai.assets.web as web_assets
@ -733,12 +687,18 @@ class ImageNSFWBlurInvocation(BaseInvocation, PILInvocationConfig):
class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig): class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
""" Add an invisible watermark to an image """ """ Add an invisible watermark to an image """
# to avoid circular import
DEFAULT_ENABLED = InvokeAIAppConfig.get_config().invisible_watermark
# fmt: off # fmt: off
type: Literal["img_watermark"] = "img_watermark" type: Literal["img_watermark"] = "img_watermark"
# Inputs # Inputs
image: Optional[ImageField] = Field(default=None, description="The image to check") image: Optional[ImageField] = Field(default=None, description="The image to check")
text: str = Field(default='InvokeAI', description="Watermark text") text: str = Field(default='InvokeAI', description="Watermark text")
enabled: bool = Field(default=DEFAULT_ENABLED, description="Whether the invisible watermark is enabled")
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")
# fmt: on # fmt: on
class Config(InvocationConfig): class Config(InvocationConfig):
@ -753,23 +713,28 @@ class ImageWatermarkInvocation(BaseInvocation, PILInvocationConfig):
import cv2 import cv2
from imwatermark import WatermarkEncoder from imwatermark import WatermarkEncoder
logger = context.services.logger
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
bgr = cv2.cvtColor(numpy.array(image.convert("RGB")), cv2.COLOR_RGB2BGR) if self.enabled:
wm = self.text logger.info("Running invisible watermarker")
encoder = WatermarkEncoder() bgr = cv2.cvtColor(numpy.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
encoder.set_watermark('bytes', wm.encode('utf-8')) wm = self.text
bgr_encoded = encoder.encode(bgr, 'dwtDct') encoder = WatermarkEncoder()
new_image = Image.fromarray( encoder.set_watermark('bytes', wm.encode('utf-8'))
cv2.cvtColor(bgr_encoded, cv2.COLOR_BGR2RGB) bgr_encoded = encoder.encode(bgr, 'dwtDct')
).convert("RGBA") new_image = Image.fromarray(
cv2.cvtColor(bgr_encoded, cv2.COLOR_BGR2RGB)
).convert("RGBA")
image = new_image
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=new_image, image=image,
image_origin=ResourceOrigin.INTERNAL, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,
session_id=context.graph_execution_state_id, session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate, is_intermediate=self.is_intermediate,
metadata=self.metadata.dict() if self.metadata else None,
) )
return ImageOutput( return ImageOutput(

View File

@ -0,0 +1,54 @@
# Copyright 2023 Lincoln D. Stein and the InvokeAI Team
""" Common classes used by .image and .controlnet; avoids circular import issues """
from pydantic import BaseModel, Field
from typing import Literal
from ..models.image import ImageField
from .baseinvocation import (
BaseInvocationOutput,
InvocationConfig,
)
class PILInvocationConfig(BaseModel):
"""Helper class to provide all PIL invocations with additional config"""
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["PIL", "image"],
},
}
class ImageOutput(BaseInvocationOutput):
"""Base class for invocations that output an image"""
# fmt: off
type: Literal["image_output"] = "image_output"
image: ImageField = Field(default=None, description="The output image")
width: int = Field(description="The width of the image in pixels")
height: int = Field(description="The height of the image in pixels")
# fmt: on
class Config:
schema_extra = {"required": ["type", "image", "width", "height"]}
class MaskOutput(BaseInvocationOutput):
"""Base class for invocations that output a mask"""
# fmt: off
type: Literal["mask"] = "mask"
mask: ImageField = Field(default=None, description="The output mask")
width: int = Field(description="The width of the mask in pixels")
height: int = Field(description="The height of the mask in pixels")
# fmt: on
class Config:
schema_extra = {
"required": [
"type",
"mask",
]
}

View File

@ -501,7 +501,7 @@ class LatentsToImageInvocation(BaseInvocation):
vae: VaeField = Field(default=None, description="Vae submodel") vae: VaeField = Field(default=None, description="Vae submodel")
tiled: bool = Field( tiled: bool = Field(
default=False, default=False,
description="Decode latents by overlaping tiles(less memory consumption)") description="Decode latents by overlapping tiles(less memory consumption)")
fp32: bool = Field(DEFAULT_PRECISION=='float32', description="Decode in full precision") fp32: bool = Field(DEFAULT_PRECISION=='float32', description="Decode in full precision")
metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image") metadata: Optional[CoreMetadata] = Field(default=None, description="Optional core metadata to be written to the image")

View File

@ -11,7 +11,6 @@ from invokeai.app.invocations.baseinvocation import (
from invokeai.app.invocations.controlnet_image_processors import ControlField from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
class LoRAMetadataField(BaseModel): class LoRAMetadataField(BaseModel):
"""LoRA metadata for an image generated in InvokeAI.""" """LoRA metadata for an image generated in InvokeAI."""

View File

@ -135,6 +135,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
board_id: str, board_id: str,
image_name: str, image_name: str,
) -> None: ) -> None:
print(f'DEBUG: board_id={board_id}, image_name={image_name}')
try: try:
self._lock.acquire() self._lock.acquire()
self._cursor.execute( self._cursor.execute(
@ -146,6 +147,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
(board_id, image_name, board_id), (board_id, image_name, board_id),
) )
self._conn.commit() self._conn.commit()
print('got here')
except sqlite3.Error as e: except sqlite3.Error as e:
self._conn.rollback() self._conn.rollback()
raise e raise e

View File

@ -365,6 +365,7 @@ setting environment variables INVOKEAI_<setting>.
internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features') internet_available : bool = Field(default=True, description="If true, attempt to download models on the fly; otherwise only use local models", category='Features')
log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features') log_tokenization : bool = Field(default=False, description="Enable logging of parsed prompt tokens.", category='Features')
nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker", category='Features') nsfw_checker : bool = Field(default=True, description="Enable/disable the NSFW checker", category='Features')
invisible_watermark : bool = Field(default=True, description="Enable/disable the invisible watermark", category='Features')
patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features') patchmatch : bool = Field(default=True, description="Enable/disable patchmatch inpaint code", category='Features')
restore : bool = Field(default=True, description="Enable/disable face restoration code (DEPRECATED)", category='DEPRECATED') restore : bool = Field(default=True, description="Enable/disable face restoration code (DEPRECATED)", category='DEPRECATED')

View File

@ -1,4 +1,5 @@
from ..invocations.latent import LatentsToImageInvocation, TextToLatentsInvocation from ..invocations.latent import LatentsToImageInvocation, TextToLatentsInvocation
from ..invocations.image import ImageNSFWBlurInvocation
from ..invocations.noise import NoiseInvocation from ..invocations.noise import NoiseInvocation
from ..invocations.compel import CompelInvocation from ..invocations.compel import CompelInvocation
from ..invocations.params import ParamIntInvocation from ..invocations.params import ParamIntInvocation
@ -24,6 +25,7 @@ def create_text_to_image() -> LibraryGraph:
'5': CompelInvocation(id='5'), '5': CompelInvocation(id='5'),
'6': TextToLatentsInvocation(id='6'), '6': TextToLatentsInvocation(id='6'),
'7': LatentsToImageInvocation(id='7'), '7': LatentsToImageInvocation(id='7'),
'8': ImageNSFWBlurInvocation(id='8'),
}, },
edges=[ edges=[
Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')), Edge(source=EdgeConnection(node_id='width', field='a'), destination=EdgeConnection(node_id='3', field='width')),
@ -33,6 +35,7 @@ def create_text_to_image() -> LibraryGraph:
Edge(source=EdgeConnection(node_id='6', field='latents'), destination=EdgeConnection(node_id='7', field='latents')), Edge(source=EdgeConnection(node_id='6', field='latents'), destination=EdgeConnection(node_id='7', field='latents')),
Edge(source=EdgeConnection(node_id='4', field='conditioning'), destination=EdgeConnection(node_id='6', field='positive_conditioning')), Edge(source=EdgeConnection(node_id='4', field='conditioning'), destination=EdgeConnection(node_id='6', field='positive_conditioning')),
Edge(source=EdgeConnection(node_id='5', field='conditioning'), destination=EdgeConnection(node_id='6', field='negative_conditioning')), Edge(source=EdgeConnection(node_id='5', field='conditioning'), destination=EdgeConnection(node_id='6', field='negative_conditioning')),
Edge(source=EdgeConnection(node_id='7', field='image'), destination=EdgeConnection(node_id='8', field='image')),
] ]
), ),
exposed_inputs=[ exposed_inputs=[
@ -43,7 +46,7 @@ def create_text_to_image() -> LibraryGraph:
ExposedNodeInput(node_path='seed', field='a', alias='seed'), ExposedNodeInput(node_path='seed', field='a', alias='seed'),
], ],
exposed_outputs=[ exposed_outputs=[
ExposedNodeOutput(node_path='7', field='image', alias='image') ExposedNodeOutput(node_path='8', field='image', alias='image')
]) ])

View File

@ -216,16 +216,13 @@ class ImageService(ImageServiceABC):
metadata=metadata, metadata=metadata,
session_id=session_id, session_id=session_id,
) )
if board_id is not None: if board_id is not None:
self._services.board_image_records.add_image_to_board( self._services.board_image_records.add_image_to_board(
board_id=board_id, image_name=image_name board_id=board_id, image_name=image_name
) )
self._services.image_files.save( self._services.image_files.save(
image_name=image_name, image=image, metadata=metadata, graph=graph image_name=image_name, image=image, metadata=metadata, graph=graph
) )
image_dto = self.get_dto(image_name) image_dto = self.get_dto(image_name)
return image_dto return image_dto
@ -236,7 +233,7 @@ class ImageService(ImageServiceABC):
self._services.logger.error("Failed to save image file") self._services.logger.error("Failed to save image file")
raise raise
except Exception as e: except Exception as e:
self._services.logger.error("Problem saving image record and file") self._services.logger.error(f"Problem saving image record and file: {str(e)}")
raise e raise e
def update( def update(

View File

@ -23,6 +23,8 @@ import {
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RESIZE, RESIZE,
NSFW_CHECKER,
WATERMARKER,
} from './constants'; } from './constants';
/** /**
@ -104,9 +106,9 @@ export const buildCanvasImageToImageGraph = (
skipped_layers: clipSkip, skipped_layers: clipSkip,
}, },
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
is_intermediate: !shouldAutoSave,
type: 'l2i', type: 'l2i',
id: LATENTS_TO_IMAGE, id: LATENTS_TO_IMAGE,
is_intermediate: true,
}, },
[LATENTS_TO_LATENTS]: { [LATENTS_TO_LATENTS]: {
type: 'l2l', type: 'l2l',
@ -126,6 +128,16 @@ export const buildCanvasImageToImageGraph = (
// image_name: initialImage.image_name, // image_name: initialImage.image_name,
// }, // },
}, },
[NSFW_CHECKER]: {
type: 'img_nsfw',
id: NSFW_CHECKER,
is_intermediate: true,
},
[WATERMARKER]: {
is_intermediate: !shouldAutoSave,
type: 'img_watermark',
id: WATERMARKER,
},
}, },
edges: [ edges: [
{ {
@ -168,6 +180,26 @@ export const buildCanvasImageToImageGraph = (
field: 'latents', field: 'latents',
}, },
}, },
{
source: {
node_id: LATENTS_TO_IMAGE,
field: 'image',
},
destination: {
node_id: NSFW_CHECKER,
field: 'image',
},
},
{
source: {
node_id: NSFW_CHECKER,
field: 'image',
},
destination: {
node_id: WATERMARKER,
field: 'image',
},
},
{ {
source: { source: {
node_id: IMAGE_TO_LATENTS, node_id: IMAGE_TO_LATENTS,
@ -316,7 +348,7 @@ export const buildCanvasImageToImageGraph = (
field: 'metadata', field: 'metadata',
}, },
destination: { destination: {
node_id: LATENTS_TO_IMAGE, node_id: WATERMARKER,
field: 'metadata', field: 'metadata',
}, },
}); });

View File

@ -16,6 +16,8 @@ import {
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
TEXT_TO_IMAGE_GRAPH, TEXT_TO_IMAGE_GRAPH,
TEXT_TO_LATENTS, TEXT_TO_LATENTS,
NSFW_CHECKER,
WATERMARKER,
} from './constants'; } from './constants';
/** /**
@ -107,6 +109,16 @@ export const buildCanvasTextToImageGraph = (
[LATENTS_TO_IMAGE]: { [LATENTS_TO_IMAGE]: {
type: 'l2i', type: 'l2i',
id: LATENTS_TO_IMAGE, id: LATENTS_TO_IMAGE,
is_intermediate: true,
},
[NSFW_CHECKER]: {
type: 'img_nsfw',
id: NSFW_CHECKER,
is_intermediate: true,
},
[WATERMARKER]: {
type: 'img_watermark',
id: WATERMARKER,
is_intermediate: !shouldAutoSave, is_intermediate: !shouldAutoSave,
}, },
}, },
@ -181,6 +193,26 @@ export const buildCanvasTextToImageGraph = (
field: 'latents', field: 'latents',
}, },
}, },
{
source: {
node_id: LATENTS_TO_IMAGE,
field: 'image',
},
destination: {
node_id: NSFW_CHECKER,
field: 'image',
},
},
{
source: {
node_id: NSFW_CHECKER,
field: 'image',
},
destination: {
node_id: WATERMARKER,
field: 'image',
},
},
{ {
source: { source: {
node_id: NOISE, node_id: NOISE,
@ -221,7 +253,7 @@ export const buildCanvasTextToImageGraph = (
field: 'metadata', field: 'metadata',
}, },
destination: { destination: {
node_id: LATENTS_TO_IMAGE, node_id: WATERMARKER,
field: 'metadata', field: 'metadata',
}, },
}); });

View File

@ -22,6 +22,8 @@ import {
NOISE, NOISE,
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
RESIZE, RESIZE,
NSFW_CHECKER,
WATERMARKER,
} from './constants'; } from './constants';
/** /**
@ -185,6 +187,26 @@ export const buildLinearImageToImageGraph = (
field: 'latents', field: 'latents',
}, },
}, },
{
source: {
node_id: LATENTS_TO_IMAGE,
field: 'image',
},
destination: {
node_id: NSFW_CHECKER,
field: 'image',
},
},
{
source: {
node_id: NSFW_CHECKER,
field: 'image',
},
destination: {
node_id: WATERMARKER,
field: 'image',
},
},
{ {
source: { source: {
node_id: IMAGE_TO_LATENTS, node_id: IMAGE_TO_LATENTS,
@ -362,7 +384,7 @@ export const buildLinearImageToImageGraph = (
field: 'metadata', field: 'metadata',
}, },
destination: { destination: {
node_id: LATENTS_TO_IMAGE, node_id: WATERMARKER,
field: 'metadata', field: 'metadata',
}, },
}); });

View File

@ -16,6 +16,8 @@ import {
POSITIVE_CONDITIONING, POSITIVE_CONDITIONING,
TEXT_TO_IMAGE_GRAPH, TEXT_TO_IMAGE_GRAPH,
TEXT_TO_LATENTS, TEXT_TO_LATENTS,
NSFW_CHECKER,
WATERMARKER,
} from './constants'; } from './constants';
export const buildLinearTextToImageGraph = ( export const buildLinearTextToImageGraph = (
@ -47,7 +49,7 @@ export const buildLinearTextToImageGraph = (
} }
/** /**
* The easiest way to build linear graphs is to do it in the node editor, then copy and paste the v * The easiest way to build linear graphs is to do it in the node editor, then copy and paste the
* full graph here as a template. Then use the parameters from app state and set friendlier node * full graph here as a template. Then use the parameters from app state and set friendlier node
* ids. * ids.
* *
@ -180,6 +182,26 @@ export const buildLinearTextToImageGraph = (
field: 'noise', field: 'noise',
}, },
}, },
{
source: {
node_id: LATENTS_TO_IMAGE,
field: 'image',
},
destination: {
node_id: NSFW_CHECKER,
field: 'image',
},
},
{
source: {
node_id: NSFW_CHECKER,
field: 'image',
},
destination: {
node_id: WATERMARKER,
field: 'image',
},
},
], ],
}; };
@ -210,7 +232,7 @@ export const buildLinearTextToImageGraph = (
field: 'metadata', field: 'metadata',
}, },
destination: { destination: {
node_id: LATENTS_TO_IMAGE, node_id: WATERMARKER,
field: 'metadata', field: 'metadata',
}, },
}); });

View File

@ -3,6 +3,8 @@ export const POSITIVE_CONDITIONING = 'positive_conditioning';
export const NEGATIVE_CONDITIONING = 'negative_conditioning'; export const NEGATIVE_CONDITIONING = 'negative_conditioning';
export const TEXT_TO_LATENTS = 'text_to_latents'; export const TEXT_TO_LATENTS = 'text_to_latents';
export const LATENTS_TO_IMAGE = 'latents_to_image'; export const LATENTS_TO_IMAGE = 'latents_to_image';
export const NSFW_CHECKER = 'nsfw_checker';
export const WATERMARKER = 'invisible_watermark';
export const NOISE = 'noise'; export const NOISE = 'noise';
export const RANDOM_INT = 'rand_int'; export const RANDOM_INT = 'rand_int';
export const RANGE_OF_SIZE = 'range_of_size'; export const RANGE_OF_SIZE = 'range_of_size';