Add mask to l2l, MaskEdge and ColorCorrect nodes

This commit is contained in:
Sergey Borisov 2023-07-24 14:25:54 +03:00
parent 02618a701d
commit 0ebe2c0ebc
2 changed files with 201 additions and 8 deletions

View File

@ -2,6 +2,7 @@
from typing import Literal, Optional from typing import Literal, Optional
import cv2
import numpy import numpy
from PIL import Image, ImageFilter, ImageOps, ImageChops from PIL import Image, ImageFilter, ImageOps, ImageChops
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -193,13 +194,10 @@ class ImagePasteInvocation(BaseInvocation, PILInvocationConfig):
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
base_image = context.services.images.get_pil_image(self.base_image.image_name) base_image = context.services.images.get_pil_image(self.base_image.image_name)
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name)
mask = ( mask = None
None if self.mask is not None:
if self.mask is None mask = context.services.images.get_pil_image(self.mask.image_name)
else ImageOps.invert( mask = ImageOps.invert(mask.convert("L"))
context.services.images.get_pil_image(self.mask.image_name)
)
)
# TODO: probably shouldn't invert mask here... should user be required to do it? # TODO: probably shouldn't invert mask here... should user be required to do it?
min_x = min(0, self.x) min_x = min(0, self.x)
@ -650,3 +648,167 @@ class ImageInverseLerpInvocation(BaseInvocation, PILInvocationConfig):
width=image_dto.width, width=image_dto.width,
height=image_dto.height, height=image_dto.height,
) )
class MaskEdgeInvocation(BaseInvocation, PILInvocationConfig):
"""Applies an edge mask to an image"""
# fmt: off
type: Literal["mask_edge"] = "mask_edge"
# Inputs
image: Optional[ImageField] = Field(default=None, description="The image to apply the mask to")
edge_size: int = Field(description="The size of the edge")
edge_blur: int = Field(description="The amount of blur on the edge")
low_threshold: int = Field(description="First threshold for the hysteresis procedure in Canny edge detection")
high_threshold: int = Field(description="Second threshold for the hysteresis procedure in Canny edge detection")
# fmt: on
def invoke(self, context: InvocationContext) -> MaskOutput:
mask = context.services.images.get_pil_image(self.image.image_name)
npimg = numpy.asarray(mask, dtype=numpy.uint8)
npgradient = numpy.uint8(
255 * (1.0 - numpy.floor(numpy.abs(0.5 - numpy.float32(npimg) / 255.0) * 2.0))
)
npedge = cv2.Canny(npimg, threshold1=self.low_threshold, threshold2=self.high_threshold)
npmask = npgradient + npedge
npmask = cv2.dilate(
npmask, numpy.ones((3, 3), numpy.uint8), iterations=int(self.edge_size / 2)
)
new_mask = Image.fromarray(npmask)
if self.edge_blur > 0:
new_mask = new_mask.filter(ImageFilter.BoxBlur(self.edge_blur))
new_mask = ImageOps.invert(new_mask)
image_dto = context.services.images.create(
image=new_mask,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.MASK,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return MaskOutput(
mask=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
class ColorCorrectInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["color_correct"] = "color_correct"
init: Optional[ImageField] = Field(default=None, description="Initial image")
result: Optional[ImageField] = Field(default=None, description="Resulted image")
mask: Optional[ImageField] = Field(default=None, description="Mask image")
mask_blur_radius: float = Field(default=8, description="Mask blur radius")
def invoke(self, context: InvocationContext) -> ImageOutput:
pil_init_mask = None
if self.mask is not None:
pil_init_mask = context.services.images.get_pil_image(
self.mask.image_name
).convert("L")
init_image = context.services.images.get_pil_image(
self.init.image_name
)
result = context.services.images.get_pil_image(
self.result.image_name
).convert("RGBA")
#if init_image is None or init_mask is None:
# return result
# Get the original alpha channel of the mask if there is one.
# Otherwise it is some other black/white image format ('1', 'L' or 'RGB')
#pil_init_mask = (
# init_mask.getchannel("A")
# if init_mask.mode == "RGBA"
# else init_mask.convert("L")
#)
pil_init_image = init_image.convert(
"RGBA"
) # Add an alpha channel if one doesn't exist
# Build an image with only visible pixels from source to use as reference for color-matching.
init_rgb_pixels = numpy.asarray(init_image.convert("RGB"), dtype=numpy.uint8)
init_a_pixels = numpy.asarray(pil_init_image.getchannel("A"), dtype=numpy.uint8)
init_mask_pixels = numpy.asarray(pil_init_mask, dtype=numpy.uint8)
# Get numpy version of result
np_image = numpy.asarray(result.convert("RGB"), dtype=numpy.uint8)
# Mask and calculate mean and standard deviation
mask_pixels = init_a_pixels * init_mask_pixels > 0
np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :]
np_image_masked = np_image[mask_pixels, :]
if np_init_rgb_pixels_masked.size > 0:
init_means = np_init_rgb_pixels_masked.mean(axis=0)
init_std = np_init_rgb_pixels_masked.std(axis=0)
gen_means = np_image_masked.mean(axis=0)
gen_std = np_image_masked.std(axis=0)
# Color correct
np_matched_result = np_image.copy()
np_matched_result[:, :, :] = (
(
(
(
np_matched_result[:, :, :].astype(numpy.float32)
- gen_means[None, None, :]
)
/ gen_std[None, None, :]
)
* init_std[None, None, :]
+ init_means[None, None, :]
)
.clip(0, 255)
.astype(numpy.uint8)
)
matched_result = Image.fromarray(np_matched_result, mode="RGB")
else:
matched_result = Image.fromarray(np_image, mode="RGB")
# Blur the mask out (into init image) by specified amount
if self.mask_blur_radius > 0:
nm = numpy.asarray(pil_init_mask, dtype=numpy.uint8)
nmd = cv2.erode(
nm,
kernel=numpy.ones((3, 3), dtype=numpy.uint8),
iterations=int(self.mask_blur_radius / 2),
)
pmd = Image.fromarray(nmd, mode="L")
blurred_init_mask = pmd.filter(ImageFilter.BoxBlur(self.mask_blur_radius))
else:
blurred_init_mask = pil_init_mask
multiplied_blurred_init_mask = ImageChops.multiply(
blurred_init_mask, result.split()[-1]
)
# Paste original on color-corrected generation (using blurred mask)
matched_result.paste(init_image, (0, 0), mask=multiplied_blurred_init_mask)
image_dto = context.services.images.create(
image=matched_result,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
is_intermediate=self.is_intermediate,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)

View File

@ -3,6 +3,8 @@
from contextlib import ExitStack from contextlib import ExitStack
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
import torchvision.transforms as T
from torchvision.transforms.functional import resize as tv_resize
import einops import einops
import torch import torch
from diffusers import ControlNetModel from diffusers import ControlNetModel
@ -394,6 +396,9 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
strength: float = Field( strength: float = Field(
default=0.7, ge=0, le=1, default=0.7, ge=0, le=1,
description="The strength of the latents to use") description="The strength of the latents to use")
mask: Optional[ImageField] = Field(
None, description="Mask",
)
# Schema customisation # Schema customisation
class Config(InvocationConfig): class Config(InvocationConfig):
@ -409,10 +414,25 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
}, },
} }
def prep_mask_tensor(self, context, lantents):
if self.mask is None:
return None
mask_image = context.services.images.get_pil_image(self.mask.image_name)
if mask_image.mode != "L":
# FIXME: why do we get passed an RGB image here? We can only use single-channel.
mask_image = mask_image.convert("L")
mask_tensor = image_resized_to_grid_as_tensor(mask_image, normalize=False)
mask_tensor = tv_resize(
mask_tensor, lantents.shape[-2:], T.InterpolationMode.BILINEAR
)
return mask_tensor
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput: def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name) noise = context.services.latents.get(self.noise.latents_name)
latent = context.services.latents.get(self.latents.latents_name) latent = context.services.latents.get(self.latents.latents_name)
mask = self.prep_mask_tensor(context, latent)
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get( graph_execution_state = context.services.graph_execution_manager.get(
@ -441,6 +461,7 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
noise = noise.to(device=unet.device, dtype=unet.dtype) noise = noise.to(device=unet.device, dtype=unet.dtype)
latent = latent.to(device=unet.device, dtype=unet.dtype) latent = latent.to(device=unet.device, dtype=unet.dtype)
mask = mask.to(device=unet.device, dtype=unet.dtype)
scheduler = get_scheduler( scheduler = get_scheduler(
context=context, context=context,
@ -470,6 +491,15 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
device=unet.device, device=unet.device,
) )
def _apply_mask_on_step(step_output, timestep, conditioning_data):
noised_init = scheduler.add_noise(initial_latents, noise, timestep.unsqueeze(0))
step_output.prev_sample = step_output.prev_sample * (1 - mask) + noised_init * mask
return step_output
additional_guidance = []
if mask is not None:
additional_guidance.append(_apply_mask_on_step)
result_latents, result_attention_map_saver = pipeline.latents_from_embeddings( result_latents, result_attention_map_saver = pipeline.latents_from_embeddings(
latents=initial_latents, latents=initial_latents,
timesteps=timesteps, timesteps=timesteps,
@ -477,7 +507,8 @@ class LatentsToLatentsInvocation(TextToLatentsInvocation):
num_inference_steps=self.steps, num_inference_steps=self.steps,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData] control_data=control_data, # list[ControlNetData]
callback=step_callback callback=step_callback,
additional_guidance=additional_guidance,
) )
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699